make code wasm friendly

This commit is contained in:
Vivek Bhupatiraju
2024-02-23 19:30:56 -07:00
committed by Vivek
parent 411e64dc9f
commit 5899260cec
2 changed files with 245 additions and 82 deletions

View File

@@ -3,13 +3,12 @@
The library contains the following components:
- `src`: Rust library for multi-party PSI using BFV
- `pkg`: JS-TS-WASM package
- `pkg`: JS-TS-WASM package
### Build
### Build
The rust library is used to build the JS-TS-WASM package using `wasm-pack` targeting `web` [guide](https://developer.mozilla.org/en-US/docs/WebAssembly/Rust_to_Wasm). When compiling to `web` the output can natively be included on a web page, and doesn't require any further postprocessing. The output is included as an ES module. For more information check [`wasm-bindgen` guide](https://rustwasm.github.io/docs/wasm-bindgen/reference/deployment.html)
```bash
wasm-pack build --target web
```
@@ -20,4 +19,4 @@ To test the rust library, run:
```bash
cargo test --release
```
```

View File

@@ -1,13 +1,18 @@
use bfv::{
BfvParameters, Ciphertext, CollectiveDecryption, CollectiveDecryptionShare,
CollectivePublicKeyGenerator, CollectivePublicKeyShare, CollectiveRlkAggTrimmedShare1,
CollectiveRlkGenerator, CollectiveRlkShare1, CollectiveRlkShare2, Encoding, EvaluationKey,
Evaluator, Plaintext, SecretKey,
BfvParameters, CiphertextProto, CollectiveDecryption, CollectiveDecryptionShare,
CollectiveDecryptionShareProto, CollectivePublicKeyGenerator, CollectivePublicKeyShareProto,
CollectiveRlkAggTrimmedShare1Proto, CollectiveRlkGenerator, CollectiveRlkShare1Proto,
CollectiveRlkShare2Proto, Encoding, EvaluationKey, Evaluator, Plaintext, SecretKey,
SecretKeyProto,
};
use itertools::{izip, Itertools};
use rand::thread_rng;
use traits::{TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithParameters};
use serde::{Deserialize, Serialize};
use traits::{
TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithLevelledParameters,
TryFromWithParameters,
};
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
mod bandwidth_benches;
@@ -28,18 +33,33 @@ fn params() -> BfvParameters {
params
}
/************* GENERATING KEYS *************/
#[derive(Serialize, Deserialize)]
struct PsiKeys {
s: SecretKey,
s_rlk: SecretKey,
s: SecretKeyProto,
s_rlk: SecretKeyProto,
}
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
struct MessageRound1 {
share_pk: CollectivePublicKeyShare,
share_rlk1: CollectiveRlkShare1,
share_pk: CollectivePublicKeyShareProto,
share_rlk1: CollectiveRlkShare1Proto,
}
fn gen_keys() -> (PsiKeys, MessageRound1) {
#[derive(Serialize, Deserialize)]
struct GenKeysOutput {
psi_keys: PsiKeys,
message_round1: MessageRound1,
}
#[wasm_bindgen]
pub fn gen_keys_js() -> JsValue {
let output = gen_keys();
serde_wasm_bindgen::to_value(&output).unwrap()
}
fn gen_keys() -> GenKeysOutput {
let params = params();
let mut rng = thread_rng();
let s = SecretKey::random_with_params(&params, &mut rng);
@@ -49,23 +69,55 @@ fn gen_keys() -> (PsiKeys, MessageRound1) {
let share_rlk1 =
CollectiveRlkGenerator::generate_share_1(&params, &s, &s_rlk, CRS_RLK, 0, &mut rng);
(
PsiKeys { s, s_rlk },
MessageRound1 {
share_pk,
share_rlk1,
GenKeysOutput {
psi_keys: PsiKeys {
s: convert(&s, &params),
s_rlk: convert(&s_rlk, &params),
},
)
message_round1: MessageRound1 {
share_pk: convert(&share_pk, &params),
share_rlk1: convert(&share_rlk1, &params),
},
}
}
/************* ROUND 1 *************/
#[derive(Serialize, Deserialize)]
struct Round1Output {
state_round2: StateRound2,
message_round2: MessageRound2,
}
#[derive(Serialize, Deserialize)]
struct StateRound2 {
rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1,
rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1Proto,
}
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
struct MessageRound2 {
share_rlk2: CollectiveRlkShare2,
cts: Vec<Ciphertext>,
share_rlk2: CollectiveRlkShare2Proto,
cts: Vec<CiphertextProto>,
}
#[wasm_bindgen]
pub fn round1_js(
gen_keys_output: JsValue,
other_message_round1: JsValue,
bit_vector: &[u32],
) -> JsValue {
let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output)
.expect("failed to deserialize gen_keys_output");
let other_message_round1: MessageRound1 = serde_wasm_bindgen::from_value(other_message_round1)
.expect("failed to deserialize other_message_round1");
let output = round1(
&gen_keys_output.psi_keys,
gen_keys_output.message_round1,
other_message_round1,
bit_vector,
);
serde_wasm_bindgen::to_value(&output).unwrap()
}
fn round1(
@@ -73,28 +125,35 @@ fn round1(
message: MessageRound1,
other_message: MessageRound1,
bit_vector: &[u32],
) -> (StateRound2, MessageRound2) {
) -> Round1Output {
let params = params();
let mut rng = thread_rng();
let self_s = convert(&psi_keys.s, &params);
let self_s_rlk = convert(&psi_keys.s_rlk, &params);
let self_share_pk = convert(&message.share_pk, &params);
let other_share_pk = convert(&other_message.share_pk, &params);
let self_share_rlk1 = convert(&message.share_rlk1, &params);
let other_share_rlk1 = convert(&other_message.share_rlk1, &params);
// generate pk
let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise(
&params,
&[message.share_pk, other_message.share_pk],
&[self_share_pk, other_share_pk],
CRS_PK,
);
// generate rlk share 2
let rlk_agg1 = CollectiveRlkGenerator::aggregate_shares_1(
&params,
&[message.share_rlk1, other_message.share_rlk1],
&[self_share_rlk1, other_share_rlk1],
0,
);
let share_rlk2 = CollectiveRlkGenerator::generate_share_2(
&params,
&psi_keys.s,
&self_s,
&rlk_agg1,
&psi_keys.s_rlk,
&self_s_rlk,
0,
&mut rng,
);
@@ -108,24 +167,61 @@ fn round1(
})
.collect_vec();
(
StateRound2 {
rlk_agg1_trimmed: rlk_agg1.trim(),
Round1Output {
state_round2: StateRound2 {
rlk_agg1_trimmed: convert(&rlk_agg1.trim(), &params),
},
MessageRound2 {
share_rlk2,
cts: ciphertexts,
message_round2: MessageRound2 {
share_rlk2: convert(&share_rlk2, &params),
cts: ciphertexts
.iter()
.map(|v| convert(v, &params))
.collect_vec(),
},
)
}
}
#[derive(Clone)]
/************* ROUND 2 *************/
#[derive(Serialize, Deserialize)]
struct Round2Output {
state_round3: StateRound3,
message_round3: MessageRound3,
}
#[derive(Clone, Serialize, Deserialize)]
struct MessageRound3 {
decryption_shares: Vec<CollectiveDecryptionShare>,
decryption_shares: Vec<CollectiveDecryptionShareProto>,
}
#[derive(Serialize, Deserialize)]
struct StateRound3 {
cts_res: Vec<Ciphertext>,
cts_res: Vec<CiphertextProto>,
}
#[wasm_bindgen]
pub fn round2_js(
gen_keys_output: JsValue,
round1_output: JsValue,
other_message_round2: JsValue,
is_a: bool,
) -> JsValue {
let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output)
.expect("failed to deserialize gen_keys_output");
let round1_output: Round1Output =
serde_wasm_bindgen::from_value(round1_output).expect("failed to deserialize round1_output");
let other_message_round2: MessageRound2 = serde_wasm_bindgen::from_value(other_message_round2)
.expect("failed to deserialize other_message_round2");
let output = round2(
&gen_keys_output.psi_keys,
round1_output.state_round2,
round1_output.message_round2,
other_message_round2,
is_a,
);
serde_wasm_bindgen::to_value(&output).unwrap()
}
fn round2(
@@ -134,22 +230,39 @@ fn round2(
message: MessageRound2,
other_message: MessageRound2,
is_a: bool,
) -> (StateRound3, MessageRound3) {
) -> Round2Output {
let params = params();
let mut rng = thread_rng();
let self_s = convert(&psi_keys.s, &params);
let self_rlk1_agg = convert(&state_round2.rlk_agg1_trimmed, &params);
let self_share_rlk2 = convert(&message.share_rlk2, &params);
let other_share_rlk2 = convert(&other_message.share_rlk2, &params);
let self_cts = message
.cts
.iter()
.map(|v| convert(v, &params))
.collect_vec();
let other_cts = other_message
.cts
.iter()
.map(|v| convert(v, &params))
.collect_vec();
// Create RLK
let rlk = CollectiveRlkGenerator::aggregate_shares_2(
&params,
&[message.share_rlk2, other_message.share_rlk2],
state_round2.rlk_agg1_trimmed,
&[self_share_rlk2, other_share_rlk2],
self_rlk1_agg,
0,
);
// perform PSI
let evaluator = Evaluator::new(params.clone());
let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
let cts_res = izip!(message.cts.iter(), other_message.cts.iter())
let cts_res = izip!(self_cts.iter(), other_cts.iter())
.map(|(ca, cb)| {
let ct_out = {
if is_a {
@@ -163,10 +276,38 @@ fn round2(
.collect_vec();
let decryption_shares = cts_res
.iter()
.map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &psi_keys.s, &mut rng))
.map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &self_s, &mut rng))
.collect_vec();
(StateRound3 { cts_res }, MessageRound3 { decryption_shares })
Round2Output {
state_round3: StateRound3 {
cts_res: cts_res.iter().map(|v| convert(v, &params)).collect_vec(),
},
message_round3: MessageRound3 {
decryption_shares: decryption_shares
.iter()
.map(|v| {
CollectiveDecryptionShareProto::try_from_with_levelled_parameters(v, &params, 0)
})
.collect_vec(),
},
}
}
/************* ROUND 3 *************/
#[wasm_bindgen]
pub fn round3_js(round2_output: JsValue, other_message: JsValue) -> Vec<u32> {
let round2_output: Round2Output =
serde_wasm_bindgen::from_value(round2_output).expect("failed to deserialize round2_output");
let other_message: MessageRound3 =
serde_wasm_bindgen::from_value(other_message).expect("failed to deserialize other_message");
round3(
round2_output.state_round3,
round2_output.message_round3,
other_message,
)
}
fn round3(
@@ -175,13 +316,32 @@ fn round3(
other_message: MessageRound3,
) -> Vec<u32> {
let params = params();
let self_cts = state_round3
.cts_res
.iter()
.map(|v| convert(v, &params))
.collect_vec();
izip!(
state_round3.cts_res.iter(),
self_cts.iter(),
message.decryption_shares.into_iter(),
other_message.decryption_shares.into_iter()
)
.flat_map(|(c, share_a, share_b)| {
let pt = CollectiveDecryption::aggregate_share_and_decrypt(&params, c, &[share_a, share_b]);
.flat_map(|(c, share_a_proto, share_b_proto)| {
let shares_vec = vec![
CollectiveDecryptionShare::try_from_with_levelled_parameters(
&share_a_proto,
&params,
0,
),
CollectiveDecryptionShare::try_from_with_levelled_parameters(
&share_b_proto,
&params,
0,
),
];
let pt = CollectiveDecryption::aggregate_share_and_decrypt(&params, c, &shares_vec);
Vec::<u32>::try_decoding_with_parameters(&pt, &params, Encoding::default())
})
.collect_vec()
@@ -225,52 +385,56 @@ mod tests {
let vector_size = RING_SIZE * 3;
// gen keys
let (a_psi_keys, a_message_round1) = gen_keys();
let (b_psi_keys, b_message_round1) = gen_keys();
let gen_keys_output_a = gen_keys();
let gen_keys_output_b = gen_keys();
// round1
let a_bit_vector = random_bit_vector(hamming_weight, vector_size);
let b_bit_vector = random_bit_vector(hamming_weight, vector_size);
let (a_state_round2, a_message_round2) = round1(
&a_psi_keys,
a_message_round1.clone(),
b_message_round1.clone(),
&a_bit_vector,
let bit_vector_a = random_bit_vector(hamming_weight, vector_size);
let bit_vector_b = random_bit_vector(hamming_weight, vector_size);
let round1_output_a = round1(
&gen_keys_output_a.psi_keys,
gen_keys_output_a.message_round1.clone(),
gen_keys_output_b.message_round1.clone(),
&bit_vector_a,
);
let (b_state_round2, b_message_round2) = round1(
&b_psi_keys,
b_message_round1,
a_message_round1,
&b_bit_vector,
let round1_output_b = round1(
&gen_keys_output_b.psi_keys,
gen_keys_output_b.message_round1,
gen_keys_output_a.message_round1,
&bit_vector_b,
);
// round2
let (a_state_round3, a_message_round3) = round2(
&a_psi_keys,
a_state_round2,
a_message_round2.clone(),
b_message_round2.clone(),
let round2_output_a = round2(
&gen_keys_output_a.psi_keys,
round1_output_a.state_round2,
round1_output_a.message_round2.clone(),
round1_output_b.message_round2.clone(),
true,
);
let (b_state_round3, b_message_round3) = round2(
&b_psi_keys,
b_state_round2,
b_message_round2,
a_message_round2,
let round2_output_b = round2(
&gen_keys_output_b.psi_keys,
round1_output_b.state_round2,
round1_output_b.message_round2,
round1_output_a.message_round2,
false,
);
// round3
let a_psi_output = round3(
a_state_round3,
a_message_round3.clone(),
b_message_round3.clone(),
let psi_output_a = round3(
round2_output_a.state_round3,
round2_output_a.message_round3.clone(),
round2_output_b.message_round3.clone(),
);
let psi_output_b = round3(
round2_output_b.state_round3,
round2_output_b.message_round3,
round2_output_a.message_round3,
);
let b_psi_output = round3(b_state_round3, b_message_round3, a_message_round3);
let expected_psi_output = plain_psi(&a_bit_vector, &b_bit_vector);
let expected_psi_output = plain_psi(&bit_vector_a, &bit_vector_b);
assert_eq!(expected_psi_output, a_psi_output[..vector_size]);
assert_eq!(a_psi_output, b_psi_output);
assert_eq!(expected_psi_output, psi_output_a[..vector_size]);
assert_eq!(psi_output_a, psi_output_b);
}
}