Add 2pc specialization

This commit is contained in:
Andrew Morris
2025-02-03 15:17:41 +11:00
parent 63ff747be4
commit a76bc3cb23
2 changed files with 51 additions and 6 deletions

View File

@@ -5,6 +5,7 @@
#include <cstring>
#include "emp-tool/io/i_raw_io.h"
#include "emp-ag2pc/2pc.h"
#include "emp-agmpc/mpc.h"
void run_2pc_impl(int party, int nP);
@@ -253,7 +254,40 @@ void run_2pc_impl(int party, int nP) {
throw std::runtime_error("Invalid party number");
}
throw std::runtime_error("TODO: 2PC specialization");
try {
int other_party = (party == 1) ? 2 : 1;
auto io = emp::IOChannel(std::make_shared<RawIOJS>(other_party, 'a'));
auto circuit = get_circuit();
std::vector<bool> input_bits = get_input_bits();
{
size_t circuit_input_count = (party == 1) ? circuit.n1 : circuit.n2;
if (input_bits.size() != circuit_input_count) {
throw std::runtime_error("Mismatch between circuit and inputBits");
}
}
for (int p = 0; p < 2; p++) {
size_t input_count = get_input_bits_per_party(p);
size_t circuit_input_count = (p == 0) ? circuit.n1 : circuit.n2;
if (input_count != circuit_input_count) {
throw std::runtime_error("Mismatch between circuit and inputBitsPerParty");
}
}
auto twopc = emp::C2PC(io, party, &circuit);
twopc.function_independent();
twopc.function_dependent();
std::vector<bool> output_bits = twopc.online(input_bits, true);
handle_output_bits(output_bits);
} catch (const std::exception& e) {
handle_error(e.what());
}
}
void run_mpc_impl(int party, int nP) {

View File

@@ -2,8 +2,18 @@ import { expect } from 'chai';
import { BufferQueue, secureMPC } from "../src/ts"
describe('Secure MPC', () => {
it('3 + 5 == 8', async function () {
expect(await internalDemo(3, 5)).to.deep.equal({ alice: 8, bob: 8 });
it('3 + 5 == 8 (2pc)', async function () {
// Note: This tends to run a bit slower than mpc mode, but that's because
// of the cold start. Running mpc first is slower than running 2pc first.
expect(await internalDemo(3, 5, '2pc')).to.deep.equal({ alice: 8, bob: 8 });
});
it('3 + 5 == 8 (mpc)', async function () {
expect(await internalDemo(3, 5, 'mpc')).to.deep.equal({ alice: 8, bob: 8 });
});
it('3 + 5 == 8 (auto)', async function () {
expect(await internalDemo(3, 5, 'auto')).to.deep.equal({ alice: 8, bob: 8 });
});
it('3 + 5 == 8 (5 parties)', async function () {
@@ -28,7 +38,8 @@ class BufferQueueStore {
async function internalDemo(
aliceInput: number,
bobInput: number
bobInput: number,
mode: '2pc' | 'mpc' | 'auto' = 'auto',
): Promise<{ alice: number, bob: number }> {
const bqs = new BufferQueueStore();
@@ -49,7 +60,7 @@ async function internalDemo(
return bqs.get('bob', 'alice', channel).pop(len);
},
},
mode: 'mpc', // TODO: omit this (default to 'auto' which does 2pc)
mode,
}),
secureMPC({
party: 1,
@@ -67,7 +78,7 @@ async function internalDemo(
return bqs.get('alice', 'bob', channel).pop(len);
},
},
mode: 'mpc',
mode,
}),
]);