mirror of
https://github.com/cursive-team/2P-PSI.git
synced 2026-01-09 03:57:55 -05:00
make code wasm friendly
This commit is contained in:
@@ -3,13 +3,12 @@
|
||||
The library contains the following components:
|
||||
|
||||
- `src`: Rust library for multi-party PSI using BFV
|
||||
- `pkg`: JS-TS-WASM package
|
||||
- `pkg`: JS-TS-WASM package
|
||||
|
||||
### Build
|
||||
### Build
|
||||
|
||||
The rust library is used to build the JS-TS-WASM package using `wasm-pack` targeting `web` [guide](https://developer.mozilla.org/en-US/docs/WebAssembly/Rust_to_Wasm). When compiling to `web` the output can natively be included on a web page, and doesn't require any further postprocessing. The output is included as an ES module. For more information check [`wasm-bindgen` guide](https://rustwasm.github.io/docs/wasm-bindgen/reference/deployment.html)
|
||||
|
||||
|
||||
```bash
|
||||
wasm-pack build --target web
|
||||
```
|
||||
@@ -20,4 +19,4 @@ To test the rust library, run:
|
||||
|
||||
```bash
|
||||
cargo test --release
|
||||
```
|
||||
```
|
||||
|
||||
320
src/lib.rs
320
src/lib.rs
@@ -1,13 +1,18 @@
|
||||
use bfv::{
|
||||
BfvParameters, Ciphertext, CollectiveDecryption, CollectiveDecryptionShare,
|
||||
CollectivePublicKeyGenerator, CollectivePublicKeyShare, CollectiveRlkAggTrimmedShare1,
|
||||
CollectiveRlkGenerator, CollectiveRlkShare1, CollectiveRlkShare2, Encoding, EvaluationKey,
|
||||
Evaluator, Plaintext, SecretKey,
|
||||
BfvParameters, CiphertextProto, CollectiveDecryption, CollectiveDecryptionShare,
|
||||
CollectiveDecryptionShareProto, CollectivePublicKeyGenerator, CollectivePublicKeyShareProto,
|
||||
CollectiveRlkAggTrimmedShare1Proto, CollectiveRlkGenerator, CollectiveRlkShare1Proto,
|
||||
CollectiveRlkShare2Proto, Encoding, EvaluationKey, Evaluator, Plaintext, SecretKey,
|
||||
SecretKeyProto,
|
||||
};
|
||||
use itertools::{izip, Itertools};
|
||||
use rand::thread_rng;
|
||||
|
||||
use traits::{TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithParameters};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use traits::{
|
||||
TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithLevelledParameters,
|
||||
TryFromWithParameters,
|
||||
};
|
||||
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
|
||||
|
||||
mod bandwidth_benches;
|
||||
|
||||
@@ -28,18 +33,33 @@ fn params() -> BfvParameters {
|
||||
params
|
||||
}
|
||||
|
||||
/************* GENERATING KEYS *************/
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct PsiKeys {
|
||||
s: SecretKey,
|
||||
s_rlk: SecretKey,
|
||||
s: SecretKeyProto,
|
||||
s_rlk: SecretKeyProto,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct MessageRound1 {
|
||||
share_pk: CollectivePublicKeyShare,
|
||||
share_rlk1: CollectiveRlkShare1,
|
||||
share_pk: CollectivePublicKeyShareProto,
|
||||
share_rlk1: CollectiveRlkShare1Proto,
|
||||
}
|
||||
|
||||
fn gen_keys() -> (PsiKeys, MessageRound1) {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct GenKeysOutput {
|
||||
psi_keys: PsiKeys,
|
||||
message_round1: MessageRound1,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn gen_keys_js() -> JsValue {
|
||||
let output = gen_keys();
|
||||
serde_wasm_bindgen::to_value(&output).unwrap()
|
||||
}
|
||||
|
||||
fn gen_keys() -> GenKeysOutput {
|
||||
let params = params();
|
||||
let mut rng = thread_rng();
|
||||
let s = SecretKey::random_with_params(¶ms, &mut rng);
|
||||
@@ -49,23 +69,55 @@ fn gen_keys() -> (PsiKeys, MessageRound1) {
|
||||
let share_rlk1 =
|
||||
CollectiveRlkGenerator::generate_share_1(¶ms, &s, &s_rlk, CRS_RLK, 0, &mut rng);
|
||||
|
||||
(
|
||||
PsiKeys { s, s_rlk },
|
||||
MessageRound1 {
|
||||
share_pk,
|
||||
share_rlk1,
|
||||
GenKeysOutput {
|
||||
psi_keys: PsiKeys {
|
||||
s: convert(&s, ¶ms),
|
||||
s_rlk: convert(&s_rlk, ¶ms),
|
||||
},
|
||||
)
|
||||
message_round1: MessageRound1 {
|
||||
share_pk: convert(&share_pk, ¶ms),
|
||||
share_rlk1: convert(&share_rlk1, ¶ms),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/************* ROUND 1 *************/
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Round1Output {
|
||||
state_round2: StateRound2,
|
||||
message_round2: MessageRound2,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StateRound2 {
|
||||
rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1,
|
||||
rlk_agg1_trimmed: CollectiveRlkAggTrimmedShare1Proto,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct MessageRound2 {
|
||||
share_rlk2: CollectiveRlkShare2,
|
||||
cts: Vec<Ciphertext>,
|
||||
share_rlk2: CollectiveRlkShare2Proto,
|
||||
cts: Vec<CiphertextProto>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn round1_js(
|
||||
gen_keys_output: JsValue,
|
||||
other_message_round1: JsValue,
|
||||
bit_vector: &[u32],
|
||||
) -> JsValue {
|
||||
let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output)
|
||||
.expect("failed to deserialize gen_keys_output");
|
||||
let other_message_round1: MessageRound1 = serde_wasm_bindgen::from_value(other_message_round1)
|
||||
.expect("failed to deserialize other_message_round1");
|
||||
let output = round1(
|
||||
&gen_keys_output.psi_keys,
|
||||
gen_keys_output.message_round1,
|
||||
other_message_round1,
|
||||
bit_vector,
|
||||
);
|
||||
|
||||
serde_wasm_bindgen::to_value(&output).unwrap()
|
||||
}
|
||||
|
||||
fn round1(
|
||||
@@ -73,28 +125,35 @@ fn round1(
|
||||
message: MessageRound1,
|
||||
other_message: MessageRound1,
|
||||
bit_vector: &[u32],
|
||||
) -> (StateRound2, MessageRound2) {
|
||||
) -> Round1Output {
|
||||
let params = params();
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let self_s = convert(&psi_keys.s, ¶ms);
|
||||
let self_s_rlk = convert(&psi_keys.s_rlk, ¶ms);
|
||||
let self_share_pk = convert(&message.share_pk, ¶ms);
|
||||
let other_share_pk = convert(&other_message.share_pk, ¶ms);
|
||||
let self_share_rlk1 = convert(&message.share_rlk1, ¶ms);
|
||||
let other_share_rlk1 = convert(&other_message.share_rlk1, ¶ms);
|
||||
|
||||
// generate pk
|
||||
let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise(
|
||||
¶ms,
|
||||
&[message.share_pk, other_message.share_pk],
|
||||
&[self_share_pk, other_share_pk],
|
||||
CRS_PK,
|
||||
);
|
||||
|
||||
// generate rlk share 2
|
||||
let rlk_agg1 = CollectiveRlkGenerator::aggregate_shares_1(
|
||||
¶ms,
|
||||
&[message.share_rlk1, other_message.share_rlk1],
|
||||
&[self_share_rlk1, other_share_rlk1],
|
||||
0,
|
||||
);
|
||||
let share_rlk2 = CollectiveRlkGenerator::generate_share_2(
|
||||
¶ms,
|
||||
&psi_keys.s,
|
||||
&self_s,
|
||||
&rlk_agg1,
|
||||
&psi_keys.s_rlk,
|
||||
&self_s_rlk,
|
||||
0,
|
||||
&mut rng,
|
||||
);
|
||||
@@ -108,24 +167,61 @@ fn round1(
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
(
|
||||
StateRound2 {
|
||||
rlk_agg1_trimmed: rlk_agg1.trim(),
|
||||
Round1Output {
|
||||
state_round2: StateRound2 {
|
||||
rlk_agg1_trimmed: convert(&rlk_agg1.trim(), ¶ms),
|
||||
},
|
||||
MessageRound2 {
|
||||
share_rlk2,
|
||||
cts: ciphertexts,
|
||||
message_round2: MessageRound2 {
|
||||
share_rlk2: convert(&share_rlk2, ¶ms),
|
||||
cts: ciphertexts
|
||||
.iter()
|
||||
.map(|v| convert(v, ¶ms))
|
||||
.collect_vec(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/************* ROUND 2 *************/
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Round2Output {
|
||||
state_round3: StateRound3,
|
||||
message_round3: MessageRound3,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct MessageRound3 {
|
||||
decryption_shares: Vec<CollectiveDecryptionShare>,
|
||||
decryption_shares: Vec<CollectiveDecryptionShareProto>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StateRound3 {
|
||||
cts_res: Vec<Ciphertext>,
|
||||
cts_res: Vec<CiphertextProto>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn round2_js(
|
||||
gen_keys_output: JsValue,
|
||||
round1_output: JsValue,
|
||||
other_message_round2: JsValue,
|
||||
is_a: bool,
|
||||
) -> JsValue {
|
||||
let gen_keys_output: GenKeysOutput = serde_wasm_bindgen::from_value(gen_keys_output)
|
||||
.expect("failed to deserialize gen_keys_output");
|
||||
let round1_output: Round1Output =
|
||||
serde_wasm_bindgen::from_value(round1_output).expect("failed to deserialize round1_output");
|
||||
let other_message_round2: MessageRound2 = serde_wasm_bindgen::from_value(other_message_round2)
|
||||
.expect("failed to deserialize other_message_round2");
|
||||
|
||||
let output = round2(
|
||||
&gen_keys_output.psi_keys,
|
||||
round1_output.state_round2,
|
||||
round1_output.message_round2,
|
||||
other_message_round2,
|
||||
is_a,
|
||||
);
|
||||
|
||||
serde_wasm_bindgen::to_value(&output).unwrap()
|
||||
}
|
||||
|
||||
fn round2(
|
||||
@@ -134,22 +230,39 @@ fn round2(
|
||||
message: MessageRound2,
|
||||
other_message: MessageRound2,
|
||||
is_a: bool,
|
||||
) -> (StateRound3, MessageRound3) {
|
||||
) -> Round2Output {
|
||||
let params = params();
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let self_s = convert(&psi_keys.s, ¶ms);
|
||||
let self_rlk1_agg = convert(&state_round2.rlk_agg1_trimmed, ¶ms);
|
||||
|
||||
let self_share_rlk2 = convert(&message.share_rlk2, ¶ms);
|
||||
let other_share_rlk2 = convert(&other_message.share_rlk2, ¶ms);
|
||||
|
||||
let self_cts = message
|
||||
.cts
|
||||
.iter()
|
||||
.map(|v| convert(v, ¶ms))
|
||||
.collect_vec();
|
||||
let other_cts = other_message
|
||||
.cts
|
||||
.iter()
|
||||
.map(|v| convert(v, ¶ms))
|
||||
.collect_vec();
|
||||
|
||||
// Create RLK
|
||||
let rlk = CollectiveRlkGenerator::aggregate_shares_2(
|
||||
¶ms,
|
||||
&[message.share_rlk2, other_message.share_rlk2],
|
||||
state_round2.rlk_agg1_trimmed,
|
||||
&[self_share_rlk2, other_share_rlk2],
|
||||
self_rlk1_agg,
|
||||
0,
|
||||
);
|
||||
|
||||
// perform PSI
|
||||
let evaluator = Evaluator::new(params.clone());
|
||||
let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
|
||||
let cts_res = izip!(message.cts.iter(), other_message.cts.iter())
|
||||
let cts_res = izip!(self_cts.iter(), other_cts.iter())
|
||||
.map(|(ca, cb)| {
|
||||
let ct_out = {
|
||||
if is_a {
|
||||
@@ -163,10 +276,38 @@ fn round2(
|
||||
.collect_vec();
|
||||
let decryption_shares = cts_res
|
||||
.iter()
|
||||
.map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &psi_keys.s, &mut rng))
|
||||
.map(|c| CollectiveDecryption::generate_share(evaluator.params(), c, &self_s, &mut rng))
|
||||
.collect_vec();
|
||||
|
||||
(StateRound3 { cts_res }, MessageRound3 { decryption_shares })
|
||||
Round2Output {
|
||||
state_round3: StateRound3 {
|
||||
cts_res: cts_res.iter().map(|v| convert(v, ¶ms)).collect_vec(),
|
||||
},
|
||||
message_round3: MessageRound3 {
|
||||
decryption_shares: decryption_shares
|
||||
.iter()
|
||||
.map(|v| {
|
||||
CollectiveDecryptionShareProto::try_from_with_levelled_parameters(v, ¶ms, 0)
|
||||
})
|
||||
.collect_vec(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/************* ROUND 3 *************/
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn round3_js(round2_output: JsValue, other_message: JsValue) -> Vec<u32> {
|
||||
let round2_output: Round2Output =
|
||||
serde_wasm_bindgen::from_value(round2_output).expect("failed to deserialize round2_output");
|
||||
let other_message: MessageRound3 =
|
||||
serde_wasm_bindgen::from_value(other_message).expect("failed to deserialize other_message");
|
||||
|
||||
round3(
|
||||
round2_output.state_round3,
|
||||
round2_output.message_round3,
|
||||
other_message,
|
||||
)
|
||||
}
|
||||
|
||||
fn round3(
|
||||
@@ -175,13 +316,32 @@ fn round3(
|
||||
other_message: MessageRound3,
|
||||
) -> Vec<u32> {
|
||||
let params = params();
|
||||
|
||||
let self_cts = state_round3
|
||||
.cts_res
|
||||
.iter()
|
||||
.map(|v| convert(v, ¶ms))
|
||||
.collect_vec();
|
||||
|
||||
izip!(
|
||||
state_round3.cts_res.iter(),
|
||||
self_cts.iter(),
|
||||
message.decryption_shares.into_iter(),
|
||||
other_message.decryption_shares.into_iter()
|
||||
)
|
||||
.flat_map(|(c, share_a, share_b)| {
|
||||
let pt = CollectiveDecryption::aggregate_share_and_decrypt(¶ms, c, &[share_a, share_b]);
|
||||
.flat_map(|(c, share_a_proto, share_b_proto)| {
|
||||
let shares_vec = vec![
|
||||
CollectiveDecryptionShare::try_from_with_levelled_parameters(
|
||||
&share_a_proto,
|
||||
¶ms,
|
||||
0,
|
||||
),
|
||||
CollectiveDecryptionShare::try_from_with_levelled_parameters(
|
||||
&share_b_proto,
|
||||
¶ms,
|
||||
0,
|
||||
),
|
||||
];
|
||||
let pt = CollectiveDecryption::aggregate_share_and_decrypt(¶ms, c, &shares_vec);
|
||||
Vec::<u32>::try_decoding_with_parameters(&pt, ¶ms, Encoding::default())
|
||||
})
|
||||
.collect_vec()
|
||||
@@ -225,52 +385,56 @@ mod tests {
|
||||
let vector_size = RING_SIZE * 3;
|
||||
|
||||
// gen keys
|
||||
let (a_psi_keys, a_message_round1) = gen_keys();
|
||||
let (b_psi_keys, b_message_round1) = gen_keys();
|
||||
let gen_keys_output_a = gen_keys();
|
||||
let gen_keys_output_b = gen_keys();
|
||||
|
||||
// round1
|
||||
let a_bit_vector = random_bit_vector(hamming_weight, vector_size);
|
||||
let b_bit_vector = random_bit_vector(hamming_weight, vector_size);
|
||||
let (a_state_round2, a_message_round2) = round1(
|
||||
&a_psi_keys,
|
||||
a_message_round1.clone(),
|
||||
b_message_round1.clone(),
|
||||
&a_bit_vector,
|
||||
let bit_vector_a = random_bit_vector(hamming_weight, vector_size);
|
||||
let bit_vector_b = random_bit_vector(hamming_weight, vector_size);
|
||||
let round1_output_a = round1(
|
||||
&gen_keys_output_a.psi_keys,
|
||||
gen_keys_output_a.message_round1.clone(),
|
||||
gen_keys_output_b.message_round1.clone(),
|
||||
&bit_vector_a,
|
||||
);
|
||||
let (b_state_round2, b_message_round2) = round1(
|
||||
&b_psi_keys,
|
||||
b_message_round1,
|
||||
a_message_round1,
|
||||
&b_bit_vector,
|
||||
let round1_output_b = round1(
|
||||
&gen_keys_output_b.psi_keys,
|
||||
gen_keys_output_b.message_round1,
|
||||
gen_keys_output_a.message_round1,
|
||||
&bit_vector_b,
|
||||
);
|
||||
|
||||
// round2
|
||||
let (a_state_round3, a_message_round3) = round2(
|
||||
&a_psi_keys,
|
||||
a_state_round2,
|
||||
a_message_round2.clone(),
|
||||
b_message_round2.clone(),
|
||||
let round2_output_a = round2(
|
||||
&gen_keys_output_a.psi_keys,
|
||||
round1_output_a.state_round2,
|
||||
round1_output_a.message_round2.clone(),
|
||||
round1_output_b.message_round2.clone(),
|
||||
true,
|
||||
);
|
||||
let (b_state_round3, b_message_round3) = round2(
|
||||
&b_psi_keys,
|
||||
b_state_round2,
|
||||
b_message_round2,
|
||||
a_message_round2,
|
||||
let round2_output_b = round2(
|
||||
&gen_keys_output_b.psi_keys,
|
||||
round1_output_b.state_round2,
|
||||
round1_output_b.message_round2,
|
||||
round1_output_a.message_round2,
|
||||
false,
|
||||
);
|
||||
|
||||
// round3
|
||||
let a_psi_output = round3(
|
||||
a_state_round3,
|
||||
a_message_round3.clone(),
|
||||
b_message_round3.clone(),
|
||||
let psi_output_a = round3(
|
||||
round2_output_a.state_round3,
|
||||
round2_output_a.message_round3.clone(),
|
||||
round2_output_b.message_round3.clone(),
|
||||
);
|
||||
let psi_output_b = round3(
|
||||
round2_output_b.state_round3,
|
||||
round2_output_b.message_round3,
|
||||
round2_output_a.message_round3,
|
||||
);
|
||||
let b_psi_output = round3(b_state_round3, b_message_round3, a_message_round3);
|
||||
|
||||
let expected_psi_output = plain_psi(&a_bit_vector, &b_bit_vector);
|
||||
let expected_psi_output = plain_psi(&bit_vector_a, &bit_vector_b);
|
||||
|
||||
assert_eq!(expected_psi_output, a_psi_output[..vector_size]);
|
||||
assert_eq!(a_psi_output, b_psi_output);
|
||||
assert_eq!(expected_psi_output, psi_output_a[..vector_size]);
|
||||
assert_eq!(psi_output_a, psi_output_b);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user