mirror of
https://github.com/privacy-scaling-explorations/emp-wasm.git
synced 2026-01-09 10:07:54 -05:00
improvements ported from perf-testing: buffer and flush the io, accept bytes than needed into buffer when available, improve c++ to js error propagation, default to mpc mode
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
#include "emp-tool/io/i_raw_io.h"
|
||||
#include "emp-ag2pc/2pc.h"
|
||||
@@ -24,45 +25,121 @@ EM_JS(void, send_js, (int to_party, char channel_label, const void* data, size_t
|
||||
});
|
||||
|
||||
// Implement recv_js function to receive data from JavaScript to C++
|
||||
EM_ASYNC_JS(void, recv_js, (int from_party, char channel_label, void* data, size_t len), {
|
||||
EM_ASYNC_JS(size_t, recv_js, (int from_party, char channel_label, void* data, size_t min_len, size_t max_len), {
|
||||
if (!Module.emp?.io?.recv) {
|
||||
reject(new Error("Module.emp.io.recv is not defined in JavaScript."));
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for data from JavaScript
|
||||
const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), len);
|
||||
const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), min_len, max_len);
|
||||
|
||||
// Copy data from JavaScript Uint8Array to WebAssembly memory
|
||||
HEAPU8.set(dataArray, data);
|
||||
|
||||
// Return the length of the received data
|
||||
return dataArray.length;
|
||||
});
|
||||
|
||||
class RawIOJS;
|
||||
std::map<int, RawIOJS*> raw_io_map;
|
||||
int next_raw_io_id = 0;
|
||||
size_t MAX_SEND_BUFFER_SIZE = 64 * 1024;
|
||||
|
||||
void actual_flush_all();
|
||||
|
||||
class RawIOJS : public IRawIO {
|
||||
public:
|
||||
int other_party;
|
||||
char channel_label;
|
||||
std::vector<uint8_t> send_buffer; // TODO: Max buffer size?
|
||||
std::vector<uint8_t> recv_buffer;
|
||||
size_t recv_start = 0;
|
||||
size_t recv_end = 0;
|
||||
int id;
|
||||
|
||||
RawIOJS(
|
||||
int other_party,
|
||||
char channel_label
|
||||
):
|
||||
other_party(other_party),
|
||||
channel_label(channel_label)
|
||||
{}
|
||||
channel_label(channel_label),
|
||||
recv_buffer(64 * 1024)
|
||||
{
|
||||
id = next_raw_io_id++;
|
||||
raw_io_map[id] = this;
|
||||
}
|
||||
|
||||
~RawIOJS() {
|
||||
raw_io_map.erase(id);
|
||||
}
|
||||
|
||||
void send(const void* data, size_t len) override {
|
||||
send_js(other_party, channel_label, data, len);
|
||||
if (send_buffer.size() + len > MAX_SEND_BUFFER_SIZE) {
|
||||
actual_flush();
|
||||
}
|
||||
|
||||
// This will still exceed max size if len > MAX_SEND_BUFFER_SIZE, that's ok
|
||||
send_buffer.resize(send_buffer.size() + len);
|
||||
|
||||
std::memcpy(send_buffer.data() + send_buffer.size() - len, data, len);
|
||||
}
|
||||
|
||||
void recv(void* data, size_t len) override {
|
||||
recv_js(other_party, channel_label, data, len);
|
||||
if (recv_start + len > recv_end) {
|
||||
if (recv_start + len > recv_buffer.size()) {
|
||||
// copy within
|
||||
size_t recv_len = recv_end - recv_start;
|
||||
std::memmove(recv_buffer.data(), recv_buffer.data() + recv_start, recv_len);
|
||||
recv_start = 0;
|
||||
recv_end = recv_len;
|
||||
}
|
||||
|
||||
size_t bytes_needed = recv_start + len - recv_end;
|
||||
size_t room = recv_buffer.size() - recv_end;
|
||||
|
||||
if (bytes_needed > room) {
|
||||
size_t size = recv_buffer.size();
|
||||
|
||||
while (bytes_needed > room) {
|
||||
size *= 2;
|
||||
room = size - recv_end;
|
||||
}
|
||||
|
||||
recv_buffer.resize(size);
|
||||
}
|
||||
|
||||
actual_flush_all();
|
||||
size_t bytes_received = recv_js(other_party, channel_label, recv_buffer.data() + recv_end, bytes_needed, room);
|
||||
recv_end += bytes_received;
|
||||
|
||||
if (bytes_received < bytes_needed) {
|
||||
throw std::runtime_error("recv failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::memcpy(data, recv_buffer.data() + recv_start, len);
|
||||
recv_start += len;
|
||||
}
|
||||
|
||||
void flush() override {
|
||||
// Ignored for now
|
||||
// ignored for now
|
||||
}
|
||||
|
||||
void actual_flush() {
|
||||
if (send_buffer.size() > 0) {
|
||||
send_js(other_party, channel_label, send_buffer.data(), send_buffer.size());
|
||||
send_buffer.clear();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void actual_flush_all() {
|
||||
for (auto& [key, raw_io] : raw_io_map) {
|
||||
raw_io->actual_flush();
|
||||
}
|
||||
}
|
||||
|
||||
class MultiIOJS : public IMultiIO {
|
||||
public:
|
||||
int mParty;
|
||||
@@ -284,6 +361,9 @@ void run_2pc_impl(int party, int nP) {
|
||||
twopc.function_dependent();
|
||||
|
||||
std::vector<bool> output_bits = twopc.online(input_bits, true);
|
||||
|
||||
actual_flush_all();
|
||||
|
||||
handle_output_bits(output_bits);
|
||||
} catch (const std::exception& e) {
|
||||
handle_error(e.what());
|
||||
@@ -339,6 +419,8 @@ void run_mpc_impl(int party, int nP) {
|
||||
output_bits.push_back(output.get_plaintext_bit(i));
|
||||
}
|
||||
|
||||
actual_flush_all();
|
||||
|
||||
handle_output_bits(output_bits);
|
||||
} catch (const std::exception& e) {
|
||||
handle_error(e.what());
|
||||
|
||||
@@ -5,7 +5,7 @@ export default class BufferQueue {
|
||||
private buffer: Uint8Array;
|
||||
private bufferStart: number;
|
||||
private bufferEnd: number;
|
||||
private pendingPops: number[];
|
||||
private pendingPops: { min_len: number, max_len: number }[];
|
||||
private pendingPopsResolvers: {
|
||||
resolve: ((value: Uint8Array) => void),
|
||||
reject: (e: Error) => void,
|
||||
@@ -65,19 +65,21 @@ export default class BufferQueue {
|
||||
* @param len - The number of bytes to pop from the buffer.
|
||||
* @returns A promise resolving with the popped data as a Uint8Array.
|
||||
*/
|
||||
pop(len: number): Promise<Uint8Array> {
|
||||
if (typeof len !== 'number' || len < 0) {
|
||||
return Promise.reject(new Error('Length must be non-negative integer'));
|
||||
pop(min_len: number, max_len: number): Promise<Uint8Array> {
|
||||
if (typeof min_len !== 'number' || typeof max_len !== 'number' || min_len < 0 || max_len < 0 || min_len > max_len) {
|
||||
return Promise.reject(new Error('Invalid min/max lengths'));
|
||||
}
|
||||
|
||||
if (this.bufferEnd - this.bufferStart >= len) {
|
||||
const result = this.buffer.slice(this.bufferStart, this.bufferStart + len);
|
||||
this.bufferStart += len;
|
||||
if (this.bufferEnd - this.bufferStart >= min_len) {
|
||||
const available_len = this.bufferEnd - this.bufferStart;
|
||||
const provide_len = Math.min(max_len, available_len);
|
||||
const result = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len);
|
||||
this.bufferStart += result.length;
|
||||
this._compactBuffer();
|
||||
return Promise.resolve(result);
|
||||
} else if (!this.closed) {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.pendingPops.push(len);
|
||||
this.pendingPops.push({ min_len, max_len });
|
||||
this.pendingPopsResolvers.push({ resolve, reject });
|
||||
});
|
||||
} else {
|
||||
@@ -109,10 +111,12 @@ export default class BufferQueue {
|
||||
*/
|
||||
private _resolvePendingPops(): void {
|
||||
while (this.pendingPops.length > 0) {
|
||||
const len = this.pendingPops[0];
|
||||
if (this.bufferEnd - this.bufferStart >= len) {
|
||||
const data = this.buffer.slice(this.bufferStart, this.bufferStart + len);
|
||||
this.bufferStart += len;
|
||||
const { min_len, max_len } = this.pendingPops[0];
|
||||
if (this.bufferEnd - this.bufferStart >= min_len) {
|
||||
const available_len = this.bufferEnd - this.bufferStart;
|
||||
const provide_len = Math.min(max_len, available_len);
|
||||
const data = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len);
|
||||
this.bufferStart += data.length;
|
||||
this.pendingPops.shift();
|
||||
const { resolve } = this.pendingPopsResolvers.shift()!;
|
||||
resolve(data);
|
||||
|
||||
@@ -27,9 +27,9 @@ export default class BufferedIO
|
||||
this.closeOther?.();
|
||||
}
|
||||
|
||||
async recv(fromParty: number, channel: 'a' | 'b', len: number): Promise<Uint8Array> {
|
||||
async recv(fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number): Promise<Uint8Array> {
|
||||
assert(fromParty === this.otherParty, 'fromParty !== this.otherParty');
|
||||
return await this.bq[channel].pop(len);
|
||||
return await this.bq[channel].pop(min_len, max_len);
|
||||
}
|
||||
|
||||
accept(channel: 'a' | 'b', data: Uint8Array) {
|
||||
|
||||
@@ -61,15 +61,25 @@ async function secureMPC({
|
||||
emp.circuit = circuit;
|
||||
emp.inputBits = inputBits;
|
||||
emp.inputBitsPerParty = inputBitsPerParty;
|
||||
emp.io = io;
|
||||
|
||||
let reject: undefined | ((error: unknown) => void) = undefined;
|
||||
const callbackRejector = new Promise((_resolve, rej) => {
|
||||
reject = rej;
|
||||
});
|
||||
reject = reject!;
|
||||
|
||||
emp.io = {
|
||||
send: useRejector(io.send.bind(io), reject),
|
||||
recv: useRejector(io.recv.bind(io), reject),
|
||||
};
|
||||
|
||||
const method = calculateMethod(mode, size, circuit);
|
||||
|
||||
const result = new Promise<Uint8Array>((resolve, reject) => {
|
||||
const result = new Promise<Uint8Array>(async (resolve, reject) => {
|
||||
try {
|
||||
emp.handleOutput = resolve;
|
||||
emp.handleError = reject;
|
||||
|
||||
callbackRejector.catch(reject);
|
||||
module[method](party, size);
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
@@ -97,7 +107,12 @@ function calculateMethod(
|
||||
case 'mpc':
|
||||
return '_run_mpc';
|
||||
case 'auto':
|
||||
return size === 2 ? '_run_2pc' : '_run_mpc';
|
||||
// Advantage of 2PC specialization is small and contains "FEQ error" bug
|
||||
// for the large circuits, so the performance currently cannot be realized
|
||||
// where it matters.
|
||||
// Therefore, we default to the general N-party mpc mode, even when there
|
||||
// are only 2 parties.
|
||||
return '_run_mpc';
|
||||
|
||||
default:
|
||||
const _never: never = mode;
|
||||
@@ -125,11 +140,11 @@ onmessage = async (event) => {
|
||||
send: (toParty, channel, data) => {
|
||||
postMessage({ type: 'io_send', toParty, channel, data });
|
||||
},
|
||||
recv: (fromParty, channel, len) => {
|
||||
recv: (fromParty, channel, min_len, max_len) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const id = requestId++;
|
||||
pendingRequests[id] = { resolve, reject };
|
||||
postMessage({ type: 'io_recv', fromParty, channel, len, id });
|
||||
postMessage({ type: 'io_recv', fromParty, channel, min_len, max_len, id });
|
||||
});
|
||||
},
|
||||
};
|
||||
@@ -147,7 +162,7 @@ onmessage = async (event) => {
|
||||
|
||||
postMessage({ type: 'result', result });
|
||||
} catch (error) {
|
||||
postMessage({ type: 'error', error: (error as Error).stack });
|
||||
postMessage({ type: 'error', error: (error as Error).message });
|
||||
}
|
||||
} else if (message.type === 'io_recv_response') {
|
||||
const { id, data } = message;
|
||||
@@ -163,3 +178,17 @@ onmessage = async (event) => {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
function useRejector<F extends (...args: any[]) => any>(
|
||||
fn: F,
|
||||
reject: (error: unknown) => void,
|
||||
): F {
|
||||
return ((...args: Parameters<F>) => {
|
||||
try {
|
||||
return fn(...args);
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
throw error;
|
||||
}
|
||||
}) as F;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ windowAny.internalDemo = async function(
|
||||
inputBitsPerParty: [32, 32],
|
||||
io: {
|
||||
send: (toParty, channel, data) => bobBq[channel].push(data),
|
||||
recv: (fromParty, channel, len) => aliceBq[channel].pop(len),
|
||||
recv: (fromParty, channel, min_len, max_len) => aliceBq[channel].pop(min_len, max_len),
|
||||
},
|
||||
mode,
|
||||
}),
|
||||
@@ -40,7 +40,7 @@ windowAny.internalDemo = async function(
|
||||
inputBitsPerParty: [32, 32],
|
||||
io: {
|
||||
send: (toParty, channel, data) => aliceBq[channel].push(data),
|
||||
recv: (fromParty, channel, len) => bobBq[channel].pop(len),
|
||||
recv: (fromParty, channel, min_len, max_len) => bobBq[channel].pop(min_len, max_len),
|
||||
},
|
||||
mode,
|
||||
}),
|
||||
@@ -69,7 +69,7 @@ windowAny.internalDemo3 = async function(
|
||||
inputBitsPerParty: [32, 32, 32],
|
||||
io: {
|
||||
send: (toParty, channel, data) => bqs.get(party, toParty, channel).push(data),
|
||||
recv: (fromParty, channel, len) => bqs.get(fromParty, party, channel).pop(len),
|
||||
recv: (fromParty, channel, min_len, max_len) => bqs.get(fromParty, party, channel).pop(min_len, max_len),
|
||||
},
|
||||
mode,
|
||||
})));
|
||||
@@ -342,9 +342,9 @@ function makeCopyPasteIO(otherParty: number): IO {
|
||||
|
||||
return {
|
||||
send: makeConsoleSend(otherParty),
|
||||
recv: (fromParty, channel, len) => {
|
||||
recv: (fromParty, channel, min_len, max_len) => {
|
||||
assert(fromParty === otherParty, 'Unexpected party');
|
||||
return bq[channel].pop(len);
|
||||
return bq[channel].pop(min_len, max_len);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -42,7 +42,17 @@ export default async function nodeSecureMPC({
|
||||
emp.circuit = circuit;
|
||||
emp.inputBits = inputBits;
|
||||
emp.inputBitsPerParty = inputBitsPerParty;
|
||||
emp.io = io;
|
||||
|
||||
let reject: undefined | ((error: unknown) => void) = undefined;
|
||||
const callbackRejector = new Promise((_resolve, rej) => {
|
||||
reject = rej;
|
||||
});
|
||||
reject = reject!;
|
||||
|
||||
emp.io = {
|
||||
send: useRejector(io.send.bind(io), reject),
|
||||
recv: useRejector(io.recv.bind(io), reject),
|
||||
};
|
||||
|
||||
const method = calculateMethod(mode, size, circuit);
|
||||
|
||||
@@ -50,6 +60,7 @@ export default async function nodeSecureMPC({
|
||||
try {
|
||||
emp.handleOutput = resolve;
|
||||
emp.handleError = reject;
|
||||
callbackRejector.catch(reject);
|
||||
|
||||
module[method](party, size);
|
||||
} catch (error) {
|
||||
@@ -74,10 +85,29 @@ function calculateMethod(
|
||||
case 'mpc':
|
||||
return '_run_mpc';
|
||||
case 'auto':
|
||||
return size === 2 ? '_run_2pc' : '_run_mpc';
|
||||
// Advantage of 2PC specialization is small and contains "FEQ error" bug
|
||||
// for the large circuits, so the performance currently cannot be realized
|
||||
// where it matters.
|
||||
// Therefore, we default to the general N-party mpc mode, even when there
|
||||
// are only 2 parties.
|
||||
return '_run_mpc';
|
||||
|
||||
default:
|
||||
const _never: never = mode;
|
||||
throw new Error('Unexpected mode: ' + mode);
|
||||
}
|
||||
}
|
||||
|
||||
function useRejector<F extends (...args: any[]) => any>(
|
||||
fn: F,
|
||||
reject: (error: unknown) => void,
|
||||
): F {
|
||||
return ((...args: Parameters<F>) => {
|
||||
try {
|
||||
return fn(...args);
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
throw error;
|
||||
}
|
||||
}) as F;
|
||||
}
|
||||
|
||||
@@ -62,10 +62,10 @@ export default function secureMPC({
|
||||
const { toParty, channel, data } = message;
|
||||
io.send(toParty, channel, data);
|
||||
} else if (message.type === 'io_recv') {
|
||||
const { fromParty, channel, len } = message;
|
||||
const { fromParty, channel, min_len, max_len } = message;
|
||||
// Handle the recv request from the worker
|
||||
try {
|
||||
const data = await io.recv(fromParty, channel, len);
|
||||
const data = await io.recv(fromParty, channel, min_len, max_len);
|
||||
worker.postMessage({ type: 'io_recv_response', id: message.id, data });
|
||||
} catch (error) {
|
||||
worker.postMessage({
|
||||
@@ -80,6 +80,10 @@ export default function secureMPC({
|
||||
} else if (message.type === 'error') {
|
||||
// Reject the promise if an error occurred
|
||||
reject(new Error(message.error));
|
||||
} else if (message.type === 'log') {
|
||||
console.log('Worker log:', message.msg);
|
||||
} else {
|
||||
console.error('Unexpected message from worker:', message);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export type IO = {
|
||||
send: (toParty: number, channel: 'a' | 'b', data: Uint8Array) => void;
|
||||
recv: (fromParty: number, channel: 'a' | 'b', len: number) => Promise<Uint8Array>;
|
||||
recv: (fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number) => Promise<Uint8Array>;
|
||||
on?: (event: 'error', listener: (error: Error) => void) => void;
|
||||
off?: (event: 'error', listener: (error: Error) => void) => void;
|
||||
close?: () => void;
|
||||
|
||||
@@ -55,9 +55,9 @@ async function internalDemo(
|
||||
expect(toParty).to.equal(1);
|
||||
bqs.get('alice', 'bob', channel).push(data);
|
||||
},
|
||||
recv: async (fromParty, channel, len) => {
|
||||
recv: async (fromParty, channel, min_len, max_len) => {
|
||||
expect(fromParty).to.equal(1);
|
||||
return bqs.get('bob', 'alice', channel).pop(len);
|
||||
return bqs.get('bob', 'alice', channel).pop(min_len, max_len);
|
||||
},
|
||||
},
|
||||
mode,
|
||||
@@ -73,9 +73,9 @@ async function internalDemo(
|
||||
expect(toParty).to.equal(0);
|
||||
bqs.get('bob', 'alice', channel).push(data);
|
||||
},
|
||||
recv: async (fromParty, channel, len) => {
|
||||
recv: async (fromParty, channel, min_len, max_len) => {
|
||||
expect(fromParty).to.equal(0);
|
||||
return bqs.get('alice', 'bob', channel).pop(len);
|
||||
return bqs.get('alice', 'bob', channel).pop(min_len, max_len);
|
||||
},
|
||||
},
|
||||
mode,
|
||||
@@ -119,8 +119,8 @@ async function internalDemoN(
|
||||
send: (toParty, channel, data) => {
|
||||
bqs.get(party, toParty, channel).push(data);
|
||||
},
|
||||
recv: async (fromParty, channel, len) => {
|
||||
return bqs.get(fromParty, party, channel).pop(len);
|
||||
recv: async (fromParty, channel, min_len, max_len) => {
|
||||
return bqs.get(fromParty, party, channel).pop(min_len, max_len);
|
||||
},
|
||||
}
|
||||
})));
|
||||
|
||||
Reference in New Issue
Block a user