improvements ported from perf-testing: buffer and flush the io, accept bytes than needed into buffer when available, improve c++ to js error propagation, default to mpc mode

This commit is contained in:
Andrew Morris
2025-06-12 09:57:37 +10:00
parent 888bc95a0c
commit 5bf8be6426
9 changed files with 193 additions and 44 deletions

View File

@@ -3,6 +3,7 @@
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <map>
#include "emp-tool/io/i_raw_io.h"
#include "emp-ag2pc/2pc.h"
@@ -24,45 +25,121 @@ EM_JS(void, send_js, (int to_party, char channel_label, const void* data, size_t
});
// Implement recv_js function to receive data from JavaScript to C++
EM_ASYNC_JS(void, recv_js, (int from_party, char channel_label, void* data, size_t len), {
EM_ASYNC_JS(size_t, recv_js, (int from_party, char channel_label, void* data, size_t min_len, size_t max_len), {
if (!Module.emp?.io?.recv) {
reject(new Error("Module.emp.io.recv is not defined in JavaScript."));
return;
}
// Wait for data from JavaScript
const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), len);
const dataArray = await Module.emp.io.recv(from_party - 1, String.fromCharCode(channel_label), min_len, max_len);
// Copy data from JavaScript Uint8Array to WebAssembly memory
HEAPU8.set(dataArray, data);
// Return the length of the received data
return dataArray.length;
});
class RawIOJS;
std::map<int, RawIOJS*> raw_io_map;
int next_raw_io_id = 0;
size_t MAX_SEND_BUFFER_SIZE = 64 * 1024;
void actual_flush_all();
class RawIOJS : public IRawIO {
public:
int other_party;
char channel_label;
std::vector<uint8_t> send_buffer; // TODO: Max buffer size?
std::vector<uint8_t> recv_buffer;
size_t recv_start = 0;
size_t recv_end = 0;
int id;
RawIOJS(
int other_party,
char channel_label
):
other_party(other_party),
channel_label(channel_label)
{}
channel_label(channel_label),
recv_buffer(64 * 1024)
{
id = next_raw_io_id++;
raw_io_map[id] = this;
}
~RawIOJS() {
raw_io_map.erase(id);
}
void send(const void* data, size_t len) override {
send_js(other_party, channel_label, data, len);
if (send_buffer.size() + len > MAX_SEND_BUFFER_SIZE) {
actual_flush();
}
// This will still exceed max size if len > MAX_SEND_BUFFER_SIZE, that's ok
send_buffer.resize(send_buffer.size() + len);
std::memcpy(send_buffer.data() + send_buffer.size() - len, data, len);
}
void recv(void* data, size_t len) override {
recv_js(other_party, channel_label, data, len);
if (recv_start + len > recv_end) {
if (recv_start + len > recv_buffer.size()) {
// copy within
size_t recv_len = recv_end - recv_start;
std::memmove(recv_buffer.data(), recv_buffer.data() + recv_start, recv_len);
recv_start = 0;
recv_end = recv_len;
}
size_t bytes_needed = recv_start + len - recv_end;
size_t room = recv_buffer.size() - recv_end;
if (bytes_needed > room) {
size_t size = recv_buffer.size();
while (bytes_needed > room) {
size *= 2;
room = size - recv_end;
}
recv_buffer.resize(size);
}
actual_flush_all();
size_t bytes_received = recv_js(other_party, channel_label, recv_buffer.data() + recv_end, bytes_needed, room);
recv_end += bytes_received;
if (bytes_received < bytes_needed) {
throw std::runtime_error("recv failed");
}
}
std::memcpy(data, recv_buffer.data() + recv_start, len);
recv_start += len;
}
void flush() override {
// Ignored for now
// ignored for now
}
void actual_flush() {
if (send_buffer.size() > 0) {
send_js(other_party, channel_label, send_buffer.data(), send_buffer.size());
send_buffer.clear();
}
}
};
void actual_flush_all() {
for (auto& [key, raw_io] : raw_io_map) {
raw_io->actual_flush();
}
}
class MultiIOJS : public IMultiIO {
public:
int mParty;
@@ -284,6 +361,9 @@ void run_2pc_impl(int party, int nP) {
twopc.function_dependent();
std::vector<bool> output_bits = twopc.online(input_bits, true);
actual_flush_all();
handle_output_bits(output_bits);
} catch (const std::exception& e) {
handle_error(e.what());
@@ -339,6 +419,8 @@ void run_mpc_impl(int party, int nP) {
output_bits.push_back(output.get_plaintext_bit(i));
}
actual_flush_all();
handle_output_bits(output_bits);
} catch (const std::exception& e) {
handle_error(e.what());

View File

@@ -5,7 +5,7 @@ export default class BufferQueue {
private buffer: Uint8Array;
private bufferStart: number;
private bufferEnd: number;
private pendingPops: number[];
private pendingPops: { min_len: number, max_len: number }[];
private pendingPopsResolvers: {
resolve: ((value: Uint8Array) => void),
reject: (e: Error) => void,
@@ -65,19 +65,21 @@ export default class BufferQueue {
* @param len - The number of bytes to pop from the buffer.
* @returns A promise resolving with the popped data as a Uint8Array.
*/
pop(len: number): Promise<Uint8Array> {
if (typeof len !== 'number' || len < 0) {
return Promise.reject(new Error('Length must be non-negative integer'));
pop(min_len: number, max_len: number): Promise<Uint8Array> {
if (typeof min_len !== 'number' || typeof max_len !== 'number' || min_len < 0 || max_len < 0 || min_len > max_len) {
return Promise.reject(new Error('Invalid min/max lengths'));
}
if (this.bufferEnd - this.bufferStart >= len) {
const result = this.buffer.slice(this.bufferStart, this.bufferStart + len);
this.bufferStart += len;
if (this.bufferEnd - this.bufferStart >= min_len) {
const available_len = this.bufferEnd - this.bufferStart;
const provide_len = Math.min(max_len, available_len);
const result = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len);
this.bufferStart += result.length;
this._compactBuffer();
return Promise.resolve(result);
} else if (!this.closed) {
return new Promise((resolve, reject) => {
this.pendingPops.push(len);
this.pendingPops.push({ min_len, max_len });
this.pendingPopsResolvers.push({ resolve, reject });
});
} else {
@@ -109,10 +111,12 @@ export default class BufferQueue {
*/
private _resolvePendingPops(): void {
while (this.pendingPops.length > 0) {
const len = this.pendingPops[0];
if (this.bufferEnd - this.bufferStart >= len) {
const data = this.buffer.slice(this.bufferStart, this.bufferStart + len);
this.bufferStart += len;
const { min_len, max_len } = this.pendingPops[0];
if (this.bufferEnd - this.bufferStart >= min_len) {
const available_len = this.bufferEnd - this.bufferStart;
const provide_len = Math.min(max_len, available_len);
const data = this.buffer.slice(this.bufferStart, this.bufferStart + provide_len);
this.bufferStart += data.length;
this.pendingPops.shift();
const { resolve } = this.pendingPopsResolvers.shift()!;
resolve(data);

View File

@@ -27,9 +27,9 @@ export default class BufferedIO
this.closeOther?.();
}
async recv(fromParty: number, channel: 'a' | 'b', len: number): Promise<Uint8Array> {
async recv(fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number): Promise<Uint8Array> {
assert(fromParty === this.otherParty, 'fromParty !== this.otherParty');
return await this.bq[channel].pop(len);
return await this.bq[channel].pop(min_len, max_len);
}
accept(channel: 'a' | 'b', data: Uint8Array) {

View File

@@ -61,15 +61,25 @@ async function secureMPC({
emp.circuit = circuit;
emp.inputBits = inputBits;
emp.inputBitsPerParty = inputBitsPerParty;
emp.io = io;
let reject: undefined | ((error: unknown) => void) = undefined;
const callbackRejector = new Promise((_resolve, rej) => {
reject = rej;
});
reject = reject!;
emp.io = {
send: useRejector(io.send.bind(io), reject),
recv: useRejector(io.recv.bind(io), reject),
};
const method = calculateMethod(mode, size, circuit);
const result = new Promise<Uint8Array>((resolve, reject) => {
const result = new Promise<Uint8Array>(async (resolve, reject) => {
try {
emp.handleOutput = resolve;
emp.handleError = reject;
callbackRejector.catch(reject);
module[method](party, size);
} catch (error) {
reject(error);
@@ -97,7 +107,12 @@ function calculateMethod(
case 'mpc':
return '_run_mpc';
case 'auto':
return size === 2 ? '_run_2pc' : '_run_mpc';
// Advantage of 2PC specialization is small and contains "FEQ error" bug
// for the large circuits, so the performance currently cannot be realized
// where it matters.
// Therefore, we default to the general N-party mpc mode, even when there
// are only 2 parties.
return '_run_mpc';
default:
const _never: never = mode;
@@ -125,11 +140,11 @@ onmessage = async (event) => {
send: (toParty, channel, data) => {
postMessage({ type: 'io_send', toParty, channel, data });
},
recv: (fromParty, channel, len) => {
recv: (fromParty, channel, min_len, max_len) => {
return new Promise((resolve, reject) => {
const id = requestId++;
pendingRequests[id] = { resolve, reject };
postMessage({ type: 'io_recv', fromParty, channel, len, id });
postMessage({ type: 'io_recv', fromParty, channel, min_len, max_len, id });
});
},
};
@@ -147,7 +162,7 @@ onmessage = async (event) => {
postMessage({ type: 'result', result });
} catch (error) {
postMessage({ type: 'error', error: (error as Error).stack });
postMessage({ type: 'error', error: (error as Error).message });
}
} else if (message.type === 'io_recv_response') {
const { id, data } = message;
@@ -163,3 +178,17 @@ onmessage = async (event) => {
}
}
};
function useRejector<F extends (...args: any[]) => any>(
fn: F,
reject: (error: unknown) => void,
): F {
return ((...args: Parameters<F>) => {
try {
return fn(...args);
} catch (error) {
reject(error);
throw error;
}
}) as F;
}

View File

@@ -28,7 +28,7 @@ windowAny.internalDemo = async function(
inputBitsPerParty: [32, 32],
io: {
send: (toParty, channel, data) => bobBq[channel].push(data),
recv: (fromParty, channel, len) => aliceBq[channel].pop(len),
recv: (fromParty, channel, min_len, max_len) => aliceBq[channel].pop(min_len, max_len),
},
mode,
}),
@@ -40,7 +40,7 @@ windowAny.internalDemo = async function(
inputBitsPerParty: [32, 32],
io: {
send: (toParty, channel, data) => aliceBq[channel].push(data),
recv: (fromParty, channel, len) => bobBq[channel].pop(len),
recv: (fromParty, channel, min_len, max_len) => bobBq[channel].pop(min_len, max_len),
},
mode,
}),
@@ -69,7 +69,7 @@ windowAny.internalDemo3 = async function(
inputBitsPerParty: [32, 32, 32],
io: {
send: (toParty, channel, data) => bqs.get(party, toParty, channel).push(data),
recv: (fromParty, channel, len) => bqs.get(fromParty, party, channel).pop(len),
recv: (fromParty, channel, min_len, max_len) => bqs.get(fromParty, party, channel).pop(min_len, max_len),
},
mode,
})));
@@ -342,9 +342,9 @@ function makeCopyPasteIO(otherParty: number): IO {
return {
send: makeConsoleSend(otherParty),
recv: (fromParty, channel, len) => {
recv: (fromParty, channel, min_len, max_len) => {
assert(fromParty === otherParty, 'Unexpected party');
return bq[channel].pop(len);
return bq[channel].pop(min_len, max_len);
},
};
}

View File

@@ -42,7 +42,17 @@ export default async function nodeSecureMPC({
emp.circuit = circuit;
emp.inputBits = inputBits;
emp.inputBitsPerParty = inputBitsPerParty;
emp.io = io;
let reject: undefined | ((error: unknown) => void) = undefined;
const callbackRejector = new Promise((_resolve, rej) => {
reject = rej;
});
reject = reject!;
emp.io = {
send: useRejector(io.send.bind(io), reject),
recv: useRejector(io.recv.bind(io), reject),
};
const method = calculateMethod(mode, size, circuit);
@@ -50,6 +60,7 @@ export default async function nodeSecureMPC({
try {
emp.handleOutput = resolve;
emp.handleError = reject;
callbackRejector.catch(reject);
module[method](party, size);
} catch (error) {
@@ -74,10 +85,29 @@ function calculateMethod(
case 'mpc':
return '_run_mpc';
case 'auto':
return size === 2 ? '_run_2pc' : '_run_mpc';
// Advantage of 2PC specialization is small and contains "FEQ error" bug
// for the large circuits, so the performance currently cannot be realized
// where it matters.
// Therefore, we default to the general N-party mpc mode, even when there
// are only 2 parties.
return '_run_mpc';
default:
const _never: never = mode;
throw new Error('Unexpected mode: ' + mode);
}
}
function useRejector<F extends (...args: any[]) => any>(
fn: F,
reject: (error: unknown) => void,
): F {
return ((...args: Parameters<F>) => {
try {
return fn(...args);
} catch (error) {
reject(error);
throw error;
}
}) as F;
}

View File

@@ -62,10 +62,10 @@ export default function secureMPC({
const { toParty, channel, data } = message;
io.send(toParty, channel, data);
} else if (message.type === 'io_recv') {
const { fromParty, channel, len } = message;
const { fromParty, channel, min_len, max_len } = message;
// Handle the recv request from the worker
try {
const data = await io.recv(fromParty, channel, len);
const data = await io.recv(fromParty, channel, min_len, max_len);
worker.postMessage({ type: 'io_recv_response', id: message.id, data });
} catch (error) {
worker.postMessage({
@@ -80,6 +80,10 @@ export default function secureMPC({
} else if (message.type === 'error') {
// Reject the promise if an error occurred
reject(new Error(message.error));
} else if (message.type === 'log') {
console.log('Worker log:', message.msg);
} else {
console.error('Unexpected message from worker:', message);
}
};

View File

@@ -1,6 +1,6 @@
export type IO = {
send: (toParty: number, channel: 'a' | 'b', data: Uint8Array) => void;
recv: (fromParty: number, channel: 'a' | 'b', len: number) => Promise<Uint8Array>;
recv: (fromParty: number, channel: 'a' | 'b', min_len: number, max_len: number) => Promise<Uint8Array>;
on?: (event: 'error', listener: (error: Error) => void) => void;
off?: (event: 'error', listener: (error: Error) => void) => void;
close?: () => void;

View File

@@ -55,9 +55,9 @@ async function internalDemo(
expect(toParty).to.equal(1);
bqs.get('alice', 'bob', channel).push(data);
},
recv: async (fromParty, channel, len) => {
recv: async (fromParty, channel, min_len, max_len) => {
expect(fromParty).to.equal(1);
return bqs.get('bob', 'alice', channel).pop(len);
return bqs.get('bob', 'alice', channel).pop(min_len, max_len);
},
},
mode,
@@ -73,9 +73,9 @@ async function internalDemo(
expect(toParty).to.equal(0);
bqs.get('bob', 'alice', channel).push(data);
},
recv: async (fromParty, channel, len) => {
recv: async (fromParty, channel, min_len, max_len) => {
expect(fromParty).to.equal(0);
return bqs.get('alice', 'bob', channel).pop(len);
return bqs.get('alice', 'bob', channel).pop(min_len, max_len);
},
},
mode,
@@ -119,8 +119,8 @@ async function internalDemoN(
send: (toParty, channel, data) => {
bqs.get(party, toParty, channel).push(data);
},
recv: async (fromParty, channel, len) => {
return bqs.get(fromParty, party, channel).pop(len);
recv: async (fromParty, channel, min_len, max_len) => {
return bqs.get(fromParty, party, channel).pop(min_len, max_len);
},
}
})));