From 6ebafc07a85e1286204f2bce2ca403af44bb2063 Mon Sep 17 00:00:00 2001 From: Andrew Morris Date: Thu, 12 Jun 2025 13:00:26 +1000 Subject: [PATCH] Improve performance by producing a binary bristol format in JS so that we don't use emscripten's super slow istringstream --- programs/jslib.cpp | 23 +++--- src/cpp/emp-tool/circuits/circuit_file.h | 72 ++++++++++++++++++ src/ts/appendWorker.ts | 20 ++--- src/ts/binaryToBristol.ts | 46 ++++++++++++ src/ts/bristolToBinary.ts | 95 ++++++++++++++++++++++++ src/ts/nodeSecureMPC.ts | 14 ++-- src/ts/secureMPC.ts | 7 +- tests/bristolBinary.test.ts | 49 ++++++++++++ 8 files changed, 295 insertions(+), 31 deletions(-) create mode 100644 src/ts/binaryToBristol.ts create mode 100644 src/ts/bristolToBinary.ts create mode 100644 tests/bristolBinary.test.ts diff --git a/programs/jslib.cpp b/programs/jslib.cpp index d8ff81a..964715e 100644 --- a/programs/jslib.cpp +++ b/programs/jslib.cpp @@ -187,31 +187,30 @@ public: } }; -EM_JS(char*, get_circuit_raw, (int* lengthPtr), { - if (!Module.emp?.circuit) { - throw new Error("Module.emp.circuit is not defined in JavaScript."); +EM_JS(uint8_t*, get_circuit_raw, (int* lengthPtr), { + if (!Module.emp?.circuitBinary) { + throw new Error("Module.emp.circuitBinary is not defined in JavaScript."); } - const circuitString = Module.emp.circuit; // Get the string from JavaScript - const length = lengthBytesUTF8(circuitString) + 1; // Calculate length including the null terminator + const circuitBinary = Module.emp.circuitBinary; // Get the string from JavaScript - // Allocate memory for the string - const strPtr = Module._js_char_malloc(length); - stringToUTF8(circuitString, strPtr, length); + // Allocate memory + const ptr = Module._js_malloc(circuitBinary.length); + Module.HEAPU8.set(circuitBinary, ptr); // Set the length at the provided pointer location - setValue(lengthPtr, length, 'i32'); + setValue(lengthPtr, circuitBinary.length, 'i32'); // Return the pointer - return strPtr; + return ptr; }); emp::BristolFormat get_circuit() { int length = 0; - char* circuit_raw = get_circuit_raw(&length); + uint8_t* circuit_raw = get_circuit_raw(&length); emp::BristolFormat circuit; - circuit.from_str(circuit_raw); + circuit.from_buffer(circuit_raw, length); free(circuit_raw); return circuit; diff --git a/src/cpp/emp-tool/circuits/circuit_file.h b/src/cpp/emp-tool/circuits/circuit_file.h index 558bec0..05e8052 100644 --- a/src/cpp/emp-tool/circuits/circuit_file.h +++ b/src/cpp/emp-tool/circuits/circuit_file.h @@ -89,6 +89,78 @@ public: fout.close(); } + /* Consume the binary layout produced by bristolToBinary. + * ┌────────────┬──────────────────────────────────────────────────┐ + * │ bytes 0-19 │ five uint32 (num_gate, num_wire, n1, n2, n3) │ + * │ … │ repeated records │ + * │ │ 1 byte opcode (0 INV, 1 XOR, 2 AND) │ + * │ │ INV : 2 × uint32 (in , out) ── 9 B │ + * │ │ XOR/AND: 3 × uint32 (in1,in2,out) ── 13 B │ + * └────────────┴──────────────────────────────────────────────────┘ + * Any deviation throws std::runtime_error. + */ + void from_buffer(const uint8_t* buf, int size) { + auto need = [&](size_t n) { + if (n > static_cast(size)) + throw std::runtime_error("Buffer too small / truncated"); + }; + + auto read_u32 = [&](const uint8_t* p) -> uint32_t { + return static_cast(p[0]) | + (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | + (static_cast(p[3]) << 24); + }; + + /* ---------- header ---------- */ + need(20); + num_gate = static_cast(read_u32(buf + 0)); + num_wire = static_cast(read_u32(buf + 4)); + n1 = static_cast(read_u32(buf + 8)); + n2 = static_cast(read_u32(buf + 12)); + n3 = static_cast(read_u32(buf + 16)); + + gates.resize(num_gate * 4); + wires.resize(num_wire); + + size_t offset = 20; + for (int g = 0; g < num_gate; ++g) { + need(offset + 1); // opcode byte + uint8_t opcode = buf[offset++]; + switch (opcode) { + case 0: { // INV + need(offset + 8); + int in = static_cast(read_u32(buf + offset)); offset += 4; + int out = static_cast(read_u32(buf + offset)); offset += 4; + + gates[4 * g] = in; + gates[4 * g + 1] = 0; // unused + gates[4 * g + 2] = out; + gates[4 * g + 3] = NOT_GATE; + break; + } + case 1: // XOR + case 2: { // AND + need(offset + 12); + int in1 = static_cast(read_u32(buf + offset)); offset += 4; + int in2 = static_cast(read_u32(buf + offset)); offset += 4; + int out = static_cast(read_u32(buf + offset)); offset += 4; + + gates[4 * g] = in1; + gates[4 * g + 1] = in2; + gates[4 * g + 2] = out; + gates[4 * g + 3] = (opcode == 1) ? XOR_GATE : AND_GATE; + break; + } + default: + throw std::runtime_error("Unknown gate opcode"); + } + } + + if (offset != static_cast(size)) + throw std::runtime_error("Extra bytes after final gate"); + } + void compute(Bit* out, const Bit* in1, const Bit* in2) { compute((block*)out, (block*)in1, (block*)in2); } diff --git a/src/ts/appendWorker.ts b/src/ts/appendWorker.ts index 92de62c..baa13d1 100644 --- a/src/ts/appendWorker.ts +++ b/src/ts/appendWorker.ts @@ -2,7 +2,7 @@ import type { IO } from "./types"; type Module = { emp?: { - circuit?: string; + circuitBinary?: Uint8Array; inputBits?: Uint8Array; inputBitsPerParty?: number[]; io?: IO; @@ -22,18 +22,18 @@ declare const createModule: () => Promise * * @param party - The party index joining the computation (0, 1, .. N-1). * @param size - The number of parties in the computation. - * @param circuit - The circuit to run. + * @param circuitBinary - The circuit to run. * @param inputBits - The input bits for the circuit, represented as one bit per byte. * @param inputBitsPerParty - The number of input bits for each party. * @param io - Input/output channels for communication between the two parties. * @returns A promise resolving with the output bits of the circuit. */ async function secureMPC({ - party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto', + party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto', }: { party: number, size: number, - circuit: string, + circuitBinary: Uint8Array, inputBits: Uint8Array, inputBitsPerParty: number[], io: IO, @@ -48,7 +48,7 @@ async function secureMPC({ running = true; const emp: { - circuit?: string; + circuitBinary?: Uint8Array; inputBits?: Uint8Array; inputBitsPerParty?: number[]; io?: IO; @@ -58,7 +58,7 @@ async function secureMPC({ module.emp = emp; - emp.circuit = circuit; + emp.circuitBinary = circuitBinary; emp.inputBits = inputBits; emp.inputBitsPerParty = inputBitsPerParty; @@ -73,7 +73,7 @@ async function secureMPC({ recv: useRejector(io.recv.bind(io), reject), }; - const method = calculateMethod(mode, size, circuit); + const method = calculateMethod(mode, size, circuitBinary); const result = new Promise(async (resolve, reject) => { try { @@ -99,7 +99,7 @@ function calculateMethod( // Currently unused, but some 2-party circuits might perform better with // _runMPC - _circuit: string, + _circuitBinary: Uint8Array, ) { switch (mode) { case '2pc': @@ -133,7 +133,7 @@ onmessage = async (event) => { const message = event.data; if (message.type === 'start') { - const { party, size, circuit, inputBits, inputBitsPerParty, mode } = message; + const { party, size, circuitBinary, inputBits, inputBitsPerParty, mode } = message; // Create a proxy IO object to communicate with the main thread const io: IO = { @@ -153,7 +153,7 @@ onmessage = async (event) => { const result = await secureMPC({ party, size, - circuit, + circuitBinary, inputBits, inputBitsPerParty, io, diff --git a/src/ts/binaryToBristol.ts b/src/ts/binaryToBristol.ts new file mode 100644 index 0000000..05dd909 --- /dev/null +++ b/src/ts/binaryToBristol.ts @@ -0,0 +1,46 @@ +/** + * Decode the compact binary layout produced by `bristolToBinary` + * back into a textual Bristol format string. + * + * Strict: any malformed input triggers an Error. + */ +export default function binaryToBristol(bytes: Uint8Array): string { + if (bytes.byteLength < 20) throw new Error("Buffer shorter than 20-byte header"); + + const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength); + let offset = 0; + const u32 = () => { const v = view.getUint32(offset, true); offset += 4; return v; }; + const u8 = () => { const v = view.getUint8(offset); offset += 1; return v; }; + + /* ---------- header ---------- */ + const header = [u32(), u32(), u32(), u32(), u32()]; + const lines: string[] = [ + `${header[0]} ${header[1]}`, + `${header[2]} ${header[3]} ${header[4]}`, + '', + ]; + + /* ---------- gates ---------- */ + const CODE_TO_NAME = ["INV", "XOR", "AND"] as const; + const INPUT_COUNT = [1, 2, 2] as const; + const OUTPUT_COUNT = [1, 1, 1] as const; + while (offset < bytes.byteLength) { + const code = u8(); + if (code > 2) throw new Error(`Unknown gate code ${code} at byte ${offset - 1}`); + + const inCount = INPUT_COUNT[code]; + const outCount = OUTPUT_COUNT[code]; + const wires: number[] = []; + + for (let i = 0; i < inCount + outCount; i++) { + if (offset + 4 > bytes.byteLength) throw new Error("Truncated wire index"); + wires.push(u32()); + } + lines.push( + `${inCount} ${outCount} ${wires.join(" ")} ${CODE_TO_NAME[code]}` + ); + } + + if (offset !== bytes.byteLength) throw new Error("Extra bytes after final gate"); + return lines.join("\n"); +} diff --git a/src/ts/bristolToBinary.ts b/src/ts/bristolToBinary.ts new file mode 100644 index 0000000..28b6710 --- /dev/null +++ b/src/ts/bristolToBinary.ts @@ -0,0 +1,95 @@ +/** + * **Strict** converter from a (restricted) Bristol-format string to a compact + * binary representation. + * + * Layout (little-endian) + * ┌────────────┬─────────────────────────────────────────────┐ + * │ bytes 0-19 │ five 32-bit unsigned ints (exactly as given)│ + * │ … │ repeated gate records │ + * │ │ 1 byte gate-type (0 INV, 1 XOR, 2 AND) │ + * │ │ INV : 2 × uint32 (in, out) │ + * │ │ XOR/AND: 3 × uint32 (in1,in2,out) │ + * └────────────┴─────────────────────────────────────────────┘ + * + * Any deviation from the expected syntax throws an Error. + */ +export default function bristolToBinary(source: string): Uint8Array { + /* ---------- helpers ---------- */ + const toInt = (tok: string, ctx: string) => { + if (!/^-?\d+$/.test(tok)) throw new Error(`Expected integer for ${ctx}, got "${tok}"`); + return Number(tok); + }; + + /* ---------- split lines (keep blank lines for validation) ---------- */ + const rawLines = source.split(/\r?\n/); + if (rawLines.length < 3) throw new Error("Input too short – missing header or gates"); + + /* ---------- header ---------- */ + const h1 = rawLines[0].trim().split(/\s+/); + const h2 = rawLines[1].trim().split(/\s+/); + if (h1.length !== 2) throw new Error("Header line 1: expected exactly 2 numbers"); + if (h2.length !== 3) throw new Error("Header line 2: expected exactly 3 numbers"); + const header = [...h1, ...h2].map((t, i) => toInt(t, `header[${i}]`)); + + /* ---------- gate parsing ---------- */ + const GATE_CODE = { INV: 0, XOR: 1, AND: 2 } as const; + type Gate = { code: number; wires: number[] }; + + const gates: Gate[] = []; + for (let ln = 2; ln < rawLines.length; ln++) { + const line = rawLines[ln].trim(); + if (line === "") continue; // allow a single blank separator – still not “ignored” + const parts = line.split(/\s+/); + + if (parts.length < 5) throw new Error(`Line ${ln + 1}: too few tokens`); + + const inCount = toInt(parts[0], "input-count"); + const outCount = toInt(parts[1], "output-count"); + const gateType = parts[parts.length - 1] as keyof typeof GATE_CODE; + + if (!(gateType in GATE_CODE)) throw new Error(`Line ${ln + 1}: unknown gate type "${gateType}"`); + + /* verify the (k inputs, l outputs) pair agrees with the opcode */ + const expected = gateType === "INV" ? [1, 1] : [2, 1]; + if (inCount !== expected[0] || outCount !== expected[1]) { + throw new Error( + `Line ${ln + 1}: counts ${inCount}-in/${outCount}-out contradict gate type ${gateType}` + ); + } + + const wireTokens = parts.slice(2, 2 + inCount + outCount); + if (wireTokens.length !== inCount + outCount) { + throw new Error(`Line ${ln + 1}: expected ${inCount + outCount} wire indices`); + } + + const wires = wireTokens.map((t, i) => toInt(t, `wire[${i}]`)); + + /* ensure no trailing garbage */ + if (parts.length !== 2 + wires.length + 1) { + throw new Error(`Line ${ln + 1}: unexpected extra tokens`); + } + + gates.push({ code: GATE_CODE[gateType], wires }); + } + + if (gates.length === 0) throw new Error("No gate definitions found"); + + /* ---------- allocate & write ---------- */ + const byteLength = + 5 * 4 + + gates.reduce((sum, g) => sum + 1 + g.wires.length * 4, 0); + + const buf = new ArrayBuffer(byteLength); + const view = new DataView(buf); + let off = 0; + const w32 = (v: number) => { view.setUint32(off, v >>> 0, true); off += 4; }; + const w8 = (v: number) => { view.setUint8(off, v); off += 1; }; + + header.forEach(w32); + gates.forEach(({ code, wires }) => { + w8(code); + wires.forEach(w32); + }); + + return new Uint8Array(buf); +} diff --git a/src/ts/nodeSecureMPC.ts b/src/ts/nodeSecureMPC.ts index 1ec58f5..25f3247 100644 --- a/src/ts/nodeSecureMPC.ts +++ b/src/ts/nodeSecureMPC.ts @@ -5,18 +5,18 @@ import type { IO } from "./types"; * * @param party - The party index joining the computation (0, 1, .. N-1). * @param size - The number of parties in the computation. - * @param circuit - The circuit to run. + * @param circuitBinary - The circuit to run. * @param inputBits - The input to the circuit, represented as one bit per byte. * @param inputBitsPerParty - The number of input bits for each party. * @param io - Input/output channels for communication between the two parties. * @returns A promise resolving with the output bits of the circuit. */ export default async function nodeSecureMPC({ - party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto', + party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto', }: { party: number, size: number, - circuit: string, + circuitBinary: Uint8Array, inputBits: Uint8Array, inputBitsPerParty: number[], io: IO, @@ -29,7 +29,7 @@ export default async function nodeSecureMPC({ let module = await ((await import('../../build/jslib.js')).default()); const emp: { - circuit?: string; + circuitBinary?: Uint8Array; inputBits?: Uint8Array; inputBitsPerParty?: number[]; io?: IO; @@ -39,7 +39,7 @@ export default async function nodeSecureMPC({ module.emp = emp; - emp.circuit = circuit; + emp.circuitBinary = circuitBinary; emp.inputBits = inputBits; emp.inputBitsPerParty = inputBitsPerParty; @@ -54,7 +54,7 @@ export default async function nodeSecureMPC({ recv: useRejector(io.recv.bind(io), reject), }; - const method = calculateMethod(mode, size, circuit); + const method = calculateMethod(mode, size, circuitBinary); const result = await new Promise((resolve, reject) => { try { @@ -77,7 +77,7 @@ function calculateMethod( // Currently unused, but some 2-party circuits might perform better with // _runMPC - _circuit: string, + _circuitBinary: Uint8Array, ) { switch (mode) { case '2pc': diff --git a/src/ts/secureMPC.ts b/src/ts/secureMPC.ts index 58c9c69..9a48589 100644 --- a/src/ts/secureMPC.ts +++ b/src/ts/secureMPC.ts @@ -2,6 +2,7 @@ import { EventEmitter } from "ee-typed"; import type { IO } from "./types"; import workerCode from "./workerCode.js"; import nodeSecureMPC from "./nodeSecureMPC.js"; +import bristolToBinary from "./bristolToBinary.js"; export type SecureMPC = typeof secureMPC; @@ -29,9 +30,11 @@ export default function secureMPC({ io: IO, mode?: '2pc' | 'mpc' | 'auto', }): Promise { + const circuitBinary = bristolToBinary(circuit); + if (typeof Worker === 'undefined') { return nodeSecureMPC({ - party, size, circuit, inputBits, inputBitsPerParty, io, mode, + party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode, }); } @@ -48,7 +51,7 @@ export default function secureMPC({ type: 'start', party, size, - circuit, + circuitBinary, inputBits, inputBitsPerParty, mode, diff --git a/tests/bristolBinary.test.ts b/tests/bristolBinary.test.ts new file mode 100644 index 0000000..f31992b --- /dev/null +++ b/tests/bristolBinary.test.ts @@ -0,0 +1,49 @@ +import { expect } from "chai"; + +import bristolToBinary from '../src/ts/bristolToBinary'; +import binaryToBristol from '../src/ts/binaryToBristol'; + +const normalise = (s: string) => + s + .trim() + .replace(/\r?\n/g, "\n") // LF only + .split("\n") + .map(l => l.trimEnd()) // drop trailing spaces/tabs + .join("\n"); + +describe("Bristol ⇆ Binary round-trip", () => { + const samples: string[] = [ + `106601 107113 +512 0 160 + +1 1 177 749 INV +2 1 30 31 3599 XOR +1 1 55 4100 INV +1 1 62 4246 INV +1 1 83 3297 INV`, + // a smaller second sample to be sure different sizes round-trip + `2 3 +1 0 0 + +1 1 0 1 INV +2 1 0 1 2 AND`, + ]; + + samples.forEach((src, i) => { + it(`sample ${i + 1} should round-trip exactly`, () => { + const bin = bristolToBinary(src); + const text = binaryToBristol(bin); + expect(normalise(text)).to.equal(normalise(src)); + }); + }); + + it("decoding an opcode-corrupted buffer should throw", () => { + const good = bristolToBinary(samples[0]); + + // Byte 20 (index 20) is the first gate's opcode. 0/1/2 are valid; 0xFF is not. + const bad = new Uint8Array(good); + bad[20] = 0xFF; + + expect(() => binaryToBristol(bad)).to.throw(); + }); +});