Draft new jslib.cpp

This commit is contained in:
Andrew Morris
2025-01-29 16:06:31 +11:00
parent c0a13b7632
commit 7f4f3d2ff9
5 changed files with 166 additions and 113 deletions

View File

@@ -5,12 +5,12 @@
#include <cstring>
#include "emp-tool/io/i_raw_io.h"
#include "emp-ag2pc/2pc.h"
#include "emp-agmpc/mpc.h"
void run_impl(int party);
void run_impl(int party, int nP);
// Implement send_js function to send data from C++ to JavaScript
EM_JS(void, send_js, (const void* data, size_t len), {
EM_JS(void, send_js, (int party2, char channel_label, const void* data, size_t len), {
if (!Module.emp?.io?.send) {
throw new Error("Module.emp.io.send is not defined in JavaScript.");
}
@@ -18,18 +18,18 @@ EM_JS(void, send_js, (const void* data, size_t len), {
// Copy data from WebAssembly memory to a JavaScript Uint8Array
const dataArray = HEAPU8.slice(data, data + len);
Module.emp.io.send(dataArray);
Module.emp.io.send(party2 - 1, channel_label, dataArray);
});
// Implement recv_js function to receive data from JavaScript to C++
EM_ASYNC_JS(void, recv_js, (void* data, size_t len), {
EM_ASYNC_JS(void, recv_js, (int party2, char channel_label, void* data, size_t len), {
if (!Module.emp?.io?.recv) {
reject(new Error("Module.emp.io.recv is not defined in JavaScript."));
return;
}
// Wait for data from JavaScript
const dataArray = await Module.emp.io.recv(arguments[1]);
const dataArray = await Module.emp.io.recv(party2 - 1, channel_label, arguments[1]);
// Copy data from JavaScript Uint8Array to WebAssembly memory
HEAPU8.set(dataArray, data);
@@ -37,12 +37,23 @@ EM_ASYNC_JS(void, recv_js, (void* data, size_t len), {
class RawIOJS : public IRawIO {
public:
int party2;
char channel_label;
RawIOJS(
int party2,
char channel_label
):
party2(party2),
channel_label(channel_label)
{}
void send(const void* data, size_t len) override {
send_js(data, len);
send_js(party2, channel_label, data, len);
}
void recv(void* data, size_t len) override {
recv_js(data, len);
recv_js(party2, channel_label, data, len);
}
void flush() override {
@@ -50,6 +61,53 @@ public:
}
};
class MultiIOJS : public IMultiIO {
public:
int mParty;
int nP;
std::vector<emp::IOChannel> a_channels;
std::vector<emp::IOChannel> b_channels;
MultiIOJS(int party, int nP) : mParty(party), nP(nP) {
for (int i = 1; i <= nP; i++) {
a_channels.emplace_back(std::make_shared<RawIOJS>(i, 'a'));
b_channels.emplace_back(std::make_shared<RawIOJS>(i, 'b'));
}
}
int size() override {
return nP;
}
int party() override {
return mParty;
}
emp::IOChannel& a_channel(int party2) override {
assert(party2 != 0);
assert(party2 != party());
return a_channels[party2];
}
emp::IOChannel& b_channel(int party2) override {
assert(party2 != 0);
assert(party2 != party());
return b_channels[party2];
}
void flush(int idx) override {
assert(idx != 0);
if (party() < idx)
a_channels[idx].flush();
else
b_channels[idx].flush();
}
};
EM_JS(char*, get_circuit_raw, (int* lengthPtr), {
if (!Module.emp?.circuit) {
throw new Error("Module.emp.circuit is not defined in JavaScript.");
@@ -113,6 +171,14 @@ std::vector<bool> get_input_bits() {
return input_bits;
}
EM_JS(int, get_input_bits_start, (), {
if (!Module.emp?.inputBitsStart) {
throw new Error("Module.emp.inputBitsStart is not defined in JavaScript.");
}
return Module.emp.inputBitsStart;
});
EM_JS(void, handle_output_bits_raw, (uint8_t* outputBits, int length), {
if (!Module.emp?.handleOutput) {
throw new Error("Module.emp.handleOutput is not defined in JavaScript.");
@@ -139,8 +205,8 @@ void handle_output_bits(const std::vector<bool>& output_bits) {
extern "C" {
EMSCRIPTEN_KEEPALIVE
void run(int party) {
run_impl(party);
void run(int party, int size) {
run_impl(party + 1, size);
}
EMSCRIPTEN_KEEPALIVE
@@ -154,18 +220,41 @@ extern "C" {
}
}
void run_impl(int party) {
auto io = emp::IOChannel(std::make_shared<RawIOJS>());
void run_impl(int party, int nP) {
std::shared_ptr<IMultiIO> io = std::make_shared<MultiIOJS>(party, nP);
auto circuit = get_circuit();
auto twopc = emp::C2PC(io, party, &circuit);
auto mpc = CMPC(io, &circuit);
twopc.function_independent();
twopc.function_dependent();
mpc.function_independent();
mpc.function_dependent();
std::vector<bool> input_bits = get_input_bits();
int input_bits_start = get_input_bits_start();
FlexIn input(nP, circuit.n1 + circuit.n2, party);
for (size_t i = 0; i < input_bits.size(); i++) {
size_t x = i + input_bits_start;
input.assign_party(x, party);
input.assign_plaintext_bit(x, input_bits[i]);
}
FlexOut output(nP, circuit.n3, party);
for (int i = 0; i < circuit.n3; i++) {
// All parties receive the output.
output.assign_party(i, 0);
}
mpc.online(&input, &output);
std::vector<bool> output_bits;
for (int i = 0; i < circuit.n3; i++) {
output_bits.push_back(output.get_plaintext_bit(i));
}
std::vector<bool> output_bits = twopc.online(input_bits, true);
handle_output_bits(output_bits);
}

View File

@@ -11,10 +11,10 @@
#include "vec.h"
class ABitMP { public:
std::shared_ptr<IMultiIO> io;
int nP;
Vec<std::optional<IKNP>> abit1;
Vec<std::optional<IKNP>> abit2;
std::shared_ptr<IMultiIO> io;
int party;
PRG prg;
block Delta;
@@ -105,11 +105,11 @@ class ABitMP { public:
block seed = sampleRandom(nP, *io, &prg, party);
PRG prg2(&seed);
uint8_t * tmp;
block * Ms[nP+1];
bool * bs[nP+1];
block * Ks[nP+1];
block * tMs[nP+1];
bool * tbs[nP+1];
NVec<block> Ms(nP+1, ssp);
NVec<bool> bs(nP+1, ssp);
NVec<block> Ks(nP+1, ssp);
NVec<block> tMs(nP+1, ssp);
NVec<bool> tbs(nP+1, ssp);
tmp = new uint8_t[ssp*length];
prg2.random_data(tmp, ssp*length);
@@ -120,16 +120,6 @@ class ABitMP { public:
// for(int k = 0; k < ssp; ++k)
// tmp[j][length - ssp + k] = (k == j);
// }
for(int i = 1; i <= nP; ++i) {
Ms[i] = new block[ssp];
Ks[i] = new block[ssp];
bs[i] = new bool[ssp];
memset(Ms[i], 0, ssp*sizeof(block));
memset(Ks[i], 0, ssp*sizeof(block));
memset(bs[i], false, ssp);
tMs[i] = new block[ssp];
tbs[i] = new bool[ssp];
}
const int chk = 1;
const int SIZE = 1024*2;
@@ -158,9 +148,9 @@ class ABitMP { public:
tKEY[(i-start)/chk][2] = KEY.at(k, i+1);
tKEY[(i-start)/chk][3] = KEY.at(k, i) ^ KEY.at(k, i+1);
for(int j = 0; j < ssp; ++j) {
Ms[k][j] = Ms[k][j] ^ tMAC[(i-start)/chk][*tmpptr];
Ks[k][j] = Ks[k][j] ^ tKEY[(i-start)/chk][*tmpptr];
bs[k][j] = bs[k][j] != tb[i/chk][*tmpptr];
Ms.at(k, j) = Ms.at(k, j) ^ tMAC[(i-start)/chk][*tmpptr];
Ks.at(k, j) = Ks.at(k, j) ^ tKEY[(i-start)/chk][*tmpptr];
bs.at(k, j) = bs.at(k, j) != tb[i/chk][*tmpptr];
++tmpptr;
}
}
@@ -172,56 +162,41 @@ class ABitMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
get_send_channel(*io, party2).send_data(Ms[party2], sizeof(block)*ssp);
get_send_channel(*io, party2).send_data(bs[party2], ssp);
get_send_channel(*io, party2).send_data(&Ms.at(party2, 0), sizeof(block)*ssp);
get_send_channel(*io, party2).send_data(&bs.at(party2, 0), ssp);
io->flush(party2);
res.push_back(false);
get_recv_channel(*io, party2).recv_data(tMs[party2], sizeof(block)*ssp);
get_recv_channel(*io, party2).recv_data(tbs[party2], ssp);
get_recv_channel(*io, party2).recv_data(&tMs.at(party2, 0), sizeof(block)*ssp);
get_recv_channel(*io, party2).recv_data(&tbs.at(party2, 0), ssp);
for(int k = 0; k < ssp; ++k) {
if(tbs[party2][k])
Ks[party2][k] = Ks[party2][k] ^ Delta;
if(tbs.at(party2, k))
Ks.at(party2, k) = Ks.at(party2, k) ^ Delta;
}
res.push_back(!cmpBlock(Ks[party2], tMs[party2], ssp));
res.push_back(!cmpBlock(&Ks.at(party2, 0), &tMs.at(party2, 0), ssp));
}
if(checkCheat(res)) error("cheat check1\n");
for(int i = 1; i <= nP; ++i) {
delete[] Ms[i];
delete[] Ks[i];
delete[] bs[i];
delete[] tMs[i];
delete[] tbs[i];
}
}
void check2(const NVec<block>& MAC, const NVec<block> KEY, bool* data, int length) {
//last 2*ssp are garbage already.
block * Ks[2], *Ms[nP+1][nP+1];
block * KK[nP+1];
bool * bs[nP+1];
Ks[0] = new block[ssp];
Ks[1] = new block[ssp];
for(int i = 1; i <= nP; ++i) {
bs[i] = new bool[ssp];
KK[i] = new block[ssp];
for(int j = 1; j <= nP; ++j)
Ms[i][j] = new block[ssp];
}
NVec<block> Ks(2, ssp);
NVec<block> Ms(nP+1, nP+1, ssp);
NVec<block> KK(nP+1, ssp);
NVec<bool> bs(nP+1, ssp);
char (*dgst)[Hash::DIGEST_SIZE] = new char[nP+1][Hash::DIGEST_SIZE];
char (*dgst0)[Hash::DIGEST_SIZE] = new char[ssp*(nP+1)][Hash::DIGEST_SIZE];
char (*dgst1)[Hash::DIGEST_SIZE] = new char[ssp*(nP+1)][Hash::DIGEST_SIZE];
for(int i = 0; i < ssp; ++i) {
Ks[0][i] = zero_block;
Ks.at(0, i) = zero_block;
for(int j = 1; j <= nP; ++j) if(j != party)
Ks[0][i] = Ks[0][i] ^ KEY.at(j, length-3*ssp+i);
Ks.at(0, i) = Ks.at(0, i) ^ KEY.at(j, length-3*ssp+i);
Ks[1][i] = Ks[0][i] ^ Delta;
Hash::hash_once(dgst0[party*ssp+i], &Ks[0][i], sizeof(block));
Hash::hash_once(dgst1[party*ssp+i], &Ks[1][i], sizeof(block));
Ks.at(1, i) = Ks.at(0, i) ^ Delta;
Hash::hash_once(dgst0[party*ssp+i], &Ks.at(0, i), sizeof(block));
Hash::hash_once(dgst1[party*ssp+i], &Ks.at(1, i), sizeof(block));
}
Hash h;
h.put(data+length-3*ssp, ssp);
@@ -242,7 +217,7 @@ class ABitMP { public:
vector<bool> res2;
for(int k = 1; k <= nP; ++k) if(k!= party)
memcpy(Ms[party][k], &MAC.at(k, length-3*ssp), sizeof(block)*ssp);
memcpy(&Ms.at(party, k, 0), &MAC.at(k, length-3*ssp), sizeof(block)*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
@@ -253,31 +228,31 @@ class ABitMP { public:
res2.push_back(false);
Hash h;
get_recv_channel(*io, party2).recv_data(bs[party2], ssp);
h.put(bs[party2], ssp);
get_recv_channel(*io, party2).recv_data(&bs.at(party2, 0), ssp);
h.put(&bs.at(party2, 0), ssp);
for(int k = 1; k <= nP; ++k) if(k != party2) {
get_recv_channel(*io, party2).recv_data(Ms[party2][k], sizeof(block)*ssp);
h.put(Ms[party2][k], sizeof(block)*ssp);
get_recv_channel(*io, party2).recv_data(&Ms.at(party2, k, 0), sizeof(block)*ssp);
h.put(&Ms.at(party2, k, 0), sizeof(block)*ssp);
}
char tmp[Hash::DIGEST_SIZE];h.digest(tmp);
res2.push_back(strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE) != 0);
}
if(checkCheat(res2)) error("commitment 1\n");
memset(bs[party], false, ssp);
memset(&bs.at(party, 0), false, ssp);
for(int i = 1; i <= nP; ++i) if(i != party) {
for(int j = 0; j < ssp; ++j)
bs[party][j] = bs[party][j] != bs[i][j];
bs.at(party, j) = bs.at(party, j) != bs.at(i, j);
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
get_send_channel(*io, party2).send_data(bs[party], ssp);
get_send_channel(*io, party2).send_data(&bs.at(party, 0), ssp);
for(int i = 0; i < ssp; ++i) {
if(bs[party][i])
get_send_channel(*io, party2).send_data(&Ks[1][i], sizeof(block));
if (bs.at(party, i))
get_send_channel(*io, party2).send_data(&Ks.at(1, i), sizeof(block));
else
get_send_channel(*io, party2).send_data(&Ks[0][i], sizeof(block));
get_send_channel(*io, party2).send_data(&Ks.at(0, i), sizeof(block));
}
io->flush(party2);
res2.push_back(false);
@@ -285,10 +260,10 @@ class ABitMP { public:
bool cheat = false;
bool *tmp_bool = new bool[ssp];
get_recv_channel(*io, party2).recv_data(tmp_bool, ssp);
get_recv_channel(*io, party2).recv_data(KK[party2], ssp*sizeof(block));
get_recv_channel(*io, party2).recv_data(&KK.at(party2, 0), ssp*sizeof(block));
for(int i = 0; i < ssp; ++i) {
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &KK[party2][i], sizeof(block));
Hash::hash_once(tmp, &KK.at(party2, i), sizeof(block));
if(tmp_bool[i])
cheat = cheat or (strncmp(tmp, dgst1[party2*ssp+i], Hash::DIGEST_SIZE)!=0);
else
@@ -305,24 +280,16 @@ class ABitMP { public:
memset(tmp_block, 0, sizeof(block)*ssp);
for(int j = 1; j <= nP; ++j) if(j != i) {
for(int k = 0; k < ssp; ++k)
tmp_block[k] = tmp_block[k] ^ Ms[j][i][k];
tmp_block[k] = tmp_block[k] ^ Ms.at(j, i, k);
}
cheat = cheat or !cmpBlock(tmp_block, KK[i], ssp);
cheat = cheat or !cmpBlock(tmp_block, &KK.at(i, 0), ssp);
}
if(cheat) error("cheat aShare\n");
delete[] Ks[0];
delete[] Ks[1];
delete[] dgst;
delete[] dgst0;
delete[] dgst1;
delete[] tmp_block;
for(int i = 1; i <= nP; ++i) {
delete[] bs[i];
delete[] KK[i];
for(int j = 1; j <= nP; ++j)
delete[] Ms[i][j];
}
}
};
#endif //ABIT_MP_H

View File

@@ -11,9 +11,9 @@
using namespace emp;
class FpreMP { public:
std::shared_ptr<IMultiIO> io;
int nP;
int party;
std::shared_ptr<IMultiIO> io;
ABitMP* abit;
block Delta;
CRH * prps;
@@ -243,9 +243,7 @@ class FpreMP { public:
int * ind = new int[length*bucket_size];
int *location = new int[length*bucket_size];
bool * d[nP+1];
for(int i = 1; i <= nP; ++i)
d[i] = new bool[length*(bucket_size-1)];
NVec<bool> d(nP+1, length*(bucket_size-1));
for(int i = 0; i < length*bucket_size; ++i)
location[i] = i;
PRG prg2(&S);
@@ -261,7 +259,7 @@ class FpreMP { public:
for(int i = 0; i < length; ++i) {
for(int j = 0; j < bucket_size-1; ++j)
d[party][(bucket_size-1)*i+j] = tr[3*location[i*bucket_size]+1] != tr[3*location[i*bucket_size+1+j]+1];
d.at(party, (bucket_size-1)*i+j) = tr[3*location[i*bucket_size]+1] != tr[3*location[i*bucket_size+1+j]+1];
for(int j = 1; j <= nP; ++j) if (j!= party) {
memcpy(&MAC.at(j, 3*i), &tMAC.at(j, 3*location[i*bucket_size]), 3*sizeof(block));
memcpy(&KEY.at(j, 3*i), &tKEY.at(j, 3*location[i*bucket_size]), 3*sizeof(block));
@@ -282,24 +280,24 @@ class FpreMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
get_send_channel(*io, party2).send_data(d[party], (bucket_size-1)*length);
get_send_channel(*io, party2).send_data(&d.at(party, 0), (bucket_size-1)*length);
io->flush(party2);
get_recv_channel(*io, party2).recv_data(d[party2], (bucket_size-1)*length);
get_recv_channel(*io, party2).recv_data(&d.at(party2, 0), (bucket_size-1)*length);
}
for(int i = 2; i <= nP; ++i)
for(int j = 0; j < (bucket_size-1)*length; ++j)
d[1][j] = d[1][j]!=d[i][j];
d.at(1, j) = d.at(1, j) != d.at(i, j);
for(int i = 0; i < length; ++i) {
for(int j = 1; j <= nP; ++j)if (j!= party) {
for(int k = 1; k < bucket_size; ++k)
if(d[1][(bucket_size-1)*i+k-1]) {
if(d.at(1, (bucket_size-1)*i+k-1)) {
MAC.at(j, 3*i+2) = MAC.at(j, 3*i+2) ^ tMAC.at(j, 3*location[i*bucket_size+k]);
KEY.at(j, 3*i+2) = KEY.at(j, 3*i+2) ^ tKEY.at(j, 3*location[i*bucket_size+k]);
}
}
for(int k = 1; k < bucket_size; ++k)
if(d[1][(bucket_size-1)*i+k-1]) {
if(d.at(1, (bucket_size-1)*i+k-1)) {
r[3*i+2] = r[3*i+2] != tr[3*location[i*bucket_size+k]];
}
}

View File

@@ -14,8 +14,6 @@ class CMPC { public:
const block MASK = makeBlock(0x0ULL, 0xFFFFFULL);
FpreMP* fpre = nullptr;
int nP;
NVec<block> mac; // dim: parties, wires
NVec<block> key; // dim: parties, wires
Vec<bool> value; // dim: wires
@@ -35,6 +33,7 @@ class CMPC { public:
Vec<block> labels; // dim: wires
BristolFormat * cf;
std::shared_ptr<IMultiIO> io;
int nP;
int num_ands = 0, num_in;
int party, total_pre, ssp;
block Delta;
@@ -356,21 +355,21 @@ class CMPC { public:
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
Vec<block> H(nP+1);
for(int j = 2; j <= nP; ++j)
eval_labels.at(j, cf->gates[4*i+2]) = GTM.at(ands, index, j);
mask_input[cf->gates[4*i+2]] = GTv.at(ands, index);
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels.at(j, cf->gates[4*i]), eval_labels.at(j, cf->gates[4*i+1]), ands, index);
xorBlocks_arr(H, H, &GT.at(ands, j, index, 0), nP+1);
Hash(&H.at(0), eval_labels.at(j, cf->gates[4*i]), eval_labels.at(j, cf->gates[4*i+1]), ands, index);
xorBlocks_arr(&H.at(0), &H.at(0), &GT.at(ands, j, index, 0), nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels.at(k, cf->gates[4*i+2]) = H[k] ^ eval_labels.at(k, cf->gates[4*i+2]);
eval_labels.at(k, cf->gates[4*i+2]) = H.at(k) ^ eval_labels.at(k, cf->gates[4*i+2]);
block t0 = GTK.at(ands, index, j) ^ Delta;
if(cmpBlock(&H[1], &GTK.at(ands, index, j), 1))
if(cmpBlock(&H.at(1), &GTK.at(ands, index, j), 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
else if(cmpBlock(&H.at(1), &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}

View File

@@ -64,28 +64,28 @@ public:
}
}
int party() {
int party() override {
return mParty;
}
int size() {
int size() override {
return nP;
}
IOChannel& a_channel(int party2) {
IOChannel& a_channel(int party2) override {
assert(party2 != 0);
assert(party2 != party());
return *a_channels[party2];
}
IOChannel& b_channel(int party2) {
IOChannel& b_channel(int party2) override {
assert(party2 != 0);
assert(party2 != party());
return *b_channels[party2];
}
void flush(int idx) {
void flush(int idx) override {
assert(idx != 0);
if(party() < idx)