Improve performance by producing a binary bristol format in JS so that we don't use emscripten's super slow istringstream

This commit is contained in:
Andrew Morris
2025-06-12 13:00:26 +10:00
parent 725d239731
commit 6ebafc07a8
8 changed files with 295 additions and 31 deletions

View File

@@ -187,31 +187,30 @@ public:
}
};
EM_JS(char*, get_circuit_raw, (int* lengthPtr), {
if (!Module.emp?.circuit) {
throw new Error("Module.emp.circuit is not defined in JavaScript.");
EM_JS(uint8_t*, get_circuit_raw, (int* lengthPtr), {
if (!Module.emp?.circuitBinary) {
throw new Error("Module.emp.circuitBinary is not defined in JavaScript.");
}
const circuitString = Module.emp.circuit; // Get the string from JavaScript
const length = lengthBytesUTF8(circuitString) + 1; // Calculate length including the null terminator
const circuitBinary = Module.emp.circuitBinary; // Get the string from JavaScript
// Allocate memory for the string
const strPtr = Module._js_char_malloc(length);
stringToUTF8(circuitString, strPtr, length);
// Allocate memory
const ptr = Module._js_malloc(circuitBinary.length);
Module.HEAPU8.set(circuitBinary, ptr);
// Set the length at the provided pointer location
setValue(lengthPtr, length, 'i32');
setValue(lengthPtr, circuitBinary.length, 'i32');
// Return the pointer
return strPtr;
return ptr;
});
emp::BristolFormat get_circuit() {
int length = 0;
char* circuit_raw = get_circuit_raw(&length);
uint8_t* circuit_raw = get_circuit_raw(&length);
emp::BristolFormat circuit;
circuit.from_str(circuit_raw);
circuit.from_buffer(circuit_raw, length);
free(circuit_raw);
return circuit;

View File

@@ -89,6 +89,78 @@ public:
fout.close();
}
/* Consume the binary layout produced by bristolToBinary.
* ┌────────────┬──────────────────────────────────────────────────┐
* │ bytes 0-19 │ five uint32 (num_gate, num_wire, n1, n2, n3) │
* │ … │ repeated records │
* │ │ 1 byte opcode (0 INV, 1 XOR, 2 AND) │
* │ │ INV : 2 × uint32 (in , out) ── 9 B │
* │ │ XOR/AND: 3 × uint32 (in1,in2,out) ── 13 B │
* └────────────┴──────────────────────────────────────────────────┘
* Any deviation throws std::runtime_error.
*/
void from_buffer(const uint8_t* buf, int size) {
auto need = [&](size_t n) {
if (n > static_cast<size_t>(size))
throw std::runtime_error("Buffer too small / truncated");
};
auto read_u32 = [&](const uint8_t* p) -> uint32_t {
return static_cast<uint32_t>(p[0]) |
(static_cast<uint32_t>(p[1]) << 8) |
(static_cast<uint32_t>(p[2]) << 16) |
(static_cast<uint32_t>(p[3]) << 24);
};
/* ---------- header ---------- */
need(20);
num_gate = static_cast<int>(read_u32(buf + 0));
num_wire = static_cast<int>(read_u32(buf + 4));
n1 = static_cast<int>(read_u32(buf + 8));
n2 = static_cast<int>(read_u32(buf + 12));
n3 = static_cast<int>(read_u32(buf + 16));
gates.resize(num_gate * 4);
wires.resize(num_wire);
size_t offset = 20;
for (int g = 0; g < num_gate; ++g) {
need(offset + 1); // opcode byte
uint8_t opcode = buf[offset++];
switch (opcode) {
case 0: { // INV
need(offset + 8);
int in = static_cast<int>(read_u32(buf + offset)); offset += 4;
int out = static_cast<int>(read_u32(buf + offset)); offset += 4;
gates[4 * g] = in;
gates[4 * g + 1] = 0; // unused
gates[4 * g + 2] = out;
gates[4 * g + 3] = NOT_GATE;
break;
}
case 1: // XOR
case 2: { // AND
need(offset + 12);
int in1 = static_cast<int>(read_u32(buf + offset)); offset += 4;
int in2 = static_cast<int>(read_u32(buf + offset)); offset += 4;
int out = static_cast<int>(read_u32(buf + offset)); offset += 4;
gates[4 * g] = in1;
gates[4 * g + 1] = in2;
gates[4 * g + 2] = out;
gates[4 * g + 3] = (opcode == 1) ? XOR_GATE : AND_GATE;
break;
}
default:
throw std::runtime_error("Unknown gate opcode");
}
}
if (offset != static_cast<size_t>(size))
throw std::runtime_error("Extra bytes after final gate");
}
void compute(Bit* out, const Bit* in1, const Bit* in2) {
compute((block*)out, (block*)in1, (block*)in2);
}

View File

@@ -2,7 +2,7 @@ import type { IO } from "./types";
type Module = {
emp?: {
circuit?: string;
circuitBinary?: Uint8Array;
inputBits?: Uint8Array;
inputBitsPerParty?: number[];
io?: IO;
@@ -22,18 +22,18 @@ declare const createModule: () => Promise<Module>
*
* @param party - The party index joining the computation (0, 1, .. N-1).
* @param size - The number of parties in the computation.
* @param circuit - The circuit to run.
* @param circuitBinary - The circuit to run.
* @param inputBits - The input bits for the circuit, represented as one bit per byte.
* @param inputBitsPerParty - The number of input bits for each party.
* @param io - Input/output channels for communication between the two parties.
* @returns A promise resolving with the output bits of the circuit.
*/
async function secureMPC({
party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto',
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto',
}: {
party: number,
size: number,
circuit: string,
circuitBinary: Uint8Array,
inputBits: Uint8Array,
inputBitsPerParty: number[],
io: IO,
@@ -48,7 +48,7 @@ async function secureMPC({
running = true;
const emp: {
circuit?: string;
circuitBinary?: Uint8Array;
inputBits?: Uint8Array;
inputBitsPerParty?: number[];
io?: IO;
@@ -58,7 +58,7 @@ async function secureMPC({
module.emp = emp;
emp.circuit = circuit;
emp.circuitBinary = circuitBinary;
emp.inputBits = inputBits;
emp.inputBitsPerParty = inputBitsPerParty;
@@ -73,7 +73,7 @@ async function secureMPC({
recv: useRejector(io.recv.bind(io), reject),
};
const method = calculateMethod(mode, size, circuit);
const method = calculateMethod(mode, size, circuitBinary);
const result = new Promise<Uint8Array>(async (resolve, reject) => {
try {
@@ -99,7 +99,7 @@ function calculateMethod(
// Currently unused, but some 2-party circuits might perform better with
// _runMPC
_circuit: string,
_circuitBinary: Uint8Array,
) {
switch (mode) {
case '2pc':
@@ -133,7 +133,7 @@ onmessage = async (event) => {
const message = event.data;
if (message.type === 'start') {
const { party, size, circuit, inputBits, inputBitsPerParty, mode } = message;
const { party, size, circuitBinary, inputBits, inputBitsPerParty, mode } = message;
// Create a proxy IO object to communicate with the main thread
const io: IO = {
@@ -153,7 +153,7 @@ onmessage = async (event) => {
const result = await secureMPC({
party,
size,
circuit,
circuitBinary,
inputBits,
inputBitsPerParty,
io,

46
src/ts/binaryToBristol.ts Normal file
View File

@@ -0,0 +1,46 @@
/**
* Decode the compact binary layout produced by `bristolToBinary`
* back into a textual Bristol format string.
*
* Strict: any malformed input triggers an Error.
*/
export default function binaryToBristol(bytes: Uint8Array): string {
if (bytes.byteLength < 20) throw new Error("Buffer shorter than 20-byte header");
const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength);
let offset = 0;
const u32 = () => { const v = view.getUint32(offset, true); offset += 4; return v; };
const u8 = () => { const v = view.getUint8(offset); offset += 1; return v; };
/* ---------- header ---------- */
const header = [u32(), u32(), u32(), u32(), u32()];
const lines: string[] = [
`${header[0]} ${header[1]}`,
`${header[2]} ${header[3]} ${header[4]}`,
'',
];
/* ---------- gates ---------- */
const CODE_TO_NAME = ["INV", "XOR", "AND"] as const;
const INPUT_COUNT = [1, 2, 2] as const;
const OUTPUT_COUNT = [1, 1, 1] as const;
while (offset < bytes.byteLength) {
const code = u8();
if (code > 2) throw new Error(`Unknown gate code ${code} at byte ${offset - 1}`);
const inCount = INPUT_COUNT[code];
const outCount = OUTPUT_COUNT[code];
const wires: number[] = [];
for (let i = 0; i < inCount + outCount; i++) {
if (offset + 4 > bytes.byteLength) throw new Error("Truncated wire index");
wires.push(u32());
}
lines.push(
`${inCount} ${outCount} ${wires.join(" ")} ${CODE_TO_NAME[code]}`
);
}
if (offset !== bytes.byteLength) throw new Error("Extra bytes after final gate");
return lines.join("\n");
}

95
src/ts/bristolToBinary.ts Normal file
View File

@@ -0,0 +1,95 @@
/**
* **Strict** converter from a (restricted) Bristol-format string to a compact
* binary representation.
*
* Layout (little-endian)
* ┌────────────┬─────────────────────────────────────────────┐
* │ bytes 0-19 │ five 32-bit unsigned ints (exactly as given)│
* │ … │ repeated gate records │
* │ │ 1 byte gate-type (0 INV, 1 XOR, 2 AND) │
* │ │ INV : 2 × uint32 (in, out) │
* │ │ XOR/AND: 3 × uint32 (in1,in2,out) │
* └────────────┴─────────────────────────────────────────────┘
*
* Any deviation from the expected syntax throws an Error.
*/
export default function bristolToBinary(source: string): Uint8Array {
/* ---------- helpers ---------- */
const toInt = (tok: string, ctx: string) => {
if (!/^-?\d+$/.test(tok)) throw new Error(`Expected integer for ${ctx}, got "${tok}"`);
return Number(tok);
};
/* ---------- split lines (keep blank lines for validation) ---------- */
const rawLines = source.split(/\r?\n/);
if (rawLines.length < 3) throw new Error("Input too short missing header or gates");
/* ---------- header ---------- */
const h1 = rawLines[0].trim().split(/\s+/);
const h2 = rawLines[1].trim().split(/\s+/);
if (h1.length !== 2) throw new Error("Header line 1: expected exactly 2 numbers");
if (h2.length !== 3) throw new Error("Header line 2: expected exactly 3 numbers");
const header = [...h1, ...h2].map((t, i) => toInt(t, `header[${i}]`));
/* ---------- gate parsing ---------- */
const GATE_CODE = { INV: 0, XOR: 1, AND: 2 } as const;
type Gate = { code: number; wires: number[] };
const gates: Gate[] = [];
for (let ln = 2; ln < rawLines.length; ln++) {
const line = rawLines[ln].trim();
if (line === "") continue; // allow a single blank separator still not “ignored”
const parts = line.split(/\s+/);
if (parts.length < 5) throw new Error(`Line ${ln + 1}: too few tokens`);
const inCount = toInt(parts[0], "input-count");
const outCount = toInt(parts[1], "output-count");
const gateType = parts[parts.length - 1] as keyof typeof GATE_CODE;
if (!(gateType in GATE_CODE)) throw new Error(`Line ${ln + 1}: unknown gate type "${gateType}"`);
/* verify the (k inputs, l outputs) pair agrees with the opcode */
const expected = gateType === "INV" ? [1, 1] : [2, 1];
if (inCount !== expected[0] || outCount !== expected[1]) {
throw new Error(
`Line ${ln + 1}: counts ${inCount}-in/${outCount}-out contradict gate type ${gateType}`
);
}
const wireTokens = parts.slice(2, 2 + inCount + outCount);
if (wireTokens.length !== inCount + outCount) {
throw new Error(`Line ${ln + 1}: expected ${inCount + outCount} wire indices`);
}
const wires = wireTokens.map((t, i) => toInt(t, `wire[${i}]`));
/* ensure no trailing garbage */
if (parts.length !== 2 + wires.length + 1) {
throw new Error(`Line ${ln + 1}: unexpected extra tokens`);
}
gates.push({ code: GATE_CODE[gateType], wires });
}
if (gates.length === 0) throw new Error("No gate definitions found");
/* ---------- allocate & write ---------- */
const byteLength =
5 * 4 +
gates.reduce((sum, g) => sum + 1 + g.wires.length * 4, 0);
const buf = new ArrayBuffer(byteLength);
const view = new DataView(buf);
let off = 0;
const w32 = (v: number) => { view.setUint32(off, v >>> 0, true); off += 4; };
const w8 = (v: number) => { view.setUint8(off, v); off += 1; };
header.forEach(w32);
gates.forEach(({ code, wires }) => {
w8(code);
wires.forEach(w32);
});
return new Uint8Array(buf);
}

View File

@@ -5,18 +5,18 @@ import type { IO } from "./types";
*
* @param party - The party index joining the computation (0, 1, .. N-1).
* @param size - The number of parties in the computation.
* @param circuit - The circuit to run.
* @param circuitBinary - The circuit to run.
* @param inputBits - The input to the circuit, represented as one bit per byte.
* @param inputBitsPerParty - The number of input bits for each party.
* @param io - Input/output channels for communication between the two parties.
* @returns A promise resolving with the output bits of the circuit.
*/
export default async function nodeSecureMPC({
party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto',
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto',
}: {
party: number,
size: number,
circuit: string,
circuitBinary: Uint8Array,
inputBits: Uint8Array,
inputBitsPerParty: number[],
io: IO,
@@ -29,7 +29,7 @@ export default async function nodeSecureMPC({
let module = await ((await import('../../build/jslib.js')).default());
const emp: {
circuit?: string;
circuitBinary?: Uint8Array;
inputBits?: Uint8Array;
inputBitsPerParty?: number[];
io?: IO;
@@ -39,7 +39,7 @@ export default async function nodeSecureMPC({
module.emp = emp;
emp.circuit = circuit;
emp.circuitBinary = circuitBinary;
emp.inputBits = inputBits;
emp.inputBitsPerParty = inputBitsPerParty;
@@ -54,7 +54,7 @@ export default async function nodeSecureMPC({
recv: useRejector(io.recv.bind(io), reject),
};
const method = calculateMethod(mode, size, circuit);
const method = calculateMethod(mode, size, circuitBinary);
const result = await new Promise<Uint8Array>((resolve, reject) => {
try {
@@ -77,7 +77,7 @@ function calculateMethod(
// Currently unused, but some 2-party circuits might perform better with
// _runMPC
_circuit: string,
_circuitBinary: Uint8Array,
) {
switch (mode) {
case '2pc':

View File

@@ -2,6 +2,7 @@ import { EventEmitter } from "ee-typed";
import type { IO } from "./types";
import workerCode from "./workerCode.js";
import nodeSecureMPC from "./nodeSecureMPC.js";
import bristolToBinary from "./bristolToBinary.js";
export type SecureMPC = typeof secureMPC;
@@ -29,9 +30,11 @@ export default function secureMPC({
io: IO,
mode?: '2pc' | 'mpc' | 'auto',
}): Promise<Uint8Array> {
const circuitBinary = bristolToBinary(circuit);
if (typeof Worker === 'undefined') {
return nodeSecureMPC({
party, size, circuit, inputBits, inputBitsPerParty, io, mode,
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode,
});
}
@@ -48,7 +51,7 @@ export default function secureMPC({
type: 'start',
party,
size,
circuit,
circuitBinary,
inputBits,
inputBitsPerParty,
mode,

View File

@@ -0,0 +1,49 @@
import { expect } from "chai";
import bristolToBinary from '../src/ts/bristolToBinary';
import binaryToBristol from '../src/ts/binaryToBristol';
const normalise = (s: string) =>
s
.trim()
.replace(/\r?\n/g, "\n") // LF only
.split("\n")
.map(l => l.trimEnd()) // drop trailing spaces/tabs
.join("\n");
describe("Bristol ⇆ Binary round-trip", () => {
const samples: string[] = [
`106601 107113
512 0 160
1 1 177 749 INV
2 1 30 31 3599 XOR
1 1 55 4100 INV
1 1 62 4246 INV
1 1 83 3297 INV`,
// a smaller second sample to be sure different sizes round-trip
`2 3
1 0 0
1 1 0 1 INV
2 1 0 1 2 AND`,
];
samples.forEach((src, i) => {
it(`sample ${i + 1} should round-trip exactly`, () => {
const bin = bristolToBinary(src);
const text = binaryToBristol(bin);
expect(normalise(text)).to.equal(normalise(src));
});
});
it("decoding an opcode-corrupted buffer should throw", () => {
const good = bristolToBinary(samples[0]);
// Byte 20 (index 20) is the first gate's opcode. 0/1/2 are valid; 0xFF is not.
const bad = new Uint8Array(good);
bad[20] = 0xFF;
expect(() => binaryToBristol(bad)).to.throw();
});
});