feat: state0 and state1 bindgen

This commit is contained in:
Enrico Bottazzi
2024-01-18 15:50:37 +01:00
parent a7f6147420
commit e5e4ece358
8 changed files with 650 additions and 295 deletions

12
Cargo.lock generated
View File

@@ -320,6 +320,7 @@ dependencies = [
"itertools 0.10.5",
"rand",
"serde",
"serde-wasm-bindgen",
"traits",
"wasm-bindgen",
]
@@ -678,6 +679,17 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
dependencies = [
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
name = "serde_derive"
version = "1.0.195"

View File

@@ -17,7 +17,8 @@ traits = {git = "https://github.com/Janmajayamall/bfv.git", branch = "wasm"}
rand = "0.8.5"
itertools = "0.10.5"
wasm-bindgen = "0.2"
serde = "1.0.195"
serde = { version = "1.0", features = ["derive"] }
serde-wasm-bindgen = "0.4"
# The `console_error_panic_hook` crate provides better debugging of panics by

View File

@@ -6,16 +6,30 @@
</head>
<body>
<script type="module">
import init, { state0_serialized } from "./pkg/mp_psi.js";
import init, { state0_bindgen, state1_bindgen } from "./pkg/mp_psi.js";
function randomBitVector(hammingWeight, size) {
let bitVector = new Array(size).fill(0);
for (let i = 0; i < hammingWeight; i++) {
let sampleIndex;
do {
sampleIndex = Math.floor(Math.random() * size);
} while (bitVector[sampleIndex] === 1);
bitVector[sampleIndex] = 1;
}
return bitVector;
}
init().then(() => {
const private_output_a_state_0 = state0_serialized();
// slice the array into 2 arrays of 513 elements each
const s_pk_a = private_output_a_state_0.slice(0, 513);
const s_rlk_a = private_output_a_state_0.slice(513, 1026);
print("s_pk_a", s_pk_a);
print("length of s_pk_a", s_pk_a.length)
print("s_rlk_a", s_rlk_a);
print("length of s_rlk_a", s_rlk_a.length)
const state0 = state0_bindgen();
// Generate a random bit vector
const hammingWeight = 3;
const size = 10;
const bitVector = randomBitVector(hammingWeight, size);
const state1 = state1_bindgen(state0.message_a_to_b, bitVector);
console.log(state1)
});
</script>
</body>

17
pkg/mp_psi.d.ts vendored
View File

@@ -1,17 +1,24 @@
/* tslint:disable */
/* eslint-disable */
/**
* @returns {Uint8Array}
* @returns {any}
*/
export function state0_serialized(): Uint8Array;
export function state0_bindgen(): any;
/**
* @param {any} message_from_a
* @param {Uint32Array} bit_vector
* @returns {any}
*/
export function state1_bindgen(message_from_a: any, bit_vector: Uint32Array): any;
export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;
export interface InitOutput {
readonly memory: WebAssembly.Memory;
readonly state0_serialized: (a: number) => void;
readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
readonly __wbindgen_free: (a: number, b: number, c: number) => void;
readonly state0_bindgen: () => number;
readonly state1_bindgen: (a: number, b: number, c: number) => number;
readonly __wbindgen_malloc: (a: number, b: number) => number;
readonly __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number;
readonly __wbindgen_exn_store: (a: number) => void;
}

View File

@@ -20,24 +20,6 @@ function takeObject(idx) {
return ret;
}
const cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } );
if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); };
let cachedUint8Memory0 = null;
function getUint8Memory0() {
if (cachedUint8Memory0 === null || cachedUint8Memory0.byteLength === 0) {
cachedUint8Memory0 = new Uint8Array(wasm.memory.buffer);
}
return cachedUint8Memory0;
}
function getStringFromWasm0(ptr, len) {
ptr = ptr >>> 0;
return cachedTextDecoder.decode(getUint8Memory0().subarray(ptr, ptr + len));
}
function addHeapObject(obj) {
if (heap_next === heap.length) heap.push(heap.length + 1);
const idx = heap_next;
@@ -47,6 +29,19 @@ function addHeapObject(obj) {
return idx;
}
function isLikeNone(x) {
return x === undefined || x === null;
}
let cachedFloat64Memory0 = null;
function getFloat64Memory0() {
if (cachedFloat64Memory0 === null || cachedFloat64Memory0.byteLength === 0) {
cachedFloat64Memory0 = new Float64Array(wasm.memory.buffer);
}
return cachedFloat64Memory0;
}
let cachedInt32Memory0 = null;
function getInt32Memory0() {
@@ -56,25 +51,176 @@ function getInt32Memory0() {
return cachedInt32Memory0;
}
function getArrayU8FromWasm0(ptr, len) {
let WASM_VECTOR_LEN = 0;
let cachedUint8Memory0 = null;
function getUint8Memory0() {
if (cachedUint8Memory0 === null || cachedUint8Memory0.byteLength === 0) {
cachedUint8Memory0 = new Uint8Array(wasm.memory.buffer);
}
return cachedUint8Memory0;
}
const cachedTextEncoder = (typeof TextEncoder !== 'undefined' ? new TextEncoder('utf-8') : { encode: () => { throw Error('TextEncoder not available') } } );
const encodeString = (typeof cachedTextEncoder.encodeInto === 'function'
? function (arg, view) {
return cachedTextEncoder.encodeInto(arg, view);
}
: function (arg, view) {
const buf = cachedTextEncoder.encode(arg);
view.set(buf);
return {
read: arg.length,
written: buf.length
};
});
function passStringToWasm0(arg, malloc, realloc) {
if (realloc === undefined) {
const buf = cachedTextEncoder.encode(arg);
const ptr = malloc(buf.length, 1) >>> 0;
getUint8Memory0().subarray(ptr, ptr + buf.length).set(buf);
WASM_VECTOR_LEN = buf.length;
return ptr;
}
let len = arg.length;
let ptr = malloc(len, 1) >>> 0;
const mem = getUint8Memory0();
let offset = 0;
for (; offset < len; offset++) {
const code = arg.charCodeAt(offset);
if (code > 0x7F) break;
mem[ptr + offset] = code;
}
if (offset !== len) {
if (offset !== 0) {
arg = arg.slice(offset);
}
ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0;
const view = getUint8Memory0().subarray(ptr + offset, ptr + len);
const ret = encodeString(arg, view);
offset += ret.written;
}
WASM_VECTOR_LEN = offset;
return ptr;
}
const cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } );
if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); };
function getStringFromWasm0(ptr, len) {
ptr = ptr >>> 0;
return getUint8Memory0().subarray(ptr / 1, ptr / 1 + len);
return cachedTextDecoder.decode(getUint8Memory0().subarray(ptr, ptr + len));
}
function debugString(val) {
// primitive types
const type = typeof val;
if (type == 'number' || type == 'boolean' || val == null) {
return `${val}`;
}
if (type == 'string') {
return `"${val}"`;
}
if (type == 'symbol') {
const description = val.description;
if (description == null) {
return 'Symbol';
} else {
return `Symbol(${description})`;
}
}
if (type == 'function') {
const name = val.name;
if (typeof name == 'string' && name.length > 0) {
return `Function(${name})`;
} else {
return 'Function';
}
}
// objects
if (Array.isArray(val)) {
const length = val.length;
let debug = '[';
if (length > 0) {
debug += debugString(val[0]);
}
for(let i = 1; i < length; i++) {
debug += ', ' + debugString(val[i]);
}
debug += ']';
return debug;
}
// Test for built-in
const builtInMatches = /\[object ([^\]]+)\]/.exec(toString.call(val));
let className;
if (builtInMatches.length > 1) {
className = builtInMatches[1];
} else {
// Failed to match the standard '[object ClassName]'
return toString.call(val);
}
if (className == 'Object') {
// we're a user defined class or Object
// JSON.stringify avoids problems with cycles, and is generally much
// easier than looping through ownProperties of `val`.
try {
return 'Object(' + JSON.stringify(val) + ')';
} catch (_) {
return 'Object';
}
}
// errors
if (val instanceof Error) {
return `${val.name}: ${val.message}\n${val.stack}`;
}
// TODO we could test for more things here, like `Set`s and `Map`s.
return className;
}
/**
* @returns {Uint8Array}
* @returns {any}
*/
export function state0_serialized() {
try {
const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
wasm.state0_serialized(retptr);
var r0 = getInt32Memory0()[retptr / 4 + 0];
var r1 = getInt32Memory0()[retptr / 4 + 1];
var v1 = getArrayU8FromWasm0(r0, r1).slice();
wasm.__wbindgen_free(r0, r1 * 1, 1);
return v1;
} finally {
wasm.__wbindgen_add_to_stack_pointer(16);
export function state0_bindgen() {
const ret = wasm.state0_bindgen();
return takeObject(ret);
}
let cachedUint32Memory0 = null;
function getUint32Memory0() {
if (cachedUint32Memory0 === null || cachedUint32Memory0.byteLength === 0) {
cachedUint32Memory0 = new Uint32Array(wasm.memory.buffer);
}
return cachedUint32Memory0;
}
function passArray32ToWasm0(arg, malloc) {
const ptr = malloc(arg.length * 4, 4) >>> 0;
getUint32Memory0().set(arg, ptr / 4);
WASM_VECTOR_LEN = arg.length;
return ptr;
}
/**
* @param {any} message_from_a
* @param {Uint32Array} bit_vector
* @returns {any}
*/
export function state1_bindgen(message_from_a, bit_vector) {
const ptr0 = passArray32ToWasm0(bit_vector, wasm.__wbindgen_malloc);
const len0 = WASM_VECTOR_LEN;
const ret = wasm.state1_bindgen(addHeapObject(message_from_a), ptr0, len0);
return takeObject(ret);
}
function handleError(f, args) {
@@ -119,15 +265,72 @@ async function __wbg_load(module, imports) {
function __wbg_get_imports() {
const imports = {};
imports.wbg = {};
imports.wbg.__wbg_crypto_d05b68a3572bb8ca = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
imports.wbg.__wbindgen_object_drop_ref = function(arg0) {
takeObject(arg0);
};
imports.wbg.__wbindgen_is_object = function(arg0) {
const val = getObject(arg0);
const ret = typeof(val) === 'object' && val !== null;
return ret;
};
imports.wbg.__wbindgen_is_undefined = function(arg0) {
const ret = getObject(arg0) === undefined;
return ret;
};
imports.wbg.__wbindgen_in = function(arg0, arg1) {
const ret = getObject(arg0) in getObject(arg1);
return ret;
};
imports.wbg.__wbindgen_object_clone_ref = function(arg0) {
const ret = getObject(arg0);
return addHeapObject(ret);
};
imports.wbg.__wbindgen_jsval_loose_eq = function(arg0, arg1) {
const ret = getObject(arg0) == getObject(arg1);
return ret;
};
imports.wbg.__wbindgen_boolean_get = function(arg0) {
const v = getObject(arg0);
const ret = typeof(v) === 'boolean' ? (v ? 1 : 0) : 2;
return ret;
};
imports.wbg.__wbindgen_number_get = function(arg0, arg1) {
const obj = getObject(arg1);
const ret = typeof(obj) === 'number' ? obj : undefined;
getFloat64Memory0()[arg0 / 8 + 1] = isLikeNone(ret) ? 0 : ret;
getInt32Memory0()[arg0 / 4 + 0] = !isLikeNone(ret);
};
imports.wbg.__wbindgen_string_get = function(arg0, arg1) {
const obj = getObject(arg1);
const ret = typeof(obj) === 'string' ? obj : undefined;
var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
var len1 = WASM_VECTOR_LEN;
getInt32Memory0()[arg0 / 4 + 1] = len1;
getInt32Memory0()[arg0 / 4 + 0] = ptr1;
};
imports.wbg.__wbindgen_error_new = function(arg0, arg1) {
const ret = new Error(getStringFromWasm0(arg0, arg1));
return addHeapObject(ret);
};
imports.wbg.__wbindgen_number_new = function(arg0) {
const ret = arg0;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_string_new = function(arg0, arg1) {
const ret = getStringFromWasm0(arg0, arg1);
return addHeapObject(ret);
};
imports.wbg.__wbg_getwithrefkey_15c62c2b8546208d = function(arg0, arg1) {
const ret = getObject(arg0)[getObject(arg1)];
return addHeapObject(ret);
};
imports.wbg.__wbg_set_20cbc34131e76824 = function(arg0, arg1, arg2) {
getObject(arg0)[takeObject(arg1)] = takeObject(arg2);
};
imports.wbg.__wbg_crypto_d05b68a3572bb8ca = function(arg0) {
const ret = getObject(arg0).crypto;
return addHeapObject(ret);
};
imports.wbg.__wbg_process_b02b3570280d0366 = function(arg0) {
const ret = getObject(arg0).process;
return addHeapObject(ret);
@@ -136,9 +339,6 @@ function __wbg_get_imports() {
const ret = getObject(arg0).versions;
return addHeapObject(ret);
};
imports.wbg.__wbindgen_object_drop_ref = function(arg0) {
takeObject(arg0);
};
imports.wbg.__wbg_node_43b1089f407e4ec2 = function(arg0) {
const ret = getObject(arg0).node;
return addHeapObject(ret);
@@ -151,10 +351,6 @@ function __wbg_get_imports() {
const ret = module.require;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbindgen_string_new = function(arg0, arg1) {
const ret = getStringFromWasm0(arg0, arg1);
return addHeapObject(ret);
};
imports.wbg.__wbg_msCrypto_10fc94afee92bd76 = function(arg0) {
const ret = getObject(arg0).msCrypto;
return addHeapObject(ret);
@@ -165,6 +361,18 @@ function __wbg_get_imports() {
imports.wbg.__wbg_getRandomValues_7e42b4fb8779dc6d = function() { return handleError(function (arg0, arg1) {
getObject(arg0).getRandomValues(getObject(arg1));
}, arguments) };
imports.wbg.__wbg_get_c43534c00f382c8a = function(arg0, arg1) {
const ret = getObject(arg0)[arg1 >>> 0];
return addHeapObject(ret);
};
imports.wbg.__wbg_length_d99b680fd68bf71b = function(arg0) {
const ret = getObject(arg0).length;
return ret;
};
imports.wbg.__wbg_new_34c624469fb1d4fd = function() {
const ret = new Array();
return addHeapObject(ret);
};
imports.wbg.__wbindgen_is_function = function(arg0) {
const ret = typeof(getObject(arg0)) === 'function';
return ret;
@@ -173,10 +381,38 @@ function __wbg_get_imports() {
const ret = new Function(getStringFromWasm0(arg0, arg1));
return addHeapObject(ret);
};
imports.wbg.__wbg_next_1938cf110c9491d4 = function(arg0) {
const ret = getObject(arg0).next;
return addHeapObject(ret);
};
imports.wbg.__wbg_next_267398d0e0761bf9 = function() { return handleError(function (arg0) {
const ret = getObject(arg0).next();
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_done_506b44765ba84b9c = function(arg0) {
const ret = getObject(arg0).done;
return ret;
};
imports.wbg.__wbg_value_31485d8770eb06ab = function(arg0) {
const ret = getObject(arg0).value;
return addHeapObject(ret);
};
imports.wbg.__wbg_iterator_364187e1ee96b750 = function() {
const ret = Symbol.iterator;
return addHeapObject(ret);
};
imports.wbg.__wbg_get_5027b32da70f39b1 = function() { return handleError(function (arg0, arg1) {
const ret = Reflect.get(getObject(arg0), getObject(arg1));
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_call_a79f1973a4f07d5e = function() { return handleError(function (arg0, arg1) {
const ret = getObject(arg0).call(getObject(arg1));
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_new_87d841e70661f6e9 = function() {
const ret = new Object();
return addHeapObject(ret);
};
imports.wbg.__wbg_self_086b5302bcafb962 = function() { return handleError(function () {
const ret = self.self;
return addHeapObject(ret);
@@ -193,14 +429,31 @@ function __wbg_get_imports() {
const ret = global.global;
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbindgen_is_undefined = function(arg0) {
const ret = getObject(arg0) === undefined;
imports.wbg.__wbg_set_379b27f1d5f1bf9c = function(arg0, arg1, arg2) {
getObject(arg0)[arg1 >>> 0] = takeObject(arg2);
};
imports.wbg.__wbg_isArray_fbd24d447869b527 = function(arg0) {
const ret = Array.isArray(getObject(arg0));
return ret;
};
imports.wbg.__wbg_instanceof_ArrayBuffer_f4521cec1b99ee35 = function(arg0) {
let result;
try {
result = getObject(arg0) instanceof ArrayBuffer;
} catch (_) {
result = false;
}
const ret = result;
return ret;
};
imports.wbg.__wbg_call_f6a2bc58c19c53c6 = function() { return handleError(function (arg0, arg1, arg2) {
const ret = getObject(arg0).call(getObject(arg1), getObject(arg2));
return addHeapObject(ret);
}, arguments) };
imports.wbg.__wbg_isSafeInteger_d8c89788832a17bf = function(arg0) {
const ret = Number.isSafeInteger(getObject(arg0));
return ret;
};
imports.wbg.__wbg_buffer_5d1b598a01b41a42 = function(arg0) {
const ret = getObject(arg0).buffer;
return addHeapObject(ret);
@@ -216,6 +469,20 @@ function __wbg_get_imports() {
imports.wbg.__wbg_set_74906aa30864df5a = function(arg0, arg1, arg2) {
getObject(arg0).set(getObject(arg1), arg2 >>> 0);
};
imports.wbg.__wbg_length_f0764416ba5bb237 = function(arg0) {
const ret = getObject(arg0).length;
return ret;
};
imports.wbg.__wbg_instanceof_Uint8Array_4f5cffed7df34b2f = function(arg0) {
let result;
try {
result = getObject(arg0) instanceof Uint8Array;
} catch (_) {
result = false;
}
const ret = result;
return ret;
};
imports.wbg.__wbg_newwithlength_728575f3bba9959b = function(arg0) {
const ret = new Uint8Array(arg0 >>> 0);
return addHeapObject(ret);
@@ -224,9 +491,12 @@ function __wbg_get_imports() {
const ret = getObject(arg0).subarray(arg1 >>> 0, arg2 >>> 0);
return addHeapObject(ret);
};
imports.wbg.__wbindgen_object_clone_ref = function(arg0) {
const ret = getObject(arg0);
return addHeapObject(ret);
imports.wbg.__wbindgen_debug_string = function(arg0, arg1) {
const ret = debugString(getObject(arg1));
const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len1 = WASM_VECTOR_LEN;
getInt32Memory0()[arg0 / 4 + 1] = len1;
getInt32Memory0()[arg0 / 4 + 0] = ptr1;
};
imports.wbg.__wbindgen_throw = function(arg0, arg1) {
throw new Error(getStringFromWasm0(arg0, arg1));
@@ -246,7 +516,9 @@ function __wbg_init_memory(imports, maybe_memory) {
function __wbg_finalize_init(instance, module) {
wasm = instance.exports;
__wbg_init.__wbindgen_wasm_module = module;
cachedFloat64Memory0 = null;
cachedInt32Memory0 = null;
cachedUint32Memory0 = null;
cachedUint8Memory0 = null;

Binary file not shown.

View File

@@ -1,7 +1,8 @@
/* tslint:disable */
/* eslint-disable */
export const memory: WebAssembly.Memory;
export function state0_serialized(a: number): void;
export function __wbindgen_add_to_stack_pointer(a: number): number;
export function __wbindgen_free(a: number, b: number, c: number): void;
export function state0_bindgen(): number;
export function state1_bindgen(a: number, b: number, c: number): number;
export function __wbindgen_malloc(a: number, b: number): number;
export function __wbindgen_realloc(a: number, b: number, c: number, d: number): number;
export function __wbindgen_exn_store(a: number): void;

View File

@@ -1,13 +1,15 @@
use bfv::{
BfvParameters, Ciphertext, CiphertextProto, CollectiveDecryption, CollectiveDecryptionShare,
CollectivePublicKeyGenerator, CollectivePublicKeyShare, CollectiveRlkAggTrimmedShare1,
CollectiveRlkGenerator, CollectiveRlkGeneratorState, CollectiveRlkShare1, CollectiveRlkShare2,
Encoding, EvaluationKey, Evaluator, Plaintext, Poly, SecretKey, SecretKeyProto,
CollectivePublicKeyGenerator, CollectivePublicKeyShare, CollectivePublicKeyShareProto,
CollectiveRlkAggTrimmedShare1, CollectiveRlkAggTrimmedShare1Proto, CollectiveRlkGenerator,
CollectiveRlkGeneratorState, CollectiveRlkShare1, CollectiveRlkShare1Proto,
CollectiveRlkShare2, CollectiveRlkShare2Proto, Encoding, EvaluationKey, Evaluator, Plaintext,
Poly, SecretKey, SecretKeyProto,
};
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use traits::{TryDecodingWithParameters, TryEncodingWithParameters, TryFromWithParameters};
use wasm_bindgen::prelude::wasm_bindgen;
use traits::{TryEncodingWithParameters, TryFromWithParameters};
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
static CRS_PK: [u8; 32] = [0u8; 32];
static CRS_RLK: [u8; 32] = [0u8; 32];
@@ -19,42 +21,56 @@ fn params() -> BfvParameters {
params
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct PrivateOutputAPostState0 {
s_pk_a: SecretKey,
s_rlk_a: SecretKey,
s_pk_a: SecretKeyProto,
s_rlk_a: SecretKeyProto,
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct PublicOutputAPostState0 {
share_pk_a: CollectivePublicKeyShare,
share_rlk_a_round1: CollectiveRlkShare1,
share_pk_a: CollectivePublicKeyShareProto,
share_rlk_a_round1: CollectiveRlkShare1Proto,
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct MessageAToBPostState0 {
share_pk_a: CollectivePublicKeyShare,
share_rlk_a_round1: CollectiveRlkShare1,
share_pk_a: CollectivePublicKeyShareProto,
share_rlk_a_round1: CollectiveRlkShare1Proto,
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct OutputState0 {
private_output_a: PrivateOutputAPostState0,
public_output_a: PublicOutputAPostState0,
message_a_to_b: MessageAToBPostState0,
}
#[derive(Serialize, Deserialize)]
struct PrivateOutputBPostState1 {
s_pk_b: SecretKey,
s_pk_b: SecretKeyProto,
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct PublicOutputBPostState1 {
ciphertext_b: Ciphertext,
share_rlk_b_round2: CollectiveRlkShare2,
rlk_agg_round1_h1s: CollectiveRlkAggTrimmedShare1,
ciphertext_b: CiphertextProto,
share_rlk_b_round2: CollectiveRlkShare2Proto,
rlk_agg_round1_h1s: CollectiveRlkAggTrimmedShare1Proto,
}
// #[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct MessageBToAPostState1 {
share_pk_b: CollectivePublicKeyShare,
share_rlk_b_round1: CollectiveRlkShare1,
share_rlk_b_round2: CollectiveRlkShare2,
ciphertext_b: Ciphertext,
share_pk_b: CollectivePublicKeyShareProto,
share_rlk_b_round1: CollectiveRlkShare1Proto,
share_rlk_b_round2: CollectiveRlkShare2Proto,
ciphertext_b: CiphertextProto,
}
#[derive(Serialize, Deserialize)]
struct OutputState1 {
private_output_b: PrivateOutputBPostState1,
public_output_b: PublicOutputBPostState1,
message_b_to_a: MessageBToAPostState1,
}
// #[derive(Serialize, Deserialize)]
@@ -75,23 +91,18 @@ struct MessageBToAPostState3 {
decryption_share_b: CollectiveDecryptionShare,
}
// #[wasm_bindgen]
// pub fn state0_serialized() -> (Vec<u8>, Vec<u8>) {
// let (private_output_a_state_0, _, _) = state0();
#[wasm_bindgen]
pub fn state0_bindgen() -> JsValue {
let (private_output_a, public_output_a, message_a_to_b) = state0();
// let s_pk_a_serialized =
// SecretKeyProto::try_from_with_parameters(&private_output_a_state_0.s_pk_a, &params())
// .coefficients;
let output = OutputState0 {
private_output_a,
public_output_a,
message_a_to_b,
};
// let s_rlk_a_serialized =
// SecretKeyProto::try_from_with_parameters(&private_output_a_state_0.s_rlk_a, &params())
// .coefficients;
// let mut output_serialized = s_pk_a_serialized;
// output_serialized.extend(s_rlk_a_serialized);
// output_serialized
// }
serde_wasm_bindgen::to_value(&output).unwrap()
}
fn state0() -> (
PrivateOutputAPostState0,
@@ -111,20 +122,42 @@ fn state0() -> (
CollectiveRlkGenerator::generate_share_1(&params, &s_pk_a, &s_rlk_a, CRS_RLK, 0, &mut rng);
let message_a_to_b = MessageAToBPostState0 {
share_pk_a: share_pk_a.clone(),
share_rlk_a_round1: share_rlk_a_round1.clone(),
share_pk_a: CollectivePublicKeyShareProto::try_from_with_parameters(&share_pk_a, &params),
share_rlk_a_round1: CollectiveRlkShare1Proto::try_from_with_parameters(
&share_rlk_a_round1,
&params,
),
};
let private_state_a = PrivateOutputAPostState0 {
s_pk_a,
s_rlk_a: s_rlk_a,
let private_output_a = PrivateOutputAPostState0 {
s_pk_a: SecretKeyProto::try_from_with_parameters(&s_pk_a, &params),
s_rlk_a: SecretKeyProto::try_from_with_parameters(&s_rlk_a, &params),
};
let public_output_a = PublicOutputAPostState0 {
share_pk_a,
share_rlk_a_round1,
share_pk_a: CollectivePublicKeyShareProto::try_from_with_parameters(&share_pk_a, &params),
share_rlk_a_round1: CollectiveRlkShare1Proto::try_from_with_parameters(
&share_rlk_a_round1,
&params,
),
};
(private_state_a, public_output_a, message_a_to_b)
(private_output_a, public_output_a, message_a_to_b)
}
#[wasm_bindgen]
pub fn state1_bindgen(message_from_a: JsValue, bit_vector: &[u32]) -> JsValue {
let message_from_a: MessageAToBPostState0 = serde_wasm_bindgen::from_value(message_from_a)
.expect("failed to deserialize message_from_a");
let (private_output_b, public_output_b, message_to_a) = state1(message_from_a, bit_vector);
let output = OutputState1 {
private_output_b,
public_output_b,
message_b_to_a: message_to_a,
};
serde_wasm_bindgen::to_value(&output).unwrap()
}
fn state1(
@@ -146,11 +179,14 @@ fn state1(
let share_rlk_b_round1 =
CollectiveRlkGenerator::generate_share_1(&params, &s_pk_b, &s_rlk_b, CRS_RLK, 0, &mut rng);
let share_rlk_a_round1 =
CollectiveRlkShare1::try_from_with_parameters(&message_from_a.share_rlk_a_round1, &params);
let share_pk_a =
CollectivePublicKeyShare::try_from_with_parameters(&message_from_a.share_pk_a, &params);
// rlk key part 1
let rlk_shares_round1 = vec![
message_from_a.share_rlk_a_round1,
share_rlk_b_round1.clone(),
];
let rlk_shares_round1 = vec![share_rlk_a_round1, share_rlk_b_round1.clone()];
let rlk_agg_1 = CollectiveRlkGenerator::aggregate_shares_1(&params, &rlk_shares_round1, 0);
// B already has access to aggregate shares for rlk round 1 and can proceed with the second round of the protocol
@@ -159,7 +195,7 @@ fn state1(
);
// generate collective public key and encryt b's input
let collective_pk_shares = vec![share_pk_b.clone(), message_from_a.share_pk_a];
let collective_pk_shares = vec![share_pk_b.clone(), share_pk_a];
let collecitve_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise(
&params,
&collective_pk_shares,
@@ -169,188 +205,200 @@ fn state1(
let ciphertext_b = collecitve_pk.encrypt(&params, &pt, &mut rng);
let message_to_a = MessageBToAPostState1 {
share_pk_b,
share_rlk_b_round1,
share_rlk_b_round2: share_rlk_b_round2.clone(),
ciphertext_b: ciphertext_b.clone(),
share_pk_b: CollectivePublicKeyShareProto::try_from_with_parameters(&share_pk_b, &params),
share_rlk_b_round1: CollectiveRlkShare1Proto::try_from_with_parameters(
&share_rlk_b_round1,
&params,
),
share_rlk_b_round2: CollectiveRlkShare2Proto::try_from_with_parameters(
&share_rlk_b_round2,
&params,
),
ciphertext_b: CiphertextProto::try_from_with_parameters(&ciphertext_b, &params),
};
let private_output_b = PrivateOutputBPostState1 { s_pk_b };
let private_output_b = PrivateOutputBPostState1 {
s_pk_b: SecretKeyProto::try_from_with_parameters(&s_pk_b, &params),
};
let rlk_aggregated_shares1_trimmed = rlk_agg_1.trim();
let public_output_b = PublicOutputBPostState1 {
ciphertext_b,
share_rlk_b_round2,
rlk_agg_round1_h1s: rlk_aggregated_shares1_trimmed,
ciphertext_b: CiphertextProto::try_from_with_parameters(&ciphertext_b, &params),
share_rlk_b_round2: CollectiveRlkShare2Proto::try_from_with_parameters(
&share_rlk_b_round2,
&params,
),
rlk_agg_round1_h1s: CollectiveRlkAggTrimmedShare1Proto::try_from_with_parameters(
&rlk_aggregated_shares1_trimmed,
&params,
),
};
(private_output_b, public_output_b, message_to_a)
}
fn state2(
private_output_a_state0: PrivateOutputAPostState0,
public_output_a_state0: PublicOutputAPostState0,
message_from_b: MessageBToAPostState1,
bit_vector: &[u32],
) -> (PublicOutputAPostState2, MessageAToBPostState2) {
let params = params();
let mut rng = thread_rng();
// fn state2(
// private_output_a_state0: PrivateOutputAPostState0,
// public_output_a_state0: PublicOutputAPostState0,
// message_from_b: MessageBToAPostState1,
// bit_vector: &[u32],
// ) -> (PublicOutputAPostState2, MessageAToBPostState2) {
// let params = params();
// let mut rng = thread_rng();
// aggrgegate shares of rlk round 1
let rlk_shares_round1 = vec![
public_output_a_state0.share_rlk_a_round1,
message_from_b.share_rlk_b_round1,
];
let rlk_agg_1 = CollectiveRlkGenerator::aggregate_shares_1(&params, &rlk_shares_round1, 0);
// // aggrgegate shares of rlk round 1
// let rlk_shares_round1 = vec![
// public_output_a_state0.share_rlk_a_round1,
// message_from_b.share_rlk_b_round1,
// ];
// let rlk_agg_1 = CollectiveRlkGenerator::aggregate_shares_1(&params, &rlk_shares_round1, 0);
// generate share 2 for rlk round 2
let share_rlk_a_round2 = CollectiveRlkGenerator::generate_share_2(
&params,
&private_output_a_state0.s_pk_a,
&rlk_agg_1,
&private_output_a_state0.s_rlk_a,
0,
&mut rng,
);
// // generate share 2 for rlk round 2
// let share_rlk_a_round2 = CollectiveRlkGenerator::generate_share_2(
// &params,
// &private_output_a_state0.s_pk_a,
// &rlk_agg_1,
// &private_output_a_state0.s_rlk_a,
// 0,
// &mut rng,
// );
let rlk_agg_1_trimmed = rlk_agg_1.trim();
// aggregate rlk round 2 shares and generate rlk
let rlk_shares_round2 = vec![
share_rlk_a_round2.clone(),
message_from_b.share_rlk_b_round2,
];
let rlk = CollectiveRlkGenerator::aggregate_shares_2(
&params,
&rlk_shares_round2,
rlk_agg_1_trimmed,
0,
);
// let rlk_agg_1_trimmed = rlk_agg_1.trim();
// // aggregate rlk round 2 shares and generate rlk
// let rlk_shares_round2 = vec![
// share_rlk_a_round2.clone(),
// message_from_b.share_rlk_b_round2,
// ];
// let rlk = CollectiveRlkGenerator::aggregate_shares_2(
// &params,
// &rlk_shares_round2,
// rlk_agg_1_trimmed,
// 0,
// );
// create public key and encrypt A's bit vector'
let collective_pk_shares = vec![public_output_a_state0.share_pk_a, message_from_b.share_pk_b];
let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise(
&params,
&collective_pk_shares,
CRS_PK,
);
let pt = Plaintext::try_encoding_with_parameters(bit_vector, &params, Encoding::default());
let ciphertext_a = collective_pk.encrypt(&params, &pt, &mut rng);
// // create public key and encrypt A's bit vector'
// let collective_pk_shares = vec![public_output_a_state0.share_pk_a, message_from_b.share_pk_b];
// let collective_pk = CollectivePublicKeyGenerator::aggregate_shares_and_finalise(
// &params,
// &collective_pk_shares,
// CRS_PK,
// );
// let pt = Plaintext::try_encoding_with_parameters(bit_vector, &params, Encoding::default());
// let ciphertext_a = collective_pk.encrypt(&params, &pt, &mut rng);
// perform PSI
let evaluator = Evaluator::new(params);
let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
let ciphertext_res = evaluator.mul(&ciphertext_a, &message_from_b.ciphertext_b);
let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key);
// // perform PSI
// let evaluator = Evaluator::new(params);
// let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
// let ciphertext_res = evaluator.mul(&ciphertext_a, &message_from_b.ciphertext_b);
// let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key);
// generate decryption share of ciphertext_res
let decryption_share_a = CollectiveDecryption::generate_share(
evaluator.params(),
&ciphertext_res,
&private_output_a_state0.s_pk_a,
&mut rng,
);
// // generate decryption share of ciphertext_res
// let decryption_share_a = CollectiveDecryption::generate_share(
// evaluator.params(),
// &ciphertext_res,
// &private_output_a_state0.s_pk_a,
// &mut rng,
// );
let public_output_a = PublicOutputAPostState2 {
decryption_share_a: decryption_share_a.clone(),
ciphertext_res,
};
// let public_output_a = PublicOutputAPostState2 {
// decryption_share_a: decryption_share_a.clone(),
// ciphertext_res,
// };
let message_a_to_b = MessageAToBPostState2 {
decryption_share_a,
ciphertext_a,
share_rlk_a_round2,
};
// let message_a_to_b = MessageAToBPostState2 {
// decryption_share_a,
// ciphertext_a,
// share_rlk_a_round2,
// };
(public_output_a, message_a_to_b)
}
// (public_output_a, message_a_to_b)
// }
fn state3(
private_output_b_state1: PrivateOutputBPostState1,
public_output_b_state1: PublicOutputBPostState1,
message_from_a: MessageAToBPostState2,
) -> (MessageBToAPostState3, Vec<u32>) {
let params = params();
let mut rng = thread_rng();
// fn state3(
// private_output_b_state1: PrivateOutputBPostState1,
// public_output_b_state1: PublicOutputBPostState1,
// message_from_a: MessageAToBPostState2,
// ) -> (MessageBToAPostState3, Vec<u32>) {
// let params = params();
// let mut rng = thread_rng();
// create rlk
let rlk_shares_round2 = vec![
message_from_a.share_rlk_a_round2,
public_output_b_state1.share_rlk_b_round2,
];
let rlk = CollectiveRlkGenerator::aggregate_shares_2(
&params,
&rlk_shares_round2,
public_output_b_state1.rlk_agg_round1_h1s,
0,
);
// // create rlk
// let rlk_shares_round2 = vec![
// message_from_a.share_rlk_a_round2,
// public_output_b_state1.share_rlk_b_round2,
// ];
// let rlk = CollectiveRlkGenerator::aggregate_shares_2(
// &params,
// &rlk_shares_round2,
// public_output_b_state1.rlk_agg_round1_h1s,
// 0,
// );
// perform PSI
let evaluator = Evaluator::new(params);
let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
let ciphertext_res = evaluator.mul(
&message_from_a.ciphertext_a,
&public_output_b_state1.ciphertext_b,
);
let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key);
// // perform PSI
// let evaluator = Evaluator::new(params);
// let evaluation_key = EvaluationKey::new_raw(&[0], vec![rlk], &[], &[], vec![]);
// let ciphertext_res = evaluator.mul(
// &message_from_a.ciphertext_a,
// &public_output_b_state1.ciphertext_b,
// );
// let ciphertext_res = evaluator.relinearize(&ciphertext_res, &evaluation_key);
// generate B's decryption share
let decryption_share_b = CollectiveDecryption::generate_share(
evaluator.params(),
&ciphertext_res,
&private_output_b_state1.s_pk_b,
&mut rng,
);
// // generate B's decryption share
// let decryption_share_b = CollectiveDecryption::generate_share(
// evaluator.params(),
// &ciphertext_res,
// &private_output_b_state1.s_pk_b,
// &mut rng,
// );
// decrypt ciphertext res
let decryption_shares_vec = vec![
decryption_share_b.clone(),
message_from_a.decryption_share_a,
];
let psi_output = CollectiveDecryption::aggregate_share_and_decrypt(
evaluator.params(),
&ciphertext_res,
&decryption_shares_vec,
);
let psi_output = Vec::<u32>::try_decoding_with_parameters(
&psi_output,
evaluator.params(),
Encoding::default(),
);
// // decrypt ciphertext res
// let decryption_shares_vec = vec![
// decryption_share_b.clone(),
// message_from_a.decryption_share_a,
// ];
// let psi_output = CollectiveDecryption::aggregate_share_and_decrypt(
// evaluator.params(),
// &ciphertext_res,
// &decryption_shares_vec,
// );
// let psi_output = Vec::<u32>::try_decoding_with_parameters(
// &psi_output,
// evaluator.params(),
// Encoding::default(),
// );
let message_b_to_a = MessageBToAPostState3 { decryption_share_b };
// let message_b_to_a = MessageBToAPostState3 { decryption_share_b };
(message_b_to_a, psi_output)
}
// (message_b_to_a, psi_output)
// }
fn state4(
public_output_a_state2: PublicOutputAPostState2,
message_from_b: MessageBToAPostState3,
) -> Vec<u32> {
let params = params();
// fn state4(
// public_output_a_state2: PublicOutputAPostState2,
// message_from_b: MessageBToAPostState3,
// ) -> Vec<u32> {
// let params = params();
// decrypt ciphertext res
let decryption_shares_vec = vec![
public_output_a_state2.decryption_share_a,
message_from_b.decryption_share_b,
];
let psi_output = CollectiveDecryption::aggregate_share_and_decrypt(
&params,
&public_output_a_state2.ciphertext_res,
&decryption_shares_vec,
);
let psi_output =
Vec::<u32>::try_decoding_with_parameters(&psi_output, &params, Encoding::default());
// // decrypt ciphertext res
// let decryption_shares_vec = vec![
// public_output_a_state2.decryption_share_a,
// message_from_b.decryption_share_b,
// ];
// let psi_output = CollectiveDecryption::aggregate_share_and_decrypt(
// &params,
// &public_output_a_state2.ciphertext_res,
// &decryption_shares_vec,
// );
// let psi_output =
// Vec::<u32>::try_decoding_with_parameters(&psi_output, &params, Encoding::default());
psi_output
}
// psi_output
// }
#[cfg(test)]
mod tests {
use super::*;
use bfv::SecretKeyProto;
use itertools::{izip, Itertools};
use rand::{distributions::Uniform, Rng};
use traits::TryFromWithParameters;
fn random_bit_vector(hamming_weight: usize, size: usize) -> Vec<u32> {
let mut rng = thread_rng();
@@ -370,41 +418,41 @@ mod tests {
.collect_vec()
}
#[test]
fn psi_works() {
let hamming_weight = 10;
let vector_size = 10;
// #[test]
// fn psi_works() {
// let hamming_weight = 10;
// let vector_size = 10;
// A: state 0
let (private_output_a_state0, public_output_a_state0, message_a_to_b_state0) = state0();
// // A: state 0
// let (private_output_a_state0, public_output_a_state0, message_a_to_b_state0) = state0();
// B: state 1
let bit_vector_b = random_bit_vector(hamming_weight, vector_size);
let (private_output_b_state1, public_output_b_state1, message_b_to_a_state1) =
state1(message_a_to_b_state0, &bit_vector_b);
// // B: state 1
// let bit_vector_b = random_bit_vector(hamming_weight, vector_size);
// let (private_output_b_state1, public_output_b_state1, message_b_to_a_state1) =
// state1(message_a_to_b_state0, &bit_vector_b);
// A: state 2
let bit_vector_a = random_bit_vector(hamming_weight, vector_size);
let (public_output_a_state2, message_a_to_b_state2) = state2(
private_output_a_state0,
public_output_a_state0,
message_b_to_a_state1,
&bit_vector_a,
);
// // A: state 2
// let bit_vector_a = random_bit_vector(hamming_weight, vector_size);
// let (public_output_a_state2, message_a_to_b_state2) = state2(
// private_output_a_state0,
// public_output_a_state0,
// message_b_to_a_state1,
// &bit_vector_a,
// );
// B: state 3
let (message_b_to_a_state3, psi_output_b) = state3(
private_output_b_state1,
public_output_b_state1,
message_a_to_b_state2,
);
// // B: state 3
// let (message_b_to_a_state3, psi_output_b) = state3(
// private_output_b_state1,
// public_output_b_state1,
// message_a_to_b_state2,
// );
// A: state 4
let psi_output_a = state4(public_output_a_state2, message_b_to_a_state3);
// // A: state 4
// let psi_output_a = state4(public_output_a_state2, message_b_to_a_state3);
let expected_psi_output = plain_psi(&bit_vector_a, &bit_vector_b);
// let expected_psi_output = plain_psi(&bit_vector_a, &bit_vector_b);
assert_eq!(expected_psi_output, psi_output_a[..vector_size]);
assert_eq!(psi_output_a, psi_output_b);
}
// assert_eq!(expected_psi_output, psi_output_a[..vector_size]);
// assert_eq!(psi_output_a, psi_output_b);
// }
}