diff --git a/programs/jslib.cpp b/programs/jslib.cpp index 13d577a..d8ff81a 100644 --- a/programs/jslib.cpp +++ b/programs/jslib.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "emp-tool/io/i_raw_io.h" #include "emp-ag2pc/2pc.h" @@ -24,45 +25,121 @@ EM_JS(void, send_js, (int to_party, char channel_label, const void* data, size_t }); // Implement recv_js function to receive data from JavaScript to C++ -EM_ASYNC_JS(void, recv_js, (int from_party, char channel_label, void* data, size_t len), { +EM_ASYNC_JS(size_t, recv_js, (int from_party, char channel_label, void* data, size_t min_len, size_t max_len), { if (!Module.emp?.io?.recv) { reject(new Error("Module.emp.io.recv is not defined in JavaScript.")); return; } // Wait for data from JavaScript - const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), len); + const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), min_len, max_len); // Copy data from JavaScript Uint8Array to WebAssembly memory HEAPU8.set(dataArray, data); + + // Return the length of the received data + return dataArray.length; }); +class RawIOJS; +std::map raw_io_map; +int next_raw_io_id = 0; +size_t MAX_SEND_BUFFER_SIZE = 64 * 1024; + +void actual_flush_all(); + class RawIOJS : public IRawIO { public: int other_party; char channel_label; + std::vector send_buffer; // TODO: Max buffer size? + std::vector recv_buffer; + size_t recv_start = 0; + size_t recv_end = 0; + int id; RawIOJS( int other_party, char channel_label ): other_party(other_party), - channel_label(channel_label) - {} + channel_label(channel_label), + recv_buffer(64 * 1024) + { + id = next_raw_io_id++; + raw_io_map[id] = this; + } + + ~RawIOJS() { + raw_io_map.erase(id); + } void send(const void* data, size_t len) override { - send_js(other_party, channel_label, data, len); + if (send_buffer.size() + len > MAX_SEND_BUFFER_SIZE) { + actual_flush(); + } + + // This will still exceed max size if len > MAX_SEND_BUFFER_SIZE, that's ok + send_buffer.resize(send_buffer.size() + len); + + std::memcpy(send_buffer.data() + send_buffer.size() - len, data, len); } void recv(void* data, size_t len) override { - recv_js(other_party, channel_label, data, len); + if (recv_start + len > recv_end) { + if (recv_start + len > recv_buffer.size()) { + // copy within + size_t recv_len = recv_end - recv_start; + std::memmove(recv_buffer.data(), recv_buffer.data() + recv_start, recv_len); + recv_start = 0; + recv_end = recv_len; + } + + size_t bytes_needed = recv_start + len - recv_end; + size_t room = recv_buffer.size() - recv_end; + + if (bytes_needed > room) { + size_t size = recv_buffer.size(); + + while (bytes_needed > room) { + size *= 2; + room = size - recv_end; + } + + recv_buffer.resize(size); + } + + actual_flush_all(); + size_t bytes_received = recv_js(other_party, channel_label, recv_buffer.data() + recv_end, bytes_needed, room); + recv_end += bytes_received; + + if (bytes_received < bytes_needed) { + throw std::runtime_error("recv failed"); + } + } + + std::memcpy(data, recv_buffer.data() + recv_start, len); + recv_start += len; } void flush() override { - // Ignored for now + // ignored for now + } + + void actual_flush() { + if (send_buffer.size() > 0) { + send_js(other_party, channel_label, send_buffer.data(), send_buffer.size()); + send_buffer.clear(); + } } }; +void actual_flush_all() { + for (auto& [key, raw_io] : raw_io_map) { + raw_io->actual_flush(); + } +} + class MultiIOJS : public IMultiIO { public: int mParty; @@ -284,6 +361,9 @@ void run_2pc_impl(int party, int nP) { twopc.function_dependent(); std::vector output_bits = twopc.online(input_bits, true); + + actual_flush_all(); + handle_output_bits(output_bits); } catch (const std::exception& e) { handle_error(e.what()); @@ -339,6 +419,8 @@ void run_mpc_impl(int party, int nP) { output_bits.push_back(output.get_plaintext_bit(i)); } + actual_flush_all(); + handle_output_bits(output_bits); } catch (const std::exception& e) { handle_error(e.what()); diff --git a/src/ts/BufferQueue.ts b/src/ts/BufferQueue.ts index ecc209b..d3b4c1f 100644 --- a/src/ts/BufferQueue.ts +++ b/src/ts/BufferQueue.ts @@ -5,7 +5,7 @@ export default class BufferQueue { private buffer: Uint8Array; private bufferStart: number; private bufferEnd: number; - private pendingPops: number[]; + private pendingPops: { min_len: number, max_len: number }[]; private pendingPopsResolvers: { resolve: ((value: Uint8Array) => void), reject: (e: Error) => void, @@ -65,19 +65,21 @@ export default class BufferQueue { * @param len - The number of bytes to pop from the buffer. * @returns A promise resolving with the popped data as a Uint8Array. */ - pop(len: number): Promise { - if (typeof len !== 'number' || len < 0) { - return Promise.reject(new Error('Length must be non-negative integer')); + pop(min_len: number, max_len: number): Promise { + if (typeof min_len !== 'number' || typeof max_len !== 'number' || min_len < 0 || max_len < 0 || min_len > max_len) { + return Promise.reject(new Error('Invalid min/max lengths')); } - if (this.bufferEnd - this.bufferStart >= len) { - const result = this.buffer.slice(this.bufferStart, this.bufferStart + len); - this.bufferStart += len; + if (this.bufferEnd - this.bufferStart >= min_len) { + const available_len = this.bufferEnd - this.bufferStart; + const provide_len = Math.min(max_len, available_len); + const result = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len); + this.bufferStart += result.length; this._compactBuffer(); return Promise.resolve(result); } else if (!this.closed) { return new Promise((resolve, reject) => { - this.pendingPops.push(len); + this.pendingPops.push({ min_len, max_len }); this.pendingPopsResolvers.push({ resolve, reject }); }); } else { @@ -109,10 +111,12 @@ export default class BufferQueue { */ private _resolvePendingPops(): void { while (this.pendingPops.length > 0) { - const len = this.pendingPops[0]; - if (this.bufferEnd - this.bufferStart >= len) { - const data = this.buffer.slice(this.bufferStart, this.bufferStart + len); - this.bufferStart += len; + const { min_len, max_len } = this.pendingPops[0]; + if (this.bufferEnd - this.bufferStart >= min_len) { + const available_len = this.bufferEnd - this.bufferStart; + const provide_len = Math.min(max_len, available_len); + const data = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len); + this.bufferStart += data.length; this.pendingPops.shift(); const { resolve } = this.pendingPopsResolvers.shift()!; resolve(data); diff --git a/src/ts/BufferedIO.ts b/src/ts/BufferedIO.ts index 52041ae..d81bd2f 100644 --- a/src/ts/BufferedIO.ts +++ b/src/ts/BufferedIO.ts @@ -27,9 +27,9 @@ export default class BufferedIO this.closeOther?.(); } - async recv(fromParty: number, channel: 'a' | 'b', len: number): Promise { + async recv(fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number): Promise { assert(fromParty === this.otherParty, 'fromParty !== this.otherParty'); - return await this.bq[channel].pop(len); + return await this.bq[channel].pop(min_len, max_len); } accept(channel: 'a' | 'b', data: Uint8Array) { diff --git a/src/ts/appendWorker.ts b/src/ts/appendWorker.ts index 817726f..92de62c 100644 --- a/src/ts/appendWorker.ts +++ b/src/ts/appendWorker.ts @@ -61,15 +61,25 @@ async function secureMPC({ emp.circuit = circuit; emp.inputBits = inputBits; emp.inputBitsPerParty = inputBitsPerParty; - emp.io = io; + + let reject: undefined | ((error: unknown) => void) = undefined; + const callbackRejector = new Promise((_resolve, rej) => { + reject = rej; + }); + reject = reject!; + + emp.io = { + send: useRejector(io.send.bind(io), reject), + recv: useRejector(io.recv.bind(io), reject), + }; const method = calculateMethod(mode, size, circuit); - const result = new Promise((resolve, reject) => { + const result = new Promise(async (resolve, reject) => { try { emp.handleOutput = resolve; emp.handleError = reject; - + callbackRejector.catch(reject); module[method](party, size); } catch (error) { reject(error); @@ -97,7 +107,12 @@ function calculateMethod( case 'mpc': return '_run_mpc'; case 'auto': - return size === 2 ? '_run_2pc' : '_run_mpc'; + // Advantage of 2PC specialization is small and contains "FEQ error" bug + // for the large circuits, so the performance currently cannot be realized + // where it matters. + // Therefore, we default to the general N-party mpc mode, even when there + // are only 2 parties. + return '_run_mpc'; default: const _never: never = mode; @@ -125,11 +140,11 @@ onmessage = async (event) => { send: (toParty, channel, data) => { postMessage({ type: 'io_send', toParty, channel, data }); }, - recv: (fromParty, channel, len) => { + recv: (fromParty, channel, min_len, max_len) => { return new Promise((resolve, reject) => { const id = requestId++; pendingRequests[id] = { resolve, reject }; - postMessage({ type: 'io_recv', fromParty, channel, len, id }); + postMessage({ type: 'io_recv', fromParty, channel, min_len, max_len, id }); }); }, }; @@ -147,7 +162,7 @@ onmessage = async (event) => { postMessage({ type: 'result', result }); } catch (error) { - postMessage({ type: 'error', error: (error as Error).stack }); + postMessage({ type: 'error', error: (error as Error).message }); } } else if (message.type === 'io_recv_response') { const { id, data } = message; @@ -163,3 +178,17 @@ onmessage = async (event) => { } } }; + +function useRejector any>( + fn: F, + reject: (error: unknown) => void, +): F { + return ((...args: Parameters) => { + try { + return fn(...args); + } catch (error) { + reject(error); + throw error; + } + }) as F; +} diff --git a/src/ts/demo.ts b/src/ts/demo.ts index c1c7a6c..a63e79a 100644 --- a/src/ts/demo.ts +++ b/src/ts/demo.ts @@ -28,7 +28,7 @@ windowAny.internalDemo = async function( inputBitsPerParty: [32, 32], io: { send: (toParty, channel, data) => bobBq[channel].push(data), - recv: (fromParty, channel, len) => aliceBq[channel].pop(len), + recv: (fromParty, channel, min_len, max_len) => aliceBq[channel].pop(min_len, max_len), }, mode, }), @@ -40,7 +40,7 @@ windowAny.internalDemo = async function( inputBitsPerParty: [32, 32], io: { send: (toParty, channel, data) => aliceBq[channel].push(data), - recv: (fromParty, channel, len) => bobBq[channel].pop(len), + recv: (fromParty, channel, min_len, max_len) => bobBq[channel].pop(min_len, max_len), }, mode, }), @@ -69,7 +69,7 @@ windowAny.internalDemo3 = async function( inputBitsPerParty: [32, 32, 32], io: { send: (toParty, channel, data) => bqs.get(party, toParty, channel).push(data), - recv: (fromParty, channel, len) => bqs.get(fromParty, party, channel).pop(len), + recv: (fromParty, channel, min_len, max_len) => bqs.get(fromParty, party, channel).pop(min_len, max_len), }, mode, }))); @@ -342,9 +342,9 @@ function makeCopyPasteIO(otherParty: number): IO { return { send: makeConsoleSend(otherParty), - recv: (fromParty, channel, len) => { + recv: (fromParty, channel, min_len, max_len) => { assert(fromParty === otherParty, 'Unexpected party'); - return bq[channel].pop(len); + return bq[channel].pop(min_len, max_len); }, }; } diff --git a/src/ts/nodeSecureMPC.ts b/src/ts/nodeSecureMPC.ts index a68c009..1ec58f5 100644 --- a/src/ts/nodeSecureMPC.ts +++ b/src/ts/nodeSecureMPC.ts @@ -42,7 +42,17 @@ export default async function nodeSecureMPC({ emp.circuit = circuit; emp.inputBits = inputBits; emp.inputBitsPerParty = inputBitsPerParty; - emp.io = io; + + let reject: undefined | ((error: unknown) => void) = undefined; + const callbackRejector = new Promise((_resolve, rej) => { + reject = rej; + }); + reject = reject!; + + emp.io = { + send: useRejector(io.send.bind(io), reject), + recv: useRejector(io.recv.bind(io), reject), + }; const method = calculateMethod(mode, size, circuit); @@ -50,6 +60,7 @@ export default async function nodeSecureMPC({ try { emp.handleOutput = resolve; emp.handleError = reject; + callbackRejector.catch(reject); module[method](party, size); } catch (error) { @@ -74,10 +85,29 @@ function calculateMethod( case 'mpc': return '_run_mpc'; case 'auto': - return size === 2 ? '_run_2pc' : '_run_mpc'; + // Advantage of 2PC specialization is small and contains "FEQ error" bug + // for the large circuits, so the performance currently cannot be realized + // where it matters. + // Therefore, we default to the general N-party mpc mode, even when there + // are only 2 parties. + return '_run_mpc'; default: const _never: never = mode; throw new Error('Unexpected mode: ' + mode); } } + +function useRejector any>( + fn: F, + reject: (error: unknown) => void, +): F { + return ((...args: Parameters) => { + try { + return fn(...args); + } catch (error) { + reject(error); + throw error; + } + }) as F; +} diff --git a/src/ts/secureMPC.ts b/src/ts/secureMPC.ts index 5780365..58c9c69 100644 --- a/src/ts/secureMPC.ts +++ b/src/ts/secureMPC.ts @@ -62,10 +62,10 @@ export default function secureMPC({ const { toParty, channel, data } = message; io.send(toParty, channel, data); } else if (message.type === 'io_recv') { - const { fromParty, channel, len } = message; + const { fromParty, channel, min_len, max_len } = message; // Handle the recv request from the worker try { - const data = await io.recv(fromParty, channel, len); + const data = await io.recv(fromParty, channel, min_len, max_len); worker.postMessage({ type: 'io_recv_response', id: message.id, data }); } catch (error) { worker.postMessage({ @@ -80,6 +80,10 @@ export default function secureMPC({ } else if (message.type === 'error') { // Reject the promise if an error occurred reject(new Error(message.error)); + } else if (message.type === 'log') { + console.log('Worker log:', message.msg); + } else { + console.error('Unexpected message from worker:', message); } }; diff --git a/src/ts/types.ts b/src/ts/types.ts index f846a2f..f691337 100644 --- a/src/ts/types.ts +++ b/src/ts/types.ts @@ -1,6 +1,6 @@ export type IO = { send: (toParty: number, channel: 'a' | 'b', data: Uint8Array) => void; - recv: (fromParty: number, channel: 'a' | 'b', len: number) => Promise; + recv: (fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number) => Promise; on?: (event: 'error', listener: (error: Error) => void) => void; off?: (event: 'error', listener: (error: Error) => void) => void; close?: () => void; diff --git a/tests/secureMPC.test.ts b/tests/secureMPC.test.ts index 6c40971..c28d612 100644 --- a/tests/secureMPC.test.ts +++ b/tests/secureMPC.test.ts @@ -55,9 +55,9 @@ async function internalDemo( expect(toParty).to.equal(1); bqs.get('alice', 'bob', channel).push(data); }, - recv: async (fromParty, channel, len) => { + recv: async (fromParty, channel, min_len, max_len) => { expect(fromParty).to.equal(1); - return bqs.get('bob', 'alice', channel).pop(len); + return bqs.get('bob', 'alice', channel).pop(min_len, max_len); }, }, mode, @@ -73,9 +73,9 @@ async function internalDemo( expect(toParty).to.equal(0); bqs.get('bob', 'alice', channel).push(data); }, - recv: async (fromParty, channel, len) => { + recv: async (fromParty, channel, min_len, max_len) => { expect(fromParty).to.equal(0); - return bqs.get('alice', 'bob', channel).pop(len); + return bqs.get('alice', 'bob', channel).pop(min_len, max_len); }, }, mode, @@ -119,8 +119,8 @@ async function internalDemoN( send: (toParty, channel, data) => { bqs.get(party, toParty, channel).push(data); }, - recv: async (fromParty, channel, len) => { - return bqs.get(fromParty, party, channel).pop(len); + recv: async (fromParty, channel, min_len, max_len) => { + return bqs.get(fromParty, party, channel).pop(min_len, max_len); }, } })));