mirror of
https://github.com/privacy-scaling-explorations/emp-wasm.git
synced 2026-01-08 01:23:52 -05:00
Improve performance by producing a binary bristol format in JS so that we don't use emscripten's super slow istringstream
This commit is contained in:
@@ -187,31 +187,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
EM_JS(char*, get_circuit_raw, (int* lengthPtr), {
|
||||
if (!Module.emp?.circuit) {
|
||||
throw new Error("Module.emp.circuit is not defined in JavaScript.");
|
||||
EM_JS(uint8_t*, get_circuit_raw, (int* lengthPtr), {
|
||||
if (!Module.emp?.circuitBinary) {
|
||||
throw new Error("Module.emp.circuitBinary is not defined in JavaScript.");
|
||||
}
|
||||
|
||||
const circuitString = Module.emp.circuit; // Get the string from JavaScript
|
||||
const length = lengthBytesUTF8(circuitString) + 1; // Calculate length including the null terminator
|
||||
const circuitBinary = Module.emp.circuitBinary; // Get the string from JavaScript
|
||||
|
||||
// Allocate memory for the string
|
||||
const strPtr = Module._js_char_malloc(length);
|
||||
stringToUTF8(circuitString, strPtr, length);
|
||||
// Allocate memory
|
||||
const ptr = Module._js_malloc(circuitBinary.length);
|
||||
Module.HEAPU8.set(circuitBinary, ptr);
|
||||
|
||||
// Set the length at the provided pointer location
|
||||
setValue(lengthPtr, length, 'i32');
|
||||
setValue(lengthPtr, circuitBinary.length, 'i32');
|
||||
|
||||
// Return the pointer
|
||||
return strPtr;
|
||||
return ptr;
|
||||
});
|
||||
|
||||
emp::BristolFormat get_circuit() {
|
||||
int length = 0;
|
||||
char* circuit_raw = get_circuit_raw(&length);
|
||||
uint8_t* circuit_raw = get_circuit_raw(&length);
|
||||
|
||||
emp::BristolFormat circuit;
|
||||
circuit.from_str(circuit_raw);
|
||||
circuit.from_buffer(circuit_raw, length);
|
||||
free(circuit_raw);
|
||||
|
||||
return circuit;
|
||||
|
||||
@@ -89,6 +89,78 @@ public:
|
||||
fout.close();
|
||||
}
|
||||
|
||||
/* Consume the binary layout produced by bristolToBinary.
|
||||
* ┌────────────┬──────────────────────────────────────────────────┐
|
||||
* │ bytes 0-19 │ five uint32 (num_gate, num_wire, n1, n2, n3) │
|
||||
* │ … │ repeated records │
|
||||
* │ │ 1 byte opcode (0 INV, 1 XOR, 2 AND) │
|
||||
* │ │ INV : 2 × uint32 (in , out) ── 9 B │
|
||||
* │ │ XOR/AND: 3 × uint32 (in1,in2,out) ── 13 B │
|
||||
* └────────────┴──────────────────────────────────────────────────┘
|
||||
* Any deviation throws std::runtime_error.
|
||||
*/
|
||||
void from_buffer(const uint8_t* buf, int size) {
|
||||
auto need = [&](size_t n) {
|
||||
if (n > static_cast<size_t>(size))
|
||||
throw std::runtime_error("Buffer too small / truncated");
|
||||
};
|
||||
|
||||
auto read_u32 = [&](const uint8_t* p) -> uint32_t {
|
||||
return static_cast<uint32_t>(p[0]) |
|
||||
(static_cast<uint32_t>(p[1]) << 8) |
|
||||
(static_cast<uint32_t>(p[2]) << 16) |
|
||||
(static_cast<uint32_t>(p[3]) << 24);
|
||||
};
|
||||
|
||||
/* ---------- header ---------- */
|
||||
need(20);
|
||||
num_gate = static_cast<int>(read_u32(buf + 0));
|
||||
num_wire = static_cast<int>(read_u32(buf + 4));
|
||||
n1 = static_cast<int>(read_u32(buf + 8));
|
||||
n2 = static_cast<int>(read_u32(buf + 12));
|
||||
n3 = static_cast<int>(read_u32(buf + 16));
|
||||
|
||||
gates.resize(num_gate * 4);
|
||||
wires.resize(num_wire);
|
||||
|
||||
size_t offset = 20;
|
||||
for (int g = 0; g < num_gate; ++g) {
|
||||
need(offset + 1); // opcode byte
|
||||
uint8_t opcode = buf[offset++];
|
||||
switch (opcode) {
|
||||
case 0: { // INV
|
||||
need(offset + 8);
|
||||
int in = static_cast<int>(read_u32(buf + offset)); offset += 4;
|
||||
int out = static_cast<int>(read_u32(buf + offset)); offset += 4;
|
||||
|
||||
gates[4 * g] = in;
|
||||
gates[4 * g + 1] = 0; // unused
|
||||
gates[4 * g + 2] = out;
|
||||
gates[4 * g + 3] = NOT_GATE;
|
||||
break;
|
||||
}
|
||||
case 1: // XOR
|
||||
case 2: { // AND
|
||||
need(offset + 12);
|
||||
int in1 = static_cast<int>(read_u32(buf + offset)); offset += 4;
|
||||
int in2 = static_cast<int>(read_u32(buf + offset)); offset += 4;
|
||||
int out = static_cast<int>(read_u32(buf + offset)); offset += 4;
|
||||
|
||||
gates[4 * g] = in1;
|
||||
gates[4 * g + 1] = in2;
|
||||
gates[4 * g + 2] = out;
|
||||
gates[4 * g + 3] = (opcode == 1) ? XOR_GATE : AND_GATE;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown gate opcode");
|
||||
}
|
||||
}
|
||||
|
||||
if (offset != static_cast<size_t>(size))
|
||||
throw std::runtime_error("Extra bytes after final gate");
|
||||
}
|
||||
|
||||
void compute(Bit* out, const Bit* in1, const Bit* in2) {
|
||||
compute((block*)out, (block*)in1, (block*)in2);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { IO } from "./types";
|
||||
|
||||
type Module = {
|
||||
emp?: {
|
||||
circuit?: string;
|
||||
circuitBinary?: Uint8Array;
|
||||
inputBits?: Uint8Array;
|
||||
inputBitsPerParty?: number[];
|
||||
io?: IO;
|
||||
@@ -22,18 +22,18 @@ declare const createModule: () => Promise<Module>
|
||||
*
|
||||
* @param party - The party index joining the computation (0, 1, .. N-1).
|
||||
* @param size - The number of parties in the computation.
|
||||
* @param circuit - The circuit to run.
|
||||
* @param circuitBinary - The circuit to run.
|
||||
* @param inputBits - The input bits for the circuit, represented as one bit per byte.
|
||||
* @param inputBitsPerParty - The number of input bits for each party.
|
||||
* @param io - Input/output channels for communication between the two parties.
|
||||
* @returns A promise resolving with the output bits of the circuit.
|
||||
*/
|
||||
async function secureMPC({
|
||||
party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto',
|
||||
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto',
|
||||
}: {
|
||||
party: number,
|
||||
size: number,
|
||||
circuit: string,
|
||||
circuitBinary: Uint8Array,
|
||||
inputBits: Uint8Array,
|
||||
inputBitsPerParty: number[],
|
||||
io: IO,
|
||||
@@ -48,7 +48,7 @@ async function secureMPC({
|
||||
running = true;
|
||||
|
||||
const emp: {
|
||||
circuit?: string;
|
||||
circuitBinary?: Uint8Array;
|
||||
inputBits?: Uint8Array;
|
||||
inputBitsPerParty?: number[];
|
||||
io?: IO;
|
||||
@@ -58,7 +58,7 @@ async function secureMPC({
|
||||
|
||||
module.emp = emp;
|
||||
|
||||
emp.circuit = circuit;
|
||||
emp.circuitBinary = circuitBinary;
|
||||
emp.inputBits = inputBits;
|
||||
emp.inputBitsPerParty = inputBitsPerParty;
|
||||
|
||||
@@ -73,7 +73,7 @@ async function secureMPC({
|
||||
recv: useRejector(io.recv.bind(io), reject),
|
||||
};
|
||||
|
||||
const method = calculateMethod(mode, size, circuit);
|
||||
const method = calculateMethod(mode, size, circuitBinary);
|
||||
|
||||
const result = new Promise<Uint8Array>(async (resolve, reject) => {
|
||||
try {
|
||||
@@ -99,7 +99,7 @@ function calculateMethod(
|
||||
|
||||
// Currently unused, but some 2-party circuits might perform better with
|
||||
// _runMPC
|
||||
_circuit: string,
|
||||
_circuitBinary: Uint8Array,
|
||||
) {
|
||||
switch (mode) {
|
||||
case '2pc':
|
||||
@@ -133,7 +133,7 @@ onmessage = async (event) => {
|
||||
const message = event.data;
|
||||
|
||||
if (message.type === 'start') {
|
||||
const { party, size, circuit, inputBits, inputBitsPerParty, mode } = message;
|
||||
const { party, size, circuitBinary, inputBits, inputBitsPerParty, mode } = message;
|
||||
|
||||
// Create a proxy IO object to communicate with the main thread
|
||||
const io: IO = {
|
||||
@@ -153,7 +153,7 @@ onmessage = async (event) => {
|
||||
const result = await secureMPC({
|
||||
party,
|
||||
size,
|
||||
circuit,
|
||||
circuitBinary,
|
||||
inputBits,
|
||||
inputBitsPerParty,
|
||||
io,
|
||||
|
||||
46
src/ts/binaryToBristol.ts
Normal file
46
src/ts/binaryToBristol.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Decode the compact binary layout produced by `bristolToBinary`
|
||||
* back into a textual Bristol format string.
|
||||
*
|
||||
* Strict: any malformed input triggers an Error.
|
||||
*/
|
||||
export default function binaryToBristol(bytes: Uint8Array): string {
|
||||
if (bytes.byteLength < 20) throw new Error("Buffer shorter than 20-byte header");
|
||||
|
||||
const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength);
|
||||
let offset = 0;
|
||||
const u32 = () => { const v = view.getUint32(offset, true); offset += 4; return v; };
|
||||
const u8 = () => { const v = view.getUint8(offset); offset += 1; return v; };
|
||||
|
||||
/* ---------- header ---------- */
|
||||
const header = [u32(), u32(), u32(), u32(), u32()];
|
||||
const lines: string[] = [
|
||||
`${header[0]} ${header[1]}`,
|
||||
`${header[2]} ${header[3]} ${header[4]}`,
|
||||
'',
|
||||
];
|
||||
|
||||
/* ---------- gates ---------- */
|
||||
const CODE_TO_NAME = ["INV", "XOR", "AND"] as const;
|
||||
const INPUT_COUNT = [1, 2, 2] as const;
|
||||
const OUTPUT_COUNT = [1, 1, 1] as const;
|
||||
while (offset < bytes.byteLength) {
|
||||
const code = u8();
|
||||
if (code > 2) throw new Error(`Unknown gate code ${code} at byte ${offset - 1}`);
|
||||
|
||||
const inCount = INPUT_COUNT[code];
|
||||
const outCount = OUTPUT_COUNT[code];
|
||||
const wires: number[] = [];
|
||||
|
||||
for (let i = 0; i < inCount + outCount; i++) {
|
||||
if (offset + 4 > bytes.byteLength) throw new Error("Truncated wire index");
|
||||
wires.push(u32());
|
||||
}
|
||||
lines.push(
|
||||
`${inCount} ${outCount} ${wires.join(" ")} ${CODE_TO_NAME[code]}`
|
||||
);
|
||||
}
|
||||
|
||||
if (offset !== bytes.byteLength) throw new Error("Extra bytes after final gate");
|
||||
return lines.join("\n");
|
||||
}
|
||||
95
src/ts/bristolToBinary.ts
Normal file
95
src/ts/bristolToBinary.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
/**
|
||||
* **Strict** converter from a (restricted) Bristol-format string to a compact
|
||||
* binary representation.
|
||||
*
|
||||
* Layout (little-endian)
|
||||
* ┌────────────┬─────────────────────────────────────────────┐
|
||||
* │ bytes 0-19 │ five 32-bit unsigned ints (exactly as given)│
|
||||
* │ … │ repeated gate records │
|
||||
* │ │ 1 byte gate-type (0 INV, 1 XOR, 2 AND) │
|
||||
* │ │ INV : 2 × uint32 (in, out) │
|
||||
* │ │ XOR/AND: 3 × uint32 (in1,in2,out) │
|
||||
* └────────────┴─────────────────────────────────────────────┘
|
||||
*
|
||||
* Any deviation from the expected syntax throws an Error.
|
||||
*/
|
||||
export default function bristolToBinary(source: string): Uint8Array {
|
||||
/* ---------- helpers ---------- */
|
||||
const toInt = (tok: string, ctx: string) => {
|
||||
if (!/^-?\d+$/.test(tok)) throw new Error(`Expected integer for ${ctx}, got "${tok}"`);
|
||||
return Number(tok);
|
||||
};
|
||||
|
||||
/* ---------- split lines (keep blank lines for validation) ---------- */
|
||||
const rawLines = source.split(/\r?\n/);
|
||||
if (rawLines.length < 3) throw new Error("Input too short – missing header or gates");
|
||||
|
||||
/* ---------- header ---------- */
|
||||
const h1 = rawLines[0].trim().split(/\s+/);
|
||||
const h2 = rawLines[1].trim().split(/\s+/);
|
||||
if (h1.length !== 2) throw new Error("Header line 1: expected exactly 2 numbers");
|
||||
if (h2.length !== 3) throw new Error("Header line 2: expected exactly 3 numbers");
|
||||
const header = [...h1, ...h2].map((t, i) => toInt(t, `header[${i}]`));
|
||||
|
||||
/* ---------- gate parsing ---------- */
|
||||
const GATE_CODE = { INV: 0, XOR: 1, AND: 2 } as const;
|
||||
type Gate = { code: number; wires: number[] };
|
||||
|
||||
const gates: Gate[] = [];
|
||||
for (let ln = 2; ln < rawLines.length; ln++) {
|
||||
const line = rawLines[ln].trim();
|
||||
if (line === "") continue; // allow a single blank separator – still not “ignored”
|
||||
const parts = line.split(/\s+/);
|
||||
|
||||
if (parts.length < 5) throw new Error(`Line ${ln + 1}: too few tokens`);
|
||||
|
||||
const inCount = toInt(parts[0], "input-count");
|
||||
const outCount = toInt(parts[1], "output-count");
|
||||
const gateType = parts[parts.length - 1] as keyof typeof GATE_CODE;
|
||||
|
||||
if (!(gateType in GATE_CODE)) throw new Error(`Line ${ln + 1}: unknown gate type "${gateType}"`);
|
||||
|
||||
/* verify the (k inputs, l outputs) pair agrees with the opcode */
|
||||
const expected = gateType === "INV" ? [1, 1] : [2, 1];
|
||||
if (inCount !== expected[0] || outCount !== expected[1]) {
|
||||
throw new Error(
|
||||
`Line ${ln + 1}: counts ${inCount}-in/${outCount}-out contradict gate type ${gateType}`
|
||||
);
|
||||
}
|
||||
|
||||
const wireTokens = parts.slice(2, 2 + inCount + outCount);
|
||||
if (wireTokens.length !== inCount + outCount) {
|
||||
throw new Error(`Line ${ln + 1}: expected ${inCount + outCount} wire indices`);
|
||||
}
|
||||
|
||||
const wires = wireTokens.map((t, i) => toInt(t, `wire[${i}]`));
|
||||
|
||||
/* ensure no trailing garbage */
|
||||
if (parts.length !== 2 + wires.length + 1) {
|
||||
throw new Error(`Line ${ln + 1}: unexpected extra tokens`);
|
||||
}
|
||||
|
||||
gates.push({ code: GATE_CODE[gateType], wires });
|
||||
}
|
||||
|
||||
if (gates.length === 0) throw new Error("No gate definitions found");
|
||||
|
||||
/* ---------- allocate & write ---------- */
|
||||
const byteLength =
|
||||
5 * 4 +
|
||||
gates.reduce((sum, g) => sum + 1 + g.wires.length * 4, 0);
|
||||
|
||||
const buf = new ArrayBuffer(byteLength);
|
||||
const view = new DataView(buf);
|
||||
let off = 0;
|
||||
const w32 = (v: number) => { view.setUint32(off, v >>> 0, true); off += 4; };
|
||||
const w8 = (v: number) => { view.setUint8(off, v); off += 1; };
|
||||
|
||||
header.forEach(w32);
|
||||
gates.forEach(({ code, wires }) => {
|
||||
w8(code);
|
||||
wires.forEach(w32);
|
||||
});
|
||||
|
||||
return new Uint8Array(buf);
|
||||
}
|
||||
@@ -5,18 +5,18 @@ import type { IO } from "./types";
|
||||
*
|
||||
* @param party - The party index joining the computation (0, 1, .. N-1).
|
||||
* @param size - The number of parties in the computation.
|
||||
* @param circuit - The circuit to run.
|
||||
* @param circuitBinary - The circuit to run.
|
||||
* @param inputBits - The input to the circuit, represented as one bit per byte.
|
||||
* @param inputBitsPerParty - The number of input bits for each party.
|
||||
* @param io - Input/output channels for communication between the two parties.
|
||||
* @returns A promise resolving with the output bits of the circuit.
|
||||
*/
|
||||
export default async function nodeSecureMPC({
|
||||
party, size, circuit, inputBits, inputBitsPerParty, io, mode = 'auto',
|
||||
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode = 'auto',
|
||||
}: {
|
||||
party: number,
|
||||
size: number,
|
||||
circuit: string,
|
||||
circuitBinary: Uint8Array,
|
||||
inputBits: Uint8Array,
|
||||
inputBitsPerParty: number[],
|
||||
io: IO,
|
||||
@@ -29,7 +29,7 @@ export default async function nodeSecureMPC({
|
||||
let module = await ((await import('../../build/jslib.js')).default());
|
||||
|
||||
const emp: {
|
||||
circuit?: string;
|
||||
circuitBinary?: Uint8Array;
|
||||
inputBits?: Uint8Array;
|
||||
inputBitsPerParty?: number[];
|
||||
io?: IO;
|
||||
@@ -39,7 +39,7 @@ export default async function nodeSecureMPC({
|
||||
|
||||
module.emp = emp;
|
||||
|
||||
emp.circuit = circuit;
|
||||
emp.circuitBinary = circuitBinary;
|
||||
emp.inputBits = inputBits;
|
||||
emp.inputBitsPerParty = inputBitsPerParty;
|
||||
|
||||
@@ -54,7 +54,7 @@ export default async function nodeSecureMPC({
|
||||
recv: useRejector(io.recv.bind(io), reject),
|
||||
};
|
||||
|
||||
const method = calculateMethod(mode, size, circuit);
|
||||
const method = calculateMethod(mode, size, circuitBinary);
|
||||
|
||||
const result = await new Promise<Uint8Array>((resolve, reject) => {
|
||||
try {
|
||||
@@ -77,7 +77,7 @@ function calculateMethod(
|
||||
|
||||
// Currently unused, but some 2-party circuits might perform better with
|
||||
// _runMPC
|
||||
_circuit: string,
|
||||
_circuitBinary: Uint8Array,
|
||||
) {
|
||||
switch (mode) {
|
||||
case '2pc':
|
||||
|
||||
@@ -2,6 +2,7 @@ import { EventEmitter } from "ee-typed";
|
||||
import type { IO } from "./types";
|
||||
import workerCode from "./workerCode.js";
|
||||
import nodeSecureMPC from "./nodeSecureMPC.js";
|
||||
import bristolToBinary from "./bristolToBinary.js";
|
||||
|
||||
export type SecureMPC = typeof secureMPC;
|
||||
|
||||
@@ -29,9 +30,11 @@ export default function secureMPC({
|
||||
io: IO,
|
||||
mode?: '2pc' | 'mpc' | 'auto',
|
||||
}): Promise<Uint8Array> {
|
||||
const circuitBinary = bristolToBinary(circuit);
|
||||
|
||||
if (typeof Worker === 'undefined') {
|
||||
return nodeSecureMPC({
|
||||
party, size, circuit, inputBits, inputBitsPerParty, io, mode,
|
||||
party, size, circuitBinary, inputBits, inputBitsPerParty, io, mode,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -48,7 +51,7 @@ export default function secureMPC({
|
||||
type: 'start',
|
||||
party,
|
||||
size,
|
||||
circuit,
|
||||
circuitBinary,
|
||||
inputBits,
|
||||
inputBitsPerParty,
|
||||
mode,
|
||||
|
||||
49
tests/bristolBinary.test.ts
Normal file
49
tests/bristolBinary.test.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { expect } from "chai";
|
||||
|
||||
import bristolToBinary from '../src/ts/bristolToBinary';
|
||||
import binaryToBristol from '../src/ts/binaryToBristol';
|
||||
|
||||
const normalise = (s: string) =>
|
||||
s
|
||||
.trim()
|
||||
.replace(/\r?\n/g, "\n") // LF only
|
||||
.split("\n")
|
||||
.map(l => l.trimEnd()) // drop trailing spaces/tabs
|
||||
.join("\n");
|
||||
|
||||
describe("Bristol ⇆ Binary round-trip", () => {
|
||||
const samples: string[] = [
|
||||
`106601 107113
|
||||
512 0 160
|
||||
|
||||
1 1 177 749 INV
|
||||
2 1 30 31 3599 XOR
|
||||
1 1 55 4100 INV
|
||||
1 1 62 4246 INV
|
||||
1 1 83 3297 INV`,
|
||||
// a smaller second sample to be sure different sizes round-trip
|
||||
`2 3
|
||||
1 0 0
|
||||
|
||||
1 1 0 1 INV
|
||||
2 1 0 1 2 AND`,
|
||||
];
|
||||
|
||||
samples.forEach((src, i) => {
|
||||
it(`sample ${i + 1} should round-trip exactly`, () => {
|
||||
const bin = bristolToBinary(src);
|
||||
const text = binaryToBristol(bin);
|
||||
expect(normalise(text)).to.equal(normalise(src));
|
||||
});
|
||||
});
|
||||
|
||||
it("decoding an opcode-corrupted buffer should throw", () => {
|
||||
const good = bristolToBinary(samples[0]);
|
||||
|
||||
// Byte 20 (index 20) is the first gate's opcode. 0/1/2 are valid; 0xFF is not.
|
||||
const bad = new Uint8Array(good);
|
||||
bad[20] = 0xFF;
|
||||
|
||||
expect(() => binaryToBristol(bad)).to.throw();
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user