From 5899260cec1dcc4aceb493777d0044e52838462f Mon Sep 17 00:00:00 2001 From: Vivek Bhupatiraju Date: Fri, 23 Feb 2024 19:30:56 -0700 Subject: [PATCH] make code wasm friendly --- README.md | 7 +- src/lib.rs | 320 ++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 245 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index afbb501..faf94cc 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,12 @@ The library contains the following components: - `src`: Rust library for multi-party PSI using BFV -- `pkg`: JS-TS-WASM package +- `pkg`: JS-TS-WASM package -### Build +### Build The rust library is used to build the JS-TS-WASM package using `wasm-pack` targeting `web` [guide](https://developer.mozilla.org/en-US/docs/WebAssembly/Rust_to_Wasm). When compiling to `web` the output can natively be included on a web page, and doesn't require any further postprocessing. The output is included as an ES module. For more information check [`wasm-bindgen` guide](https://rustwasm.github.io/docs/wasm-bindgen/reference/deployment.html) - ```bash wasm-pack build --target web ``` @@ -20,4 +19,4 @@ To test the rust library, run: ```bash cargo test --release -``` \ No newline at end of file +``` diff --git a/src/lib.rs b/src/lib.rs index eb48f87..9b69abe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,18 @@ use bfv::{ - BfvParameters, Ciphertext, CollectiveDecryption, CollectiveDecryptionShare, - CollectivePublicKeyGenerator, CollectivePublicKeyShare, CollectiveRlkAggTrimmedShare1, - CollectiveRlkGenerator, CollectiveRlkShare1, CollectiveRlkShare2, Encoding, EvaluationKey, - Evaluator, Plaintext, SecretKey, + BfvParameters, CiphertextProto, CollectiveDecryption, CollectiveDecryptionShare, + CollectiveDecryptionShareProto, CollectivePublicKeyGenerator, CollectivePublicKeyShareProto, + CollectiveRlkAggTrimmedShare1Proto, CollectiveRlkGenerator, CollectiveRlkShare1Proto, + CollectiveRlkShare2Proto, Encoding, EvaluationKey, Evaluator, Plaintext, SecretKey, + SecretKeyProto, }; use itertools::{izip, Itertools}; use rand::thread_rng; - -use traits::{TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithParameters}; +use serde::{Deserialize, Serialize}; +use traits::{ + TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithLevelledParameters, + TryFromWithParameters, +}; +use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; mod bandwidth_benches; @@ -28,18 +33,33 @@ fn params() -> BfvParameters { params } +/************* GENERATING KEYS *************/ + +#[derive(Serialize, Deserialize)] struct PsiKeys { - s: SecretKey, - s_rlk: SecretKey, + s: SecretKeyProto, + s_rlk: SecretKeyProto, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] struct MessageRound1 { - share_pk: CollectivePublicKeyShare, - share_rlk1: CollectiveRlkShare1, + share_pk: CollectivePublicKeyShareProto, + share_rlk1: CollectiveRlkShare1Proto, } -fn gen_keys() -> (PsiKeys, MessageRound1) { +#[derive(Serialize, Deserialize)] +struct GenKeysOutput { + psi_keys: PsiKeys, + message_round1: MessageRound1, +} + +#[wasm_bindgen] +pub fn gen_keys_js() -> JsValue { + let output = gen_keys(); + serde_wasm_bindgen::to_value(&output).unwrap() +} + +fn gen_keys() -> GenKeysOutput { let params = params(); let mut rng = thread_rng(); let s = SecretKey::random_with_params(¶ms, &mut rng); @@ -49,23 +69,55 @@ fn gen_keys() -> (PsiKeys, MessageRound1) { let share_rlk1 = CollectiveRlkGenerator::generate_share_1(¶ms, &s, &s_rlk, CRS_RLK, 0, &mut rng); - ( - PsiKeys { s, s_rlk }, - MessageRound1 { - share_pk, - share_rlk1, + GenKeysOutput { + psi_keys: PsiKeys { + s: convert(&s, ¶ms), + s_rlk: convert(&s_rlk, ¶ms), }, - ) + message_round1: MessageRound1 { + share_pk: convert(&share_pk, ¶ms), + share_rlk1: convert(&share_rlk1, ¶ms), + }, + } } +/************* ROUND 1 *************/ + +#[derive(Serialize, Deserialize)] +struct Round1Output { + state_round2: StateRound2, + message_round2: MessageRound2, +} + +#[derive(Serialize, Deserialize)] struct StateRound2 { - rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1, + rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1Proto, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] struct MessageRound2 { - share_rlk2: CollectiveRlkShare2, - cts: Vec, + share_rlk2: CollectiveRlkShare2Proto, + cts: Vec, +} + +#[wasm_bindgen] +pub fn round1_js( + gen_keys_output: JsValue, + other_message_round1: JsValue, + bit_vector: &[u32], +) -> JsValue { + let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output) + .expect("failed to deserialize gen_keys_output"); + let other_message_round1: MessageRound1 = serde_wasm_bindgen::from_value(other_message_round1) + .expect("failed to deserialize other_message_round1"); + let output = round1( + &gen_keys_output.psi_keys, + gen_keys_output.message_round1, + other_message_round1, + bit_vector, + ); + + serde_wasm_bindgen::to_value(&output).unwrap() } fn round1( @@ -73,28 +125,35 @@ fn round1( message: MessageRound1, other_message: MessageRound1, bit_vector: &[u32], -) -> (StateRound2, MessageRound2) { +) -> Round1Output { let params = params(); let mut rng = thread_rng(); + let self_s = convert(&psi_keys.s, ¶ms); + let self_s_rlk = convert(&psi_keys.s_rlk, ¶ms); + let self_share_pk = convert(&message.share_pk, ¶ms); + let other_share_pk = convert(&other_message.share_pk, ¶ms); + let self_share_rlk1 = convert(&message.share_rlk1, ¶ms); + let other_share_rlk1 = convert(&other_message.share_rlk1, ¶ms); + // generate pk let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise( ¶ms, - &[message.share_pk, other_message.share_pk], + &[self_share_pk, other_share_pk], CRS_PK, ); // generate rlk share 2 let rlk_agg1 = CollectiveRlkGenerator::aggregate_shares_1( ¶ms, - &[message.share_rlk1, other_message.share_rlk1], + &[self_share_rlk1, other_share_rlk1], 0, ); let share_rlk2 = CollectiveRlkGenerator::generate_share_2( ¶ms, - &psi_keys.s, + &self_s, &rlk_agg1, - &psi_keys.s_rlk, + &self_s_rlk, 0, &mut rng, ); @@ -108,24 +167,61 @@ fn round1( }) .collect_vec(); - ( - StateRound2 { - rlk_agg1_trimmed: rlk_agg1.trim(), + Round1Output { + state_round2: StateRound2 { + rlk_agg1_trimmed: convert(&rlk_agg1.trim(), ¶ms), }, - MessageRound2 { - share_rlk2, - cts: ciphertexts, + message_round2: MessageRound2 { + share_rlk2: convert(&share_rlk2, ¶ms), + cts: ciphertexts + .iter() + .map(|v| convert(v, ¶ms)) + .collect_vec(), }, - ) + } } -#[derive(Clone)] +/************* ROUND 2 *************/ + +#[derive(Serialize, Deserialize)] +struct Round2Output { + state_round3: StateRound3, + message_round3: MessageRound3, +} + +#[derive(Clone, Serialize, Deserialize)] struct MessageRound3 { - decryption_shares: Vec, + decryption_shares: Vec, } +#[derive(Serialize, Deserialize)] struct StateRound3 { - cts_res: Vec, + cts_res: Vec, +} + +#[wasm_bindgen] +pub fn round2_js( + gen_keys_output: JsValue, + round1_output: JsValue, + other_message_round2: JsValue, + is_a: bool, +) -> JsValue { + let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output) + .expect("failed to deserialize gen_keys_output"); + let round1_output: Round1Output = + serde_wasm_bindgen::from_value(round1_output).expect("failed to deserialize round1_output"); + let other_message_round2: MessageRound2 = serde_wasm_bindgen::from_value(other_message_round2) + .expect("failed to deserialize other_message_round2"); + + let output = round2( + &gen_keys_output.psi_keys, + round1_output.state_round2, + round1_output.message_round2, + other_message_round2, + is_a, + ); + + serde_wasm_bindgen::to_value(&output).unwrap() } fn round2( @@ -134,22 +230,39 @@ fn round2( message: MessageRound2, other_message: MessageRound2, is_a: bool, -) -> (StateRound3, MessageRound3) { +) -> Round2Output { let params = params(); let mut rng = thread_rng(); + let self_s = convert(&psi_keys.s, ¶ms); + let self_rlk1_agg = convert(&state_round2.rlk_agg1_trimmed, ¶ms); + + let self_share_rlk2 = convert(&message.share_rlk2, ¶ms); + let other_share_rlk2 = convert(&other_message.share_rlk2, ¶ms); + + let self_cts = message + .cts + .iter() + .map(|v| convert(v, ¶ms)) + .collect_vec(); + let other_cts = other_message + .cts + .iter() + .map(|v| convert(v, ¶ms)) + .collect_vec(); + // Create RLK let rlk = CollectiveRlkGenerator::aggregate_shares_2( ¶ms, - &[message.share_rlk2, other_message.share_rlk2], - state_round2.rlk_agg1_trimmed, + &[self_share_rlk2, other_share_rlk2], + self_rlk1_agg, 0, ); // perform PSI let evaluator = Evaluator::new(params.clone()); let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]); - let cts_res = izip!(message.cts.iter(), other_message.cts.iter()) + let cts_res = izip!(self_cts.iter(), other_cts.iter()) .map(|(ca, cb)| { let ct_out = { if is_a { @@ -163,10 +276,38 @@ fn round2( .collect_vec(); let decryption_shares = cts_res .iter() - .map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &psi_keys.s, &mut rng)) + .map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &self_s, &mut rng)) .collect_vec(); - (StateRound3 { cts_res }, MessageRound3 { decryption_shares }) + Round2Output { + state_round3: StateRound3 { + cts_res: cts_res.iter().map(|v| convert(v, ¶ms)).collect_vec(), + }, + message_round3: MessageRound3 { + decryption_shares: decryption_shares + .iter() + .map(|v| { + CollectiveDecryptionShareProto::try_from_with_levelled_parameters(v, ¶ms, 0) + }) + .collect_vec(), + }, + } +} + +/************* ROUND 3 *************/ + +#[wasm_bindgen] +pub fn round3_js(round2_output: JsValue, other_message: JsValue) -> Vec { + let round2_output: Round2Output = + serde_wasm_bindgen::from_value(round2_output).expect("failed to deserialize round2_output"); + let other_message: MessageRound3 = + serde_wasm_bindgen::from_value(other_message).expect("failed to deserialize other_message"); + + round3( + round2_output.state_round3, + round2_output.message_round3, + other_message, + ) } fn round3( @@ -175,13 +316,32 @@ fn round3( other_message: MessageRound3, ) -> Vec { let params = params(); + + let self_cts = state_round3 + .cts_res + .iter() + .map(|v| convert(v, ¶ms)) + .collect_vec(); + izip!( - state_round3.cts_res.iter(), + self_cts.iter(), message.decryption_shares.into_iter(), other_message.decryption_shares.into_iter() ) - .flat_map(|(c, share_a, share_b)| { - let pt = CollectiveDecryption::aggregate_share_and_decrypt(¶ms, c, &[share_a, share_b]); + .flat_map(|(c, share_a_proto, share_b_proto)| { + let shares_vec = vec![ + CollectiveDecryptionShare::try_from_with_levelled_parameters( + &share_a_proto, + ¶ms, + 0, + ), + CollectiveDecryptionShare::try_from_with_levelled_parameters( + &share_b_proto, + ¶ms, + 0, + ), + ]; + let pt = CollectiveDecryption::aggregate_share_and_decrypt(¶ms, c, &shares_vec); Vec::::try_decoding_with_parameters(&pt, ¶ms, Encoding::default()) }) .collect_vec() @@ -225,52 +385,56 @@ mod tests { let vector_size = RING_SIZE * 3; // gen keys - let (a_psi_keys, a_message_round1) = gen_keys(); - let (b_psi_keys, b_message_round1) = gen_keys(); + let gen_keys_output_a = gen_keys(); + let gen_keys_output_b = gen_keys(); // round1 - let a_bit_vector = random_bit_vector(hamming_weight, vector_size); - let b_bit_vector = random_bit_vector(hamming_weight, vector_size); - let (a_state_round2, a_message_round2) = round1( - &a_psi_keys, - a_message_round1.clone(), - b_message_round1.clone(), - &a_bit_vector, + let bit_vector_a = random_bit_vector(hamming_weight, vector_size); + let bit_vector_b = random_bit_vector(hamming_weight, vector_size); + let round1_output_a = round1( + &gen_keys_output_a.psi_keys, + gen_keys_output_a.message_round1.clone(), + gen_keys_output_b.message_round1.clone(), + &bit_vector_a, ); - let (b_state_round2, b_message_round2) = round1( - &b_psi_keys, - b_message_round1, - a_message_round1, - &b_bit_vector, + let round1_output_b = round1( + &gen_keys_output_b.psi_keys, + gen_keys_output_b.message_round1, + gen_keys_output_a.message_round1, + &bit_vector_b, ); // round2 - let (a_state_round3, a_message_round3) = round2( - &a_psi_keys, - a_state_round2, - a_message_round2.clone(), - b_message_round2.clone(), + let round2_output_a = round2( + &gen_keys_output_a.psi_keys, + round1_output_a.state_round2, + round1_output_a.message_round2.clone(), + round1_output_b.message_round2.clone(), true, ); - let (b_state_round3, b_message_round3) = round2( - &b_psi_keys, - b_state_round2, - b_message_round2, - a_message_round2, + let round2_output_b = round2( + &gen_keys_output_b.psi_keys, + round1_output_b.state_round2, + round1_output_b.message_round2, + round1_output_a.message_round2, false, ); // round3 - let a_psi_output = round3( - a_state_round3, - a_message_round3.clone(), - b_message_round3.clone(), + let psi_output_a = round3( + round2_output_a.state_round3, + round2_output_a.message_round3.clone(), + round2_output_b.message_round3.clone(), + ); + let psi_output_b = round3( + round2_output_b.state_round3, + round2_output_b.message_round3, + round2_output_a.message_round3, ); - let b_psi_output = round3(b_state_round3, b_message_round3, a_message_round3); - let expected_psi_output = plain_psi(&a_bit_vector, &b_bit_vector); + let expected_psi_output = plain_psi(&bit_vector_a, &bit_vector_b); - assert_eq!(expected_psi_output, a_psi_output[..vector_size]); - assert_eq!(a_psi_output, b_psi_output); + assert_eq!(expected_psi_output, psi_output_a[..vector_size]); + assert_eq!(psi_output_a, psi_output_b); } }