From 11ac8e6cb9f306c4a92dc93ac7cce25cfcebb8ce Mon Sep 17 00:00:00 2001 From: twiby Date: Wed, 21 Jun 2023 13:59:23 +0200 Subject: [PATCH] feat(trivium): add bench for casting and packing --- apps/trivium/Cargo.toml | 2 +- apps/trivium/benches/kreyvium_shortint.rs | 47 +++++++++++--- apps/trivium/benches/trivium_shortint.rs | 45 ++++++++++--- tfhe/benches/shortint/bench.rs | 9 +++ tfhe/benches/shortint/casting.rs | 78 +++++++++++++++++++++++ 5 files changed, 162 insertions(+), 19 deletions(-) create mode 100644 tfhe/benches/shortint/casting.rs diff --git a/apps/trivium/Cargo.toml b/apps/trivium/Cargo.toml index 80cd6876e..704f361d1 100644 --- a/apps/trivium/Cargo.toml +++ b/apps/trivium/Cargo.toml @@ -21,4 +21,4 @@ criterion = { version = "0.4", features = [ "html_reports" ]} [[bench]] name = "trivium" -harness = false \ No newline at end of file +harness = false diff --git a/apps/trivium/benches/kreyvium_shortint.rs b/apps/trivium/benches/kreyvium_shortint.rs index 4a3705dbe..b14694b69 100644 --- a/apps/trivium/benches/kreyvium_shortint.rs +++ b/apps/trivium/benches/kreyvium_shortint.rs @@ -1,6 +1,6 @@ use tfhe::prelude::*; use tfhe::shortint::prelude::*; -use tfhe::shortint::CastingKey; +use tfhe::shortint::KeySwitchingKey; use tfhe::{generate_keys, ConfigBuilder, FheUint64}; use tfhe_trivium::{KreyviumStreamShortint, TransCiphering}; @@ -12,8 +12,16 @@ pub fn kreyvium_shortint_warmup(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB000000000000".to_string(); let mut key = [0; 128]; @@ -40,8 +48,13 @@ pub fn kreyvium_shortint_warmup(c: &mut Criterion) { c.bench_function("kreyvium 1_1 warmup", |b| { b.iter(|| { let cipher_key = key.map(|x| client_key.encrypt(x)); - let _kreyvium = - KreyviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let _kreyvium = KreyviumStreamShortint::new( + cipher_key, + iv, + server_key.clone(), + ksk.clone(), + hl_server_key.clone(), + ); }) }); } @@ -51,8 +64,16 @@ pub fn kreyvium_shortint_gen(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB000000000000".to_string(); let mut key = [0; 128]; @@ -78,8 +99,7 @@ pub fn kreyvium_shortint_gen(c: &mut Criterion) { let cipher_key = key.map(|x| client_key.encrypt(x)); - let mut kreyvium = - KreyviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let mut kreyvium = KreyviumStreamShortint::new(cipher_key, iv, server_key, ksk, hl_server_key); c.bench_function("kreyvium 1_1 generate 64 bits", |b| { b.iter(|| kreyvium.next_64()) @@ -91,8 +111,16 @@ pub fn kreyvium_shortint_trans(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB000000000000".to_string(); let mut key = [0; 128]; @@ -119,8 +147,7 @@ pub fn kreyvium_shortint_trans(c: &mut Criterion) { let cipher_key = key.map(|x| client_key.encrypt(x)); let ciphered_message = FheUint64::try_encrypt(0u64, &hl_client_key).unwrap(); - let mut kreyvium = - KreyviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let mut kreyvium = KreyviumStreamShortint::new(cipher_key, iv, server_key, ksk, hl_server_key); c.bench_function("kreyvium 1_1 transencrypt 64 bits", |b| { b.iter(|| kreyvium.trans_encrypt_64(ciphered_message.clone())) diff --git a/apps/trivium/benches/trivium_shortint.rs b/apps/trivium/benches/trivium_shortint.rs index d79de75b4..5295fd4f3 100644 --- a/apps/trivium/benches/trivium_shortint.rs +++ b/apps/trivium/benches/trivium_shortint.rs @@ -1,6 +1,6 @@ use tfhe::prelude::*; use tfhe::shortint::prelude::*; -use tfhe::shortint::CastingKey; +use tfhe::shortint::KeySwitchingKey; use tfhe::{generate_keys, ConfigBuilder, FheUint64}; use tfhe_trivium::{TransCiphering, TriviumStreamShortint}; @@ -12,8 +12,16 @@ pub fn trivium_shortint_warmup(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB".to_string(); let mut key = [0; 80]; @@ -40,8 +48,13 @@ pub fn trivium_shortint_warmup(c: &mut Criterion) { c.bench_function("trivium 1_1 warmup", |b| { b.iter(|| { let cipher_key = key.map(|x| client_key.encrypt(x)); - let _trivium = - TriviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let _trivium = TriviumStreamShortint::new( + cipher_key, + iv, + server_key.clone(), + ksk.clone(), + hl_server_key.clone(), + ); }) }); } @@ -51,8 +64,16 @@ pub fn trivium_shortint_gen(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB".to_string(); let mut key = [0; 80]; @@ -78,7 +99,7 @@ pub fn trivium_shortint_gen(c: &mut Criterion) { let cipher_key = key.map(|x| client_key.encrypt(x)); - let mut trivium = TriviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let mut trivium = TriviumStreamShortint::new(cipher_key, iv, server_key, ksk, hl_server_key); c.bench_function("trivium 1_1 generate 64 bits", |b| { b.iter(|| trivium.next_64()) @@ -90,8 +111,16 @@ pub fn trivium_shortint_trans(c: &mut Criterion) { .enable_default_integers() .build(); let (hl_client_key, hl_server_key) = generate_keys(config); + let underlying_ck: tfhe::shortint::ClientKey = (*hl_client_key.as_ref()).clone().into(); + let underlying_sk: tfhe::shortint::ServerKey = (*hl_server_key.as_ref()).clone().into(); + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); - let ksk = CastingKey::new((&client_key, &server_key), (&hl_client_key, &hl_server_key)); + + let ksk = KeySwitchingKey::new( + (&client_key, &server_key), + (&underlying_ck, &underlying_sk), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); let key_string = "0053A6F94C9FF24598EB".to_string(); let mut key = [0; 80]; @@ -118,7 +147,7 @@ pub fn trivium_shortint_trans(c: &mut Criterion) { let cipher_key = key.map(|x| client_key.encrypt(x)); let ciphered_message = FheUint64::try_encrypt(0u64, &hl_client_key).unwrap(); - let mut trivium = TriviumStreamShortint::new(cipher_key, iv, &server_key, &ksk, &hl_server_key); + let mut trivium = TriviumStreamShortint::new(cipher_key, iv, server_key, ksk, hl_server_key); c.bench_function("trivium 1_1 transencrypt 64 bits", |b| { b.iter(|| trivium.trans_encrypt_64(ciphered_message.clone())) diff --git a/tfhe/benches/shortint/bench.rs b/tfhe/benches/shortint/bench.rs index bbac17343..0f5adb449 100644 --- a/tfhe/benches/shortint/bench.rs +++ b/tfhe/benches/shortint/bench.rs @@ -703,9 +703,18 @@ criterion_group!( scalar_not_equal ); +mod casting; +criterion_group!( + casting, + casting::pack_cast_64, + casting::pack_cast, + casting::cast +); + criterion_main!( // arithmetic_operation, // arithmetic_scalar_operation, + casting, default_ops, default_scalar_ops, ); diff --git a/tfhe/benches/shortint/casting.rs b/tfhe/benches/shortint/casting.rs new file mode 100644 index 000000000..f15e7a1b6 --- /dev/null +++ b/tfhe/benches/shortint/casting.rs @@ -0,0 +1,78 @@ +use tfhe::shortint::prelude::*; + +use rayon::prelude::*; + +use criterion::Criterion; + +pub fn pack_cast_64(c: &mut Criterion) { + let (client_key_1, server_key_1): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + let (client_key_2, server_key_2): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let ksk = KeySwitchingKey::new( + (&client_key_1, &server_key_1), + (&client_key_2, &server_key_2), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); + + let vec_ct = vec![client_key_1.encrypt(1); 64]; + + c.bench_function("pack_cast_64", |b| { + b.iter(|| { + let _ = (0..32) + .into_par_iter() + .map(|i| { + let byte_idx = 7 - i / 4; + let pair_idx = i % 4; + + let b0 = &vec_ct[8 * byte_idx + 2 * pair_idx]; + let b1 = &vec_ct[8 * byte_idx + 2 * pair_idx + 1]; + + ksk.cast( + &server_key_1.unchecked_add(b0, &server_key_1.unchecked_scalar_mul(b1, 2)), + ) + }) + .collect::>(); + }); + }); +} + +pub fn pack_cast(c: &mut Criterion) { + let (client_key_1, server_key_1): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + let (client_key_2, server_key_2): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let ksk = KeySwitchingKey::new( + (&client_key_1, &server_key_1), + (&client_key_2, &server_key_2), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); + + let ct_1 = client_key_1.encrypt(1); + let ct_2 = client_key_1.encrypt(1); + + c.bench_function("pack_cast", |b| { + b.iter(|| { + let _ = ksk.cast( + &server_key_1.unchecked_add(&ct_1, &server_key_1.unchecked_scalar_mul(&ct_2, 2)), + ); + }); + }); +} + +pub fn cast(c: &mut Criterion) { + let (client_key_1, server_key_1): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + let (client_key_2, server_key_2): (ClientKey, ServerKey) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let ksk = KeySwitchingKey::new( + (&client_key_1, &server_key_1), + (&client_key_2, &server_key_2), + PARAM_KEYSWITCH_1_1_TO_2_2, + ); + + let ct = client_key_1.encrypt(1); + + c.bench_function("cast", |b| { + b.iter(|| { + let _ = ksk.cast(&ct); + }); + }); +}