From 6c89f8ae7b2a879929ccb04b83079339d32537a8 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 12 Jan 2024 19:36:26 +0530 Subject: [PATCH] lib psi works --- Cargo.lock | 1 + Cargo.toml | 1 + src/lib.rs | 341 +++++++++++++++++++++++++++++++++++++++++++++++++--- src/main.rs | 4 +- 4 files changed, 327 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c038a6..a02f144 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,6 +133,7 @@ dependencies = [ "bfv", "itertools", "rand", + "traits", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 2d55af9..cd0481c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,6 @@ edition = "2021" [dependencies] bfv = {path = "./../bfv/bfv"} +traits = {path = "./../bfv/traits"} rand = "0.8.5" itertools = "0.10.5" diff --git a/src/lib.rs b/src/lib.rs index 9e4b008..c62ead4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ use bfv::{ - BfvParameters, Ciphertext, CollectivePublicKeyGenerator, CollectiveRlkGenerator, - CollectiveRlkGeneratorState, Poly, SecretKey, + BfvParameters, Ciphertext, CollectiveDecryption, CollectivePublicKeyGenerator, + CollectiveRlkGenerator, CollectiveRlkGeneratorState, Encoding, EvaluationKey, Evaluator, + Plaintext, Poly, SecretKey, }; use rand::thread_rng; +use traits::{TryDecodingWithParameters, TryEncodingWithParameters}; static CRS_PK: [u8; 32] = [0u8; 32]; static CRS_RLK: [u8; 32] = [0u8; 32]; @@ -18,21 +20,53 @@ struct PrivateOutputAPostState0 { s_pk_a: SecretKey, s_rlk_a: SecretKey, } + +struct PublicOutputAPostState0 { + share_pk_a: Poly, + share_rlk_a_round1: (Vec, Vec), +} + struct MessageAToBPostState0 { share_pk_a: Poly, - share_rlk_a: (Vec, Vec), + share_rlk_a_round1: (Vec, Vec), } -struct PrivateOutputBPostState0 { +struct PrivateOutputBPostState1 { s_pk_b: SecretKey, } -struct MessageBToAPostState1 {} -struct MessageAToBPostState2 {} +struct PublicOutputBPostState1 { + ciphertext_b: Ciphertext, + share_rlk_b_round2: (Vec, Vec), + rlk_agg_round1_h1s: Vec, +} +struct MessageBToAPostState1 { + share_pk_b: Poly, + share_rlk_b_round1: (Vec, Vec), + share_rlk_b_round2: (Vec, Vec), + ciphertext_b: Ciphertext, +} -struct MessageBToAPostState3 {} +struct PublicOutputAPostState2 { + decryption_share_a: Poly, + ciphertext_res: Ciphertext, +} -fn state0() -> (PrivateOutputAPostState0, MessageAToBPostState0) { +struct MessageAToBPostState2 { + decryption_share_a: Poly, + ciphertext_a: Ciphertext, + share_rlk_a_round2: (Vec, Vec), +} + +struct MessageBToAPostState3 { + decryption_share_b: Poly, +} + +fn state0() -> ( + PrivateOutputAPostState0, + PublicOutputAPostState0, + MessageAToBPostState0, +) { let params = params(); let mut rng = thread_rng(); let s_pk_a = SecretKey::random_with_params(¶ms, &mut rng); @@ -41,23 +75,34 @@ fn state0() -> (PrivateOutputAPostState0, MessageAToBPostState0) { let share_pk_a = CollectivePublicKeyGenerator::generate_share(¶ms, &s_pk_a, CRS_PK, &mut rng); - let share_rlk_a = + let share_rlk_a_round1 = CollectiveRlkGenerator::generate_share_1(¶ms, &s_pk_a, &s_rlk_a, CRS_RLK, 0, &mut rng); let message_a_to_b = MessageAToBPostState0 { - share_pk_a, - share_rlk_a, + share_pk_a: share_pk_a.clone(), + share_rlk_a_round1: share_rlk_a_round1.clone(), }; let private_state_a = PrivateOutputAPostState0 { s_pk_a, s_rlk_a: s_rlk_a.0.clone(), }; + let public_output_a = PublicOutputAPostState0 { + share_pk_a, + share_rlk_a_round1, + }; - (private_state_a, message_a_to_b) + (private_state_a, public_output_a, message_a_to_b) } -fn state1(input: MessageAToBPostState0) -> () { +fn state1( + message_from_a: MessageAToBPostState0, + bit_vector: &[u32], +) -> ( + PrivateOutputBPostState1, + PublicOutputBPostState1, + MessageBToAPostState1, +) { let params = params(); let mut rng = thread_rng(); let s_pk_b = SecretKey::random_with_params(¶ms, &mut rng); @@ -66,18 +111,278 @@ fn state1(input: MessageAToBPostState0) -> () { let share_pk_b = CollectivePublicKeyGenerator::generate_share(¶ms, &s_pk_b, CRS_PK, &mut rng); - let share_rlk_b = + let share_rlk_b_round1 = CollectiveRlkGenerator::generate_share_1(¶ms, &s_pk_b, &s_rlk_b, CRS_RLK, 0, &mut rng); - // collective public key - let collective_pk_shares = vec![share_pk_b.clone(), input.share_pk_a]; + // rlk key part 1 + let h0s = vec![ + message_from_a.share_rlk_a_round1.0, + share_rlk_b_round1.0.clone(), + ]; + let h1s = vec![ + message_from_a.share_rlk_a_round1.1, + share_rlk_b_round1.1.clone(), + ]; + let rlk_agg_1 = CollectiveRlkGenerator::aggregate_shares_1(¶ms, &h0s, &h1s, 0); + + // B already has access to aggregate shares for rlk round 1 and can proceed with the second round of the protocol + let share_rlk_b_round2 = CollectiveRlkGenerator::generate_share_2( + ¶ms, + &s_pk_b, + &rlk_agg_1.0, + &rlk_agg_1.1, + &s_rlk_b, + 0, + &mut rng, + ); + + // generate collective public key and encryt b's input + let collective_pk_shares = vec![share_pk_b.clone(), message_from_a.share_pk_a]; let collecitve_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise( ¶ms, &collective_pk_shares, CRS_PK, ); + let pt = Plaintext::try_encoding_with_parameters(bit_vector, ¶ms, Encoding::default()); + let ciphertext_b = collecitve_pk.encrypt(¶ms, &pt, &mut rng); + + let message_to_a = MessageBToAPostState1 { + share_pk_b, + share_rlk_b_round1, + share_rlk_b_round2: share_rlk_b_round2.clone(), + ciphertext_b: ciphertext_b.clone(), + }; + + let private_output_b = PrivateOutputBPostState1 { s_pk_b }; + let public_output_b = PublicOutputBPostState1 { + ciphertext_b, + share_rlk_b_round2, + rlk_agg_round1_h1s: rlk_agg_1.1, + }; + + (private_output_b, public_output_b, message_to_a) } -fn state2() {} +fn state2( + private_output_a_state0: PrivateOutputAPostState0, + public_output_a_state0: PublicOutputAPostState0, + message_from_b: MessageBToAPostState1, + bit_vector: &[u32], +) -> (PublicOutputAPostState2, MessageAToBPostState2) { + let params = params(); + let mut rng = thread_rng(); -fn state3() {} + // aggrgegate shares of rlk round 1 + let h0s = vec![ + public_output_a_state0.share_rlk_a_round1.0, + message_from_b.share_rlk_b_round1.0, + ]; + let h1s = vec![ + public_output_a_state0.share_rlk_a_round1.1, + message_from_b.share_rlk_b_round1.1, + ]; + let rlk_agg_1 = CollectiveRlkGenerator::aggregate_shares_1(¶ms, &h0s, &h1s, 0); + + // generate share 2 for rlk round 2 + let share_rlk_a_round2 = CollectiveRlkGenerator::generate_share_2( + ¶ms, + &private_output_a_state0.s_pk_a, + &rlk_agg_1.0, + &rlk_agg_1.1, + &CollectiveRlkGeneratorState(private_output_a_state0.s_rlk_a), + 0, + &mut rng, + ); + + // aggregate rlk round 2 shares and generate rlk + let h0_dash_shares = vec![ + share_rlk_a_round2.0.clone(), + message_from_b.share_rlk_b_round2.0, + ]; + let h1_dash_shares = vec![ + share_rlk_a_round2.1.clone(), + message_from_b.share_rlk_b_round2.1, + ]; + let rlk = CollectiveRlkGenerator::aggregate_shares_2( + ¶ms, + &h0_dash_shares, + &h1_dash_shares, + rlk_agg_1.1, + 0, + ); + + // create public key and encrypt A's bit vector' + let collective_pk_shares = vec![public_output_a_state0.share_pk_a, message_from_b.share_pk_b]; + let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise( + ¶ms, + &collective_pk_shares, + CRS_PK, + ); + let pt = Plaintext::try_encoding_with_parameters(bit_vector, ¶ms, Encoding::default()); + let ciphertext_a = collective_pk.encrypt(¶ms, &pt, &mut rng); + + // perform PSI + let evaluator = Evaluator::new(params); + let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]); + let ciphertext_res = evaluator.mul(&ciphertext_a, &message_from_b.ciphertext_b); + let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key); + + // generate decryption share of ciphertext_res + let decryption_share_a = CollectiveDecryption::generate_share( + evaluator.params(), + &ciphertext_res, + &private_output_a_state0.s_pk_a, + &mut rng, + ); + + let public_output_a = PublicOutputAPostState2 { + decryption_share_a: decryption_share_a.clone(), + ciphertext_res, + }; + + let message_a_to_b = MessageAToBPostState2 { + decryption_share_a, + ciphertext_a, + share_rlk_a_round2, + }; + + (public_output_a, message_a_to_b) +} + +fn state3( + private_output_b_state1: PrivateOutputBPostState1, + public_output_b_state1: PublicOutputBPostState1, + message_from_a: MessageAToBPostState2, +) -> (MessageBToAPostState3, Vec) { + let params = params(); + let mut rng = thread_rng(); + + // create rlk + let h0_dash_shares = vec![ + message_from_a.share_rlk_a_round2.0, + public_output_b_state1.share_rlk_b_round2.0, + ]; + let h1_dash_shares = vec![ + message_from_a.share_rlk_a_round2.1, + public_output_b_state1.share_rlk_b_round2.1, + ]; + let rlk = CollectiveRlkGenerator::aggregate_shares_2( + ¶ms, + &h0_dash_shares, + &h1_dash_shares, + public_output_b_state1.rlk_agg_round1_h1s, + 0, + ); + + // perform PSI + let evaluator = Evaluator::new(params); + let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]); + let ciphertext_res = evaluator.mul( + &message_from_a.ciphertext_a, + &public_output_b_state1.ciphertext_b, + ); + let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key); + + // generate B's decryption share + let decryption_share_b = CollectiveDecryption::generate_share( + evaluator.params(), + &ciphertext_res, + &private_output_b_state1.s_pk_b, + &mut rng, + ); + + // decrypt ciphertext res + let decryption_shares_vec = vec![ + decryption_share_b.clone(), + message_from_a.decryption_share_a, + ]; + let psi_output = CollectiveDecryption::aggregate_share_and_decrypt( + evaluator.params(), + &ciphertext_res, + &decryption_shares_vec, + ); + let psi_output = Vec::::try_decoding_with_parameters( + &psi_output, + evaluator.params(), + Encoding::default(), + ); + + let message_b_to_a = MessageBToAPostState3 { decryption_share_b }; + + (message_b_to_a, psi_output) +} + +fn state4( + public_output_a_state2: PublicOutputAPostState2, + message_from_b: MessageBToAPostState3, +) -> Vec { + let params = params(); + + // decrypt ciphertext res + let decryption_shares_vec = vec![ + public_output_a_state2.decryption_share_a, + message_from_b.decryption_share_b, + ]; + let psi_output = CollectiveDecryption::aggregate_share_and_decrypt( + ¶ms, + &public_output_a_state2.ciphertext_res, + &decryption_shares_vec, + ); + let psi_output = + Vec::::try_decoding_with_parameters(&psi_output, ¶ms, Encoding::default()); + + psi_output +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{distributions::Uniform, Rng}; + + fn random_bit_vector(hamming_weight: usize, size: usize) -> Vec { + let mut rng = thread_rng(); + + let mut bit_vector = vec![0; size]; + (0..hamming_weight).into_iter().for_each(|_| { + let sample_index = rng.sample(Uniform::new(0, size)); + bit_vector[sample_index] = 1; + }); + + bit_vector + } + + #[test] + fn psi_works() { + let hamming_weight = 10; + let vector_size = 10; + + // A: state 0 + let (private_output_a_state0, public_output_a_state0, message_a_to_b_state0) = state0(); + + // B: state 1 + let bit_vector_b = random_bit_vector(hamming_weight, vector_size); + let (private_output_b_state1, public_output_b_state1, message_b_to_a_state1) = + state1(message_a_to_b_state0, &bit_vector_b); + + // A: state 2 + let bit_vector_a = random_bit_vector(hamming_weight, vector_size); + let (public_output_a_state2, message_a_to_b_state2) = state2( + private_output_a_state0, + public_output_a_state0, + message_b_to_a_state1, + &bit_vector_a, + ); + + // B: state 3 + let (message_b_to_a_state3, psi_output_b) = state3( + private_output_b_state1, + public_output_b_state1, + message_a_to_b_state2, + ); + + // A: state 4 + let psi_output_a = state4(public_output_a_state2, message_b_to_a_state3); + + assert_eq!(psi_output_a, psi_output_b); + } +} diff --git a/src/main.rs b/src/main.rs index 9dbc7e7..7e3f4fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -45,7 +45,7 @@ fn main() { let parties = vec![Party::random(¶ms, hw), Party::random(¶ms, hw)]; // Collective public key generation // - let crs = [0u8; 32]; + let crs = [1u8; 32]; let mut rng = thread_rng(); // Each party generates their share let shares = parties @@ -60,7 +60,7 @@ fn main() { // Collective relinearization key generation // // This is a 2 round protocol - let crs = [0u8; 32]; + let crs = [4u8; 32]; let level = 0; // Each party generates a ephemeral state let parties_state = parties