copy emp-toolkit/emp-agmpc files unmodified

This commit is contained in:
Andrew Morris
2025-01-15 13:46:10 +11:00
parent f461d9b5f9
commit 7bac01e196
9 changed files with 2708 additions and 0 deletions

47
programs/test_mpc.cpp Normal file
View File

@@ -0,0 +1,47 @@
#include <emp-tool/emp-tool.h>
#include "emp-agmpc/emp-agmpc.h"
using namespace std;
using namespace emp;
const string circuit_file_location = macro_xstr(EMP_CIRCUIT_PATH) + string("bristol_format/");
static char out3[] = "92b404e556588ced6c1acd4ebf053f6809f73a93";//bafbc2c87c33322603f38e06c3e0f79c1f1b1475";
int main(int argc, char** argv) {
int port, party;
parse_party_and_port(argv, &party, &port);
const static int nP = 3;
NetIOMP<nP> io(party, port);
NetIOMP<nP> io2(party, port+2*(nP+1)*(nP+1)+1);
NetIOMP<nP> *ios[2] = {&io, &io2};
ThreadPool pool(4);
string file = circuit_file_location+"/AES-non-expanded.txt";
file = circuit_file_location+"/sha-1.txt";
BristolFormat cf(file.c_str());
CMPC<nP>* mpc = new CMPC<nP>(ios, &pool, party, &cf);
cout <<"Setup:\t"<<party<<"\n";
mpc->function_independent();
cout <<"FUNC_IND:\t"<<party<<"\n";
mpc->function_dependent();
cout <<"FUNC_DEP:\t"<<party<<"\n";
bool in[512]; bool out[160];
memset(in, false, 512);
mpc->online(in, out);
uint64_t band2 = io.count();
cout <<"bandwidth\t"<<party<<"\t"<<band2<<endl;
cout <<"ONLINE:\t"<<party<<"\n";
if(party == 1) {
string res = "";
for(int i = 0; i < cf.n3; ++i)
res += (out[i]?"1":"0");
cout << hex_to_binary(string(out3))<<endl;
cout << res<<endl;
cout << (res == hex_to_binary(string(out3))? "GOOD!":"BAD!")<<endl<<flush;
}
delete mpc;
return 0;
}

338
src/cpp/emp-agmpc/abitmp.h Normal file
View File

@@ -0,0 +1,338 @@
#ifndef ABIT_MP_H__
#define ABIT_MP_H__
#include <emp-tool/emp-tool.h>
#include <emp-ot/emp-ot.h>
#include "netmp.h"
#include "helper.h"
template<int nP>
class ABitMP { public:
IKNP<NetIO> *abit1[nP+1];
IKNP<NetIO> *abit2[nP+1];
NetIOMP<nP> *io;
ThreadPool * pool;
int party;
PRG prg;
block Delta;
Hash hash;
int ssp;
block * pretable;
ABitMP(NetIOMP<nP>* io, ThreadPool * pool, int party, bool * _tmp = nullptr, int ssp = 40) {
this->ssp = ssp;
this->io = io;
this->pool = pool;
this->party = party;
bool tmp[128];
if(_tmp == nullptr) {
prg.random_bool(tmp, 128);
} else {
memcpy(tmp, _tmp, 128);
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if(i < j) {
if(i == party) {
abit1[j] = new IKNP<NetIO>(io->get(j, false));
abit2[j] = new IKNP<NetIO>(io->get(j, true));
} else if (j == party) {
abit2[i] = new IKNP<NetIO>(io->get(i, false));
abit1[i] = new IKNP<NetIO>(io->get(i, true));
}
}
vector<future<void>> res;//relic multi-thread problems...
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if(i < j) {
if(i == party) {
res.push_back(pool->enqueue([this, io, tmp, j]() {
abit1[j]->setup_send(tmp);
io->flush(j);
}));
res.push_back(pool->enqueue([this, io, j]() {
abit2[j]->setup_recv();
io->flush(j);
}));
} else if (j == party) {
res.push_back(pool->enqueue([this, io, i]() {
abit2[i]->setup_recv();
io->flush(i);
}));
res.push_back(pool->enqueue([this, io, tmp, i]() {
abit1[i]->setup_send(tmp);
io->flush(i);
}));
}
}
joinNclean(res);
if(party == 1)
Delta = abit1[2]->Delta;
else
Delta = abit1[1]->Delta;
}
~ABitMP() {
for(int i = 1; i <= nP; ++i) if( i!= party ) {
delete abit1[i];
delete abit2[i];
}
}
void compute(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
vector<future<void>> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, KEY, length, party2]() {
abit1[party2]->send_cot(KEY[party2], length);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, MAC, data, length, party2]() {
abit2[party2]->recv_cot(MAC[party2], data, length);
io->flush(party2);
}));
}
joinNclean(res);
#ifdef __debug
check_MAC(io, MAC, KEY, data, Delta, length, party);
#endif
}
future<void> check(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
future<void> ret = pool->enqueue([this, MAC, KEY, data, length](){
check1(MAC, KEY, data, length);
check2(MAC, KEY, data, length);
});
return ret;
}
void check1(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
block seed = sampleRandom(io, &prg, pool, party);
PRG prg2(&seed);
uint8_t * tmp;
block * Ms[nP+1];
bool * bs[nP+1];
block * Ks[nP+1];
block * tMs[nP+1];
bool * tbs[nP+1];
tmp = new uint8_t[ssp*length];
prg2.random_data(tmp, ssp*length);
for(int i = 0; i < ssp*length; ++i)
tmp[i] = tmp[i] % 4;
// for(int j = 0; j < ssp; ++j) {
// tmp[j] = new bool[length];
// for(int k = 0; k < ssp; ++k)
// tmp[j][length - ssp + k] = (k == j);
// }
for(int i = 1; i <= nP; ++i) {
Ms[i] = new block[ssp];
Ks[i] = new block[ssp];
bs[i] = new bool[ssp];
memset(Ms[i], 0, ssp*sizeof(block));
memset(Ks[i], 0, ssp*sizeof(block));
memset(bs[i], false, ssp);
tMs[i] = new block[ssp];
tbs[i] = new bool[ssp];
}
const int chk = 1;
const int SIZE = 1024*2;
block (* tMAC)[4] = new block[SIZE/chk][4];
block (* tKEY)[4] = new block[SIZE/chk][4];
bool (* tb)[4] = new bool[length/chk][4];
memset(tMAC, 0, sizeof(block)*4*SIZE/chk);
memset(tKEY, 0, sizeof(block)*4*SIZE/chk);
memset(tb, false, 4*length/chk);
for(int i = 0; i < length; i+=chk) {
tb[i/chk][1] = data[i];
tb[i/chk][2] = data[i+1];
tb[i/chk][3] = data[i] != data[i+1];
}
for(int k = 1; k <= nP; ++k) if(k != party) {
uint8_t * tmpptr = tmp;
for(int tt = 0; tt < length/SIZE; tt++) {
int start = SIZE*tt;
for(int i = SIZE*tt; i < SIZE*(tt+1) and i < length; i+=chk) {
tMAC[(i-start)/chk][1] = MAC[k][i];
tMAC[(i-start)/chk][2] = MAC[k][i+1];
tMAC[(i-start)/chk][3] = MAC[k][i] ^ MAC[k][i+1];
tKEY[(i-start)/chk][1] = KEY[k][i];
tKEY[(i-start)/chk][2] = KEY[k][i+1];
tKEY[(i-start)/chk][3] = KEY[k][i] ^ KEY[k][i+1];
for(int j = 0; j < ssp; ++j) {
Ms[k][j] = Ms[k][j] ^ tMAC[(i-start)/chk][*tmpptr];
Ks[k][j] = Ks[k][j] ^ tKEY[(i-start)/chk][*tmpptr];
bs[k][j] = bs[k][j] != tb[i/chk][*tmpptr];
++tmpptr;
}
}
}
}
delete[] tmp;
vector<future<bool>> res;
//TODO: they should not need to send MACs.
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, Ms, bs, party2]()->bool {
io->send_data(party2, Ms[party2], sizeof(block)*ssp);
io->send_data(party2, bs[party2], ssp);
io->flush(party2);
return false;
}));
res.push_back(pool->enqueue([this, tMs, tbs, Ks, party2]()->bool {
io->recv_data(party2, tMs[party2], sizeof(block)*ssp);
io->recv_data(party2, tbs[party2], ssp);
for(int k = 0; k < ssp; ++k) {
if(tbs[party2][k])
Ks[party2][k] = Ks[party2][k] ^ Delta;
}
return !cmpBlock(Ks[party2], tMs[party2], ssp);
}));
}
if(joinNcleanCheat(res)) error("cheat check1\n");
for(int i = 1; i <= nP; ++i) {
delete[] Ms[i];
delete[] Ks[i];
delete[] bs[i];
delete[] tMs[i];
delete[] tbs[i];
}
}
void check2(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
//last 2*ssp are garbage already.
block * Ks[2], *Ms[nP+1][nP+1];
block * KK[nP+1];
bool * bs[nP+1];
Ks[0] = new block[ssp];
Ks[1] = new block[ssp];
for(int i = 1; i <= nP; ++i) {
bs[i] = new bool[ssp];
KK[i] = new block[ssp];
for(int j = 1; j <= nP; ++j)
Ms[i][j] = new block[ssp];
}
char (*dgst)[Hash::DIGEST_SIZE] = new char[nP+1][Hash::DIGEST_SIZE];
char (*dgst0)[Hash::DIGEST_SIZE] = new char[ssp*(nP+1)][Hash::DIGEST_SIZE];
char (*dgst1)[Hash::DIGEST_SIZE] = new char[ssp*(nP+1)][Hash::DIGEST_SIZE];
for(int i = 0; i < ssp; ++i) {
Ks[0][i] = zero_block;
for(int j = 1; j <= nP; ++j) if(j != party)
Ks[0][i] = Ks[0][i] ^ KEY[j][length-3*ssp+i];
Ks[1][i] = Ks[0][i] ^ Delta;
Hash::hash_once(dgst0[party*ssp+i], &Ks[0][i], sizeof(block));
Hash::hash_once(dgst1[party*ssp+i], &Ks[1][i], sizeof(block));
}
Hash h;
h.put(data+length-3*ssp, ssp);
for(int j = 1; j <= nP; ++j) if(j != party) {
h.put(&MAC[j][length-3*ssp], ssp*sizeof(block));
}
h.digest(dgst[party]);
vector<future<void>> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, dgst, dgst0, dgst1, party2](){
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->send_data(party2, dgst0[party*ssp], Hash::DIGEST_SIZE*ssp);
io->send_data(party2, dgst1[party*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst0[party2*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_data(party2, dgst1[party2*ssp], Hash::DIGEST_SIZE*ssp);
}));
}
joinNclean(res);
vector<future<bool>> res2;
for(int k = 1; k <= nP; ++k) if(k!= party)
memcpy(Ms[party][k], MAC[k]+length-3*ssp, sizeof(block)*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res2.push_back(pool->enqueue([this, data, MAC, length, party2]() -> bool {
io->send_data(party2, data + length - 3*ssp, ssp);
for(int k = 1; k <= nP; ++k) if(k != party)
io->send_data(party2, MAC[k] + length - 3*ssp, sizeof(block)*ssp);
return false;
}));
res2.push_back(pool->enqueue([this, dgst, bs, Ms, party2]() -> bool {
Hash h;
io->recv_data(party2, bs[party2], ssp);
h.put(bs[party2], ssp);
for(int k = 1; k <= nP; ++k) if(k != party2) {
io->recv_data(party2, Ms[party2][k], sizeof(block)*ssp);
h.put(Ms[party2][k], sizeof(block)*ssp);
}
char tmp[Hash::DIGEST_SIZE];h.digest(tmp);
return strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE) != 0;
}));
}
if(joinNcleanCheat(res2)) error("commitment 1\n");
memset(bs[party], false, ssp);
for(int i = 1; i <= nP; ++i) if(i != party) {
for(int j = 0; j < ssp; ++j)
bs[party][j] = bs[party][j] != bs[i][j];
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res2.push_back(pool->enqueue([this, bs, Ks, party2]() -> bool {
io->send_data(party2, bs[party], ssp);
for(int i = 0; i < ssp; ++i) {
if(bs[party][i])
io->send_data(party2, &Ks[1][i], sizeof(block));
else
io->send_data(party2, &Ks[0][i], sizeof(block));
}
io->flush(party2);
return false;
}));
res2.push_back(pool->enqueue([this, KK, dgst0, dgst1, party2]() -> bool {
bool cheat = false;
bool *tmp_bool = new bool[ssp];
io->recv_data(party2, tmp_bool, ssp);
io->recv_data(party2, KK[party2], ssp*sizeof(block));
for(int i = 0; i < ssp; ++i) {
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &KK[party2][i], sizeof(block));
if(tmp_bool[i])
cheat = cheat or (strncmp(tmp, dgst1[party2*ssp+i], Hash::DIGEST_SIZE)!=0);
else
cheat = cheat or (strncmp(tmp, dgst0[party2*ssp+i], Hash::DIGEST_SIZE)!=0);
}
delete[] tmp_bool;
return cheat;
}));
}
if(joinNcleanCheat(res2)) error("commitments 2\n");
bool cheat = false;
block *tmp_block = new block[ssp];
for(int i = 1; i <= nP; ++i) if (i != party) {
memset(tmp_block, 0, sizeof(block)*ssp);
for(int j = 1; j <= nP; ++j) if(j != i) {
for(int k = 0; k < ssp; ++k)
tmp_block[k] = tmp_block[k] ^ Ms[j][i][k];
}
cheat = cheat or !cmpBlock(tmp_block, KK[i], ssp);
}
if(cheat) error("cheat aShare\n");
delete[] Ks[0];
delete[] Ks[1];
delete[] dgst;
delete[] dgst0;
delete[] dgst1;
delete[] tmp_block;
for(int i = 1; i <= nP; ++i) {
delete[] bs[i];
delete[] KK[i];
for(int j = 1; j <= nP; ++j)
delete[] Ms[i][j];
}
}
};
#endif //ABIT_MP_H__

View File

@@ -0,0 +1,18 @@
#ifndef __CMPC_CONFIG
#define __CMPC_CONFIG
const static int abit_block_size = 1024;
const static int fpre_threads = 1;
#define LOCALHOST
#ifdef __clang__
#define __MORE_FLUSH
#endif
//#define __debug
const static char *IP[] = {""
, "127.0.0.1"
, "127.0.0.1"
, "127.0.0.1"};
const static bool lan_network = false;
#endif// __C2PC_CONFIG

View File

@@ -0,0 +1,9 @@
#ifndef EMP_AGMPC_H__
#define EMP_AGMPC_H__
#include "emp-agmpc/abitmp.h"
#include "emp-agmpc/cmpc_config.h"
#include "emp-agmpc/fpremp.h"
#include "emp-agmpc/helper.h"
#include "emp-agmpc/mpc.h"
#include "emp-agmpc/netmp.h"
#endif// EMP_AGMPC_H__

View File

@@ -0,0 +1,826 @@
#ifndef EMP_AGMPC_FLEXIBLE_INPUT_OUTPUT_H
#define EMP_AGMPC_FLEXIBLE_INPUT_OUTPUT_H
using namespace std;
template<int nP>
struct AuthBitShare
{
bool bit_share;
block key[nP + 1];
block mac[nP + 1];
};
struct BitWithMac
{
bool bit_share;
block mac;
};
template<int nP>
class FlexIn
{
public:
int len{};
int party{};
bool cmpc_associated = false;
ThreadPool* pool;
bool *value;
block *key[nP + 1];
block *mac[nP + 1];
NetIOMP<nP> * io;
block Delta;
vector<int> party_assignment;
// -2 represents an un-authenticated share (e.g., for random tape),
// -1 represents an authenticated share,
// 0 represents public input/output,
vector<bool> plaintext_assignment; // if `party` provides the value for this bit, the plaintext value is here
vector<AuthBitShare<nP>> authenticated_share_assignment; // if this bit is from authenticated shares, the authenticated share is stored here
FlexIn(int len, int party) {
this->len = len;
this->party = party;
this->pool = pool;
AuthBitShare<nP> empty_abit;
memset(&empty_abit, 0, sizeof(AuthBitShare<nP>));
party_assignment.resize(len, 0);
plaintext_assignment.resize(len, false);
authenticated_share_assignment.resize(len, empty_abit);
}
~FlexIn() {
party_assignment.clear();
plaintext_assignment.clear();
authenticated_share_assignment.clear();
}
void associate_cmpc(ThreadPool *associated_pool, bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], NetIOMP<nP> *associated_io, block associated_Delta) {
this->cmpc_associated = true;
this->pool = associated_pool;
this->value = associated_value;
for(int j = 1; j <= nP; j++) {
this->mac[j] = associated_mac[j];
this->key[j] = associated_key[j];
}
this->io = associated_io;
this->Delta = associated_Delta;
}
void assign_party(int pos, int which_party) {
party_assignment[pos] = which_party;
}
void assign_plaintext_bit(int pos, bool cur_bit) {
assert(party_assignment[pos] == party || party_assignment[pos] == -2 || party_assignment[pos] == 0);
plaintext_assignment[pos] = cur_bit;
}
void assign_authenticated_bitshare(int pos, AuthBitShare<nP> *abit) {
assert(party_assignment[pos] == -1);
memcpy(&authenticated_share_assignment[pos], abit, sizeof(AuthBitShare<nP>));
}
void input(bool *masked_input_ret) {
assert(cmpc_associated);
/* assemble an array of the input masks, their macs, and their keys */
/*
* Then,
* for a plaintext bit, the input mask, as well as its MAC, is sent to the input party, who uses the KEY for verification;
* for an un-authenticated bit, the input mask XOR with the input share is broadcast;
* for an authenticated bit share, they are used to masked the previously data (and then checking its opening)
*/
vector<AuthBitShare<nP>> input_mask;
for(int i = 0; i < len; i++) {
AuthBitShare<nP> abit;
abit.bit_share = value[i];
for(int j = 1; j <= nP; j++) {
if(j != party) {
abit.key[j] = key[j][i];
abit.mac[j] = mac[j][i];
}
}
input_mask.emplace_back(abit);
}
/*
* first of all, handle the case party_assignment[] > 0
*/
/* prepare the bit shares to open for the corresponding party */
vector<vector<BitWithMac>> open_bit_shares_for_plaintext_input_send;
open_bit_shares_for_plaintext_input_send.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
if(j != party) {
for(int i = 0; i < len; i++) {
BitWithMac mbit{};
if(party_assignment[i] == j) {
mbit.bit_share = input_mask[i].bit_share;
mbit.mac = input_mask[i].mac[j];
}
open_bit_shares_for_plaintext_input_send[j].push_back(mbit);
}
}
}
vector<vector<BitWithMac>> open_bit_shares_for_plaintext_input_recv;
open_bit_shares_for_plaintext_input_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
open_bit_shares_for_plaintext_input_recv[j].resize(len);
}
/*
* exchange the opening of the input mask
*/
vector<future<void>> res;
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &open_bit_shares_for_plaintext_input_recv, party2]() {
io->recv_data(party2, open_bit_shares_for_plaintext_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &open_bit_shares_for_plaintext_input_send, party2]() {
io->send_data(party2, open_bit_shares_for_plaintext_input_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
/*
* verify the input mask
*/
vector<future<bool>> res_check;
for (int j = 1; j <= nP; ++j) {
if(j != party) {
res_check.push_back(pool->enqueue([this, &input_mask, &open_bit_shares_for_plaintext_input_recv, j]() {
for (int i = 0; i < len; i++) {
if (party_assignment[i] == party) {
block supposed_mac = Delta & select_mask[open_bit_shares_for_plaintext_input_recv[j][i].bit_share? 1 : 0];
supposed_mac ^= input_mask[i].key[j];
block provided_mac = open_bit_shares_for_plaintext_input_recv[j][i].mac;
if(!cmpBlock(&supposed_mac, &provided_mac, 1)) {
return true;
}
}
}
return false;
}));
}
}
if(joinNcleanCheat(res_check)) error("cheat in FlexIn's plaintext input mask!");
/*
* broadcast the masked input
*/
vector<char> masked_input_sent; // use char instead of bool because bools seem to fail for "data()"
vector<vector<char>> masked_input_recv;
masked_input_sent.resize(len);
masked_input_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
masked_input_recv[j].resize(len);
}
for(int i = 0; i < len; i++) {
if(party_assignment[i] == party) {
masked_input_sent[i] = plaintext_assignment[i] ^ input_mask[i].bit_share;
for(int j = 1; j <= nP; j++) {
if(j != party) {
masked_input_sent[i] = masked_input_sent[i] ^ open_bit_shares_for_plaintext_input_recv[j][i].bit_share;
}
}
}
}
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &masked_input_recv, party2]() {
io->recv_data(party2, masked_input_recv[party2].data(), sizeof(char) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &masked_input_sent, party2]() {
io->send_data(party2, masked_input_sent.data(), sizeof(char) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
vector<bool> masked_input;
masked_input.resize(len);
for(int i = 0; i < len; i++) {
if(party_assignment[i] > 0) {
int this_party = party_assignment[i];
if(this_party == party) {
masked_input[i] = masked_input_sent[i];
} else {
masked_input[i] = masked_input_recv[this_party][i];
}
}
}
/*
* secondly, handle the case party_assignment[] == -1
*/
/*
* Compute the authenticated bit share to the new circuit
* by XOR-ing with the input mask
*/
vector<AuthBitShare<nP>> authenticated_bitshares_new_circuit;
for(int i = 0; i < len; i++) {
AuthBitShare<nP> new_entry;
memset(&new_entry, 0, sizeof(AuthBitShare<nP>));
if(party_assignment[i] == -1) {
new_entry.bit_share = authenticated_share_assignment[i].bit_share ^ input_mask[i].bit_share;
for (int j = 1; j <= nP; j++) {
new_entry.key[j] = authenticated_share_assignment[i].key[j] ^ input_mask[i].key[j];
new_entry.mac[j] = authenticated_share_assignment[i].mac[j] ^ input_mask[i].mac[j];
}
}
authenticated_bitshares_new_circuit.emplace_back(new_entry);
}
//print_block(Delta);
/*
cout << "Debug the authenticated input" << endl;
for(int i = 0; i < 10; i++){
if(party_assignment[i] == -1) {
cout << "index: " << i << ", value: " << authenticated_share_assignment[i].bit_share << endl;
cout << "mac: " << endl;
for(int j = 1; j <= nP; j++) {
if(j != party) {
//cout << j << ": ";
print_block(authenticated_share_assignment[i].mac[j]);
}
}
cout << "key: " << endl;
for(int j = 1; j <= nP; j++) {
if(j != party) {
//cout << j << ": ";
print_block(authenticated_share_assignment[i].key[j]);
print_block(authenticated_share_assignment[i].key[j] ^ Delta);
}
}
cout << "=============" << endl;
}
}
*/
/*
* Opening the authenticated shares
*/
vector<vector<BitWithMac>> open_bit_shares_for_authenticated_bits_send;
open_bit_shares_for_authenticated_bits_send.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
if(j != party) {
for(int i = 0; i < len; i++) {
BitWithMac mbit{};
if(party_assignment[i] == -1) {
mbit.bit_share = authenticated_bitshares_new_circuit[i].bit_share;
mbit.mac = authenticated_bitshares_new_circuit[i].mac[j];
}
open_bit_shares_for_authenticated_bits_send[j].push_back(mbit);
}
}
}
vector<vector<BitWithMac>> open_bit_shares_for_authenticated_bits_recv;
open_bit_shares_for_authenticated_bits_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
open_bit_shares_for_authenticated_bits_recv[j].resize(len);
}
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &open_bit_shares_for_authenticated_bits_recv, party2]() {
io->recv_data(party2, open_bit_shares_for_authenticated_bits_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &open_bit_shares_for_authenticated_bits_send, party2]() {
io->send_data(party2, open_bit_shares_for_authenticated_bits_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
/*
* verify the input mask shares
*/
for (int j = 1; j <= nP; ++j) {
if(j != party) {
res_check.push_back(pool->enqueue([this, &authenticated_bitshares_new_circuit, &open_bit_shares_for_authenticated_bits_recv, j]() {
for (int i = 0; i < len; i++) {
if (party_assignment[i] == -1) {
block supposed_mac = Delta & select_mask[open_bit_shares_for_authenticated_bits_recv[j][i].bit_share? 1 : 0];
supposed_mac ^= authenticated_bitshares_new_circuit[i].key[j];
block provided_mac = open_bit_shares_for_authenticated_bits_recv[j][i].mac;
if(!cmpBlock(&supposed_mac, &provided_mac, 1)) {
return true;
}
}
}
return false;
}));
}
}
if(joinNcleanCheat(res_check)) error("cheat in FlexIn's authenticated share input mask!");
/*
* Reconstruct the authenticated shares
*/
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -1) {
masked_input[i] = authenticated_bitshares_new_circuit[i].bit_share;
for(int j = 1; j <= nP; j++) {
if(j != party) {
masked_input[i] = masked_input[i] ^ open_bit_shares_for_authenticated_bits_recv[j][i].bit_share;
}
}
}
}
/*
* thirdly, handle the case party_assignment[] = -2
*/
/*
* Collect the masked input shares for un-authenticated bits
*/
vector<char> open_bit_shares_for_unauthenticated_bits_send;
open_bit_shares_for_unauthenticated_bits_send.resize(len);
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -2) {
open_bit_shares_for_unauthenticated_bits_send[i] = (plaintext_assignment[i] ^ input_mask[i].bit_share)? 1 : 0;
}
}
vector<vector<char>> open_bit_shares_for_unauthenticated_bits_recv;
open_bit_shares_for_unauthenticated_bits_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
open_bit_shares_for_unauthenticated_bits_recv[j].resize(len);
}
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &open_bit_shares_for_unauthenticated_bits_recv, party2]() {
io->recv_data(party2, open_bit_shares_for_unauthenticated_bits_recv[party2].data(), sizeof(char) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &open_bit_shares_for_unauthenticated_bits_send, party2]() {
io->send_data(party2, open_bit_shares_for_unauthenticated_bits_send.data(), sizeof(char) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
/*
* update the array of masked_input accordingly
*/
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -2) {
masked_input[i] = open_bit_shares_for_unauthenticated_bits_send[i];
for(int j = 1; j <= nP; j++) {
if(j != party) {
masked_input[i] = masked_input[i] ^ (open_bit_shares_for_unauthenticated_bits_recv[j][i] == 1);
}
}
}
}
/*
* lastly, handle the case party_assignment[] = 0
*/
/*
* broadcast the input mask and its MAC
*/
vector<vector<BitWithMac>> open_bit_shares_for_public_input_send;
open_bit_shares_for_public_input_send.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
if(j != party) {
for(int i = 0; i < len; i++) {
BitWithMac mbit{};
if(party_assignment[i] == 0) {
mbit.bit_share = input_mask[i].bit_share;
mbit.mac = input_mask[i].mac[j];
}
open_bit_shares_for_public_input_send[j].push_back(mbit);
}
}
}
vector<vector<BitWithMac>> open_bit_shares_for_public_input_recv;
open_bit_shares_for_public_input_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
open_bit_shares_for_public_input_recv[j].resize(len);
}
/*
* exchange the opening of the input mask
*/
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &open_bit_shares_for_public_input_recv, party2]() {
io->recv_data(party2, open_bit_shares_for_public_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &open_bit_shares_for_public_input_send, party2]() {
io->send_data(party2, open_bit_shares_for_public_input_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
/*
* verify the input mask
*/
for (int j = 1; j <= nP; ++j) {
if(j != party) {
res_check.push_back(pool->enqueue([this, &input_mask, &open_bit_shares_for_public_input_recv, j]() {
for (int i = 0; i < len; i++) {
if (party_assignment[i] == 0) {
block supposed_mac = Delta & select_mask[open_bit_shares_for_public_input_recv[j][i].bit_share? 1 : 0];
supposed_mac ^= input_mask[i].key[j];
block provided_mac = open_bit_shares_for_public_input_recv[j][i].mac;
if(!cmpBlock(&supposed_mac, &provided_mac, 1)) {
return true;
}
}
}
return false;
}));
}
}
if(joinNcleanCheat(res_check)) error("cheat in FlexIn's public input mask!");
/*
* update the masked input
*/
for(int i = 0; i < len; i++) {
if(party_assignment[i] == 0) {
masked_input[i] = plaintext_assignment[i] ^ input_mask[i].bit_share;
for(int j = 1; j <= nP; j++) {
if(j != party) {
masked_input[i] = masked_input[i] ^ open_bit_shares_for_public_input_recv[j][i].bit_share;
}
}
}
}
/*
cout << "masked_input" << endl;
for(int i = 0; i < len; i++) {
cout << masked_input[i] << " ";
}
cout << endl;
*/
for(int i = 0; i < len; i++) {
masked_input_ret[i] = masked_input[i];
}
}
int get_length() {
return len;
}
};
template<int nP>
class FlexOut
{
public:
int len{};
int party{};
bool cmpc_associated = false;
ThreadPool* pool;
bool *value;
block *key[nP + 1];
block *mac[nP + 1];
block *eval_labels[nP + 1];
NetIOMP<nP> * io;
block Delta;
block *labels;
vector<int> party_assignment;
// -1 represents an authenticated share,
// 0 represents public output
vector<bool> plaintext_results; // if `party` provides the value for this bit, the plaintext value is here
vector<AuthBitShare<nP>> authenticated_share_results; // if this bit is from authenticated shares, the authenticated share is stored here
FlexOut(int len, int party) {
this->len = len;
this->party = party;
this->pool = pool;
AuthBitShare<nP> empty_abit;
memset(&empty_abit, 0, sizeof(AuthBitShare<nP>));
party_assignment.resize(len, 0);
plaintext_results.resize(len, false);
authenticated_share_results.resize(len, empty_abit);
}
~FlexOut() {
party_assignment.clear();
plaintext_results.clear();
authenticated_share_results.clear();
}
void associate_cmpc(ThreadPool *associated_pool, bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], block *associated_eval_labels[nP + 1], block *associated_labels, NetIOMP<nP> *associated_io, block associated_Delta) {
this->cmpc_associated = true;
this->pool = associated_pool;
this->value = associated_value;
this->labels = associated_labels;
for (int j = 1; j <= nP; j++) {
this->mac[j] = associated_mac[j];
this->key[j] = associated_key[j];
}
if (party == ALICE){
for (int j = 2; j <= nP; j++) {
this->eval_labels[j] = associated_eval_labels[j];
}
}
this->io = associated_io;
this->Delta = associated_Delta;
}
void assign_party(int pos, int which_party) {
party_assignment[pos] = which_party;
}
bool get_plaintext_bit(int pos) {
assert(party_assignment[pos] == party || party_assignment[pos] == 0);
return plaintext_results[pos];
}
AuthBitShare<nP> get_authenticated_bitshare(int pos) {
assert(party_assignment[pos] == -1);
return authenticated_share_results[pos];
}
int get_length() {
return len;
}
void output(bool *masked_input_ret, int output_shift) {
assert(cmpc_associated);
/*
* Party 1 sends the labels of all the output wires out.
*/
vector<block> output_wire_label_recv;
output_wire_label_recv.resize(len);
vector<future<void>> res;
if(party == ALICE) {
vector<vector<block>> output_wire_label_send;
output_wire_label_send.resize(nP + 1);
for (int j = 2; j <= nP; j++) {
output_wire_label_send[j].resize(len);
for(int i = 0; i < len; i++) {
output_wire_label_send[j][i] = eval_labels[j][output_shift + i];
}
}
for(int j = 2; j <= nP; j++) {
res.push_back(pool->enqueue([this, &output_wire_label_send, j]() {
io->send_data(j, output_wire_label_send[j].data(), sizeof(block) * len);
io->flush(j);
}));
}
joinNclean(res);
}else {
io->recv_data(ALICE, output_wire_label_recv.data(), sizeof(block) * len);
io->flush(ALICE);
}
/*
* Each party extracts x ^ r of each output wire
*/
vector<bool> masked_output;
masked_output.resize(len);
if(party == ALICE) {
for(int i = 0; i < len; i++) {
masked_output[i] = masked_input_ret[output_shift + i];
}
} else {
for(int i = 0; i < len; i++) {
block cur_label = output_wire_label_recv[i];
block zero_label = labels[i + output_shift];
block one_label = zero_label ^ Delta;
if(cmpBlock(&cur_label, &zero_label, 1)) {
masked_output[i] = false;
} else if(cmpBlock(&cur_label, &one_label, 1)) {
masked_output[i] = true;
} else {
error("Output label mismatched.\n");
}
}
}
/*
* Decide the broadcasting of the shares of r, as well as their MAC
*/
vector<vector<BitWithMac>> output_mask_send;
output_mask_send.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
BitWithMac empty_mbit{};
output_mask_send[j].resize(len, empty_mbit);
}
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -1) {
// do nothing, just update the share (for the first party) later
} else if(party_assignment[i] == 0) {
// public output, all parties receive the mbit
for(int j = 1; j <= nP; j++){
output_mask_send[j][i].bit_share = value[output_shift + i];
output_mask_send[j][i].mac = mac[j][output_shift + i];
}
} else {
// only one party is supposed to receive the mbit
int cur_party = party_assignment[i];
output_mask_send[cur_party][i].bit_share = value[output_shift + i];
output_mask_send[cur_party][i].mac = mac[cur_party][output_shift + i];
}
}
/*
* Exchange the output mask
*/
vector<vector<BitWithMac>> output_mask_recv;
output_mask_recv.resize(nP + 1);
for(int j = 1; j <= nP; j++) {
output_mask_recv[j].resize(len);
}
for (int i = 1; i <= nP; ++i) {
for (int j = 1; j <= nP; ++j) {
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, &output_mask_recv, party2]() {
io->recv_data(party2, output_mask_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, &output_mask_send, party2]() {
io->send_data(party2, output_mask_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}));
}
}
}
joinNclean(res);
/*
* Verify the output mask
*/
vector<future<bool>> res_check;
for (int j = 1; j <= nP; ++j) {
if(j != party) {
res_check.push_back(pool->enqueue([this, &output_mask_recv, j, output_shift]() {
for (int i = 0; i < len; i++) {
if (party_assignment[i] == party || party_assignment[i] == 0) {
block supposed_mac = Delta & select_mask[output_mask_recv[j][i].bit_share? 1 : 0];
supposed_mac ^= key[j][output_shift + i];
block provided_mac = output_mask_recv[j][i].mac;
if(!cmpBlock(&supposed_mac, &provided_mac, 1)) {
return true;
}
}
}
return false;
}));
}
}
if(joinNcleanCheat(res_check)) error("cheat in FlexOut's output mask!");
/*
* Handle the case party_assignment[] = -1
*/
if(party == ALICE) {
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -1) {
authenticated_share_results[i].bit_share = value[output_shift + i] ^ masked_output[i];
for(int j = 1; j <= nP; j++) {
if(j != party) {
authenticated_share_results[i].mac[j] = mac[j][output_shift + i];
authenticated_share_results[i].key[j] = key[j][output_shift + i];
}
}
}
}
} else {
for(int i = 0; i < len; i++) {
if(party_assignment[i] == -1) {
authenticated_share_results[i].bit_share = value[output_shift + i];
for(int j = 1; j <= nP; j++) {
if(j != party) {
authenticated_share_results[i].mac[j] = mac[j][output_shift + i];
if(j == ALICE) {
authenticated_share_results[i].key[j] =
key[j][output_shift + i] ^ (Delta & select_mask[masked_output[i] ? 1 : 0]);
// change the MAC key for the first party
} else {
authenticated_share_results[i].key[j] = key[j][output_shift + i];
}
}
}
}
}
}
//print_block(Delta);
/*
cout << "Debug the authenticated output" << endl;
for(int i = 0; i < 10; i++){
if(party_assignment[i] == -1) {
cout << "index: " << i << ", value: " << authenticated_share_results[i].bit_share << endl;
cout << "mac: " << endl;
for(int j = 1; j <= nP; j++) {
if(j != party) {
//cout << j << ": ";
print_block(authenticated_share_results[i].mac[j]);
}
}
cout << "key: " << endl;
for(int j = 1; j <= nP; j++) {
if(j != party) {
//cout << j << ": ";
print_block(authenticated_share_results[i].key[j]);
print_block(authenticated_share_results[i].key[j] ^ Delta);
}
}
cout << "=============" << endl;
}
}
*/
/*
* Handle the case party_assignment[] = 0 or == party
*/
for(int i = 0; i < len; i++) {
if(party_assignment[i] == 0 || party_assignment[i] == party) {
plaintext_results[i] = value[output_shift + i] ^ masked_output[i];
for(int j = 1; j <= nP; j++) {
if(j != party) {
plaintext_results[i] = plaintext_results[i] ^ output_mask_recv[j][i].bit_share;
}
}
}
}
}
};
#endif //EMP_AGMPC_FLEXIBLE_INPUT_OUTPUT_H

429
src/cpp/emp-agmpc/fpremp.h Normal file
View File

@@ -0,0 +1,429 @@
//TODO: check MACs
#ifndef FPRE_MP_H__
#define FPRE_MP_H__
#include <emp-tool/emp-tool.h>
#include <thread>
#include "abitmp.h"
#include "netmp.h"
#include "cmpc_config.h"
using namespace emp;
template<int nP>
class FpreMP { public:
ThreadPool *pool;
int party;
NetIOMP<nP> * io;
ABitMP<nP>* abit;
block Delta;
CRH * prps;
CRH * prps2;
PRG * prgs;
PRG prg;
int ssp;
FpreMP(NetIOMP<nP> * io[2], ThreadPool * pool, int party, bool * _delta = nullptr, int ssp = 40) {
this->party = party;
this->pool = pool;
this->io = io[0];
this ->ssp = ssp;
abit = new ABitMP<nP>(io[1], pool, party, _delta, ssp);
Delta = abit->Delta;
prps = new CRH[nP+1];
prps2 = new CRH[nP+1];
prgs = new PRG[nP+1];
}
~FpreMP(){
delete[] prps;
delete[] prps2;
delete[] prgs;
delete abit;
}
int get_bucket_size(int size) {
size = max(size, 320);
int batch_size = ((size+2-1)/2)*2;
if(batch_size >= 280*1000)
return 3;
else if(batch_size >= 3100)
return 4;
else return 5;
}
void compute(block * MAC[nP+1], block * KEY[nP+1], bool * r, int length) {
int64_t bucket_size = get_bucket_size(length);
block * tMAC[nP+1];
block * tKEY[nP+1];
block * tKEYphi[nP+1];
block * tMACphi[nP+1];
block * phi;
block *X [nP+1];
bool *tr = new bool[length*bucket_size*3+3*ssp];
phi = new block[length*bucket_size];
bool *s[nP+1], *e = new bool[length*bucket_size];
for(int i = 1; i <= nP; ++i) {
tMAC[i] = new block[length*bucket_size*3+3*ssp];
tKEY[i] = new block[length*bucket_size*3+3*ssp];
tKEYphi[i] = new block[length*bucket_size+3*ssp];
tMACphi[i] = new block[length*bucket_size+3*ssp];
X[i] = new block[ssp];
}
for(int i = 0; i <= nP; ++i) {
s[i] = new bool[length*bucket_size];
memset(s[i], 0, length*bucket_size);
}
prg.random_bool(tr, length*bucket_size*3+3*ssp);
// memset(tr, false, length*bucket_size*3+3*ssp);
abit->compute(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp);
vector<future<void>> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j ) {
if(i == party) {
res.push_back(pool->enqueue([this, tKEY, tr, s, length, bucket_size, j]() {
prgs[j].random_bool(s[j], length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = garble(tKEY[j], tr, s[j], k, j);
io->send_data(j, &data, 1);
s[j][k] = (s[j][k] != (tr[3*k] and tr[3*k+1]));
}
io->flush(j);
}));
} else if (j == party) {
res.push_back(pool->enqueue([this, tMAC, tr, s, length, bucket_size, i]() {
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = 0;
io->recv_data(i, &data, 1);
bool tmp = evaluate(data, tMAC[i], tr, k, i);
s[i][k] = (tmp != (tr[3*k] and tr[3*k+1]));
}
}));
}
}
joinNclean(res);
for(int k = 0; k < length*bucket_size; ++k) {
s[0][k] = (tr[3*k] and tr[3*k+1]);
for(int i = 1; i <= nP; ++i)
if (i != party) {
s[0][k] = (s[0][k] != s[i][k]);
}
e[k] = (s[0][k] != tr[3*k+2]);
tr[3*k+2] = s[0][k];
}
#ifdef __debug
check_correctness(io, tr, length*bucket_size, party);
#endif
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, e, length, bucket_size, party2]() {
io->send_data(party2, e, length*bucket_size);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, tKEY, length, bucket_size, party2]() {
bool * tmp = new bool[length*bucket_size];
io->recv_data(party2, tmp, length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
if(tmp[k])
tKEY[party2][3*k+2] = tKEY[party2][3*k+2] ^ Delta;
}
delete[] tmp;
}));
}
joinNclean(res);
#ifdef __debug
check_MAC(io, tMAC, tKEY, tr, Delta, length*bucket_size*3, party);
#endif
auto ret = abit->check(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp);
ret.get();
//check compute phi
for(int k = 0; k < length*bucket_size; ++k) {
phi[k] = zero_block;
for(int i = 1; i <= nP; ++i) if (i != party) {
phi[k] = phi[k] ^ tKEY[i][3*k+1];
phi[k] = phi[k] ^ tMAC[i][3*k+1];
}
if(tr[3*k+1])phi[k] = phi[k] ^ Delta;
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, tKEY, tKEYphi, phi, length, bucket_size, party2]() {
block bH[2], tmpH[2];
for(int k = 0; k < length*bucket_size; ++k) {
bH[0] = tKEY[party2][3*k];
bH[1] = bH[0] ^ Delta;
HnID(prps+party2, bH, bH, 2*k, 2, tmpH);
tKEYphi[party2][k] = bH[0];
bH[1] = bH[0] ^ bH[1];
bH[1] = phi[k] ^ bH[1];
io->send_data(party2, &bH[1], sizeof(block));
}
io->flush(party2);
}));
res.push_back(pool->enqueue([this, tMAC, tMACphi, tr, length, bucket_size, party2]() {
block bH;
for(int k = 0; k < length*bucket_size; ++k) {
io->recv_data(party2, &bH, sizeof(block));
block hin = sigma(tMAC[party2][3*k]) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi[party2][k] = prps2[party2].H(hin);
if(tr[3*k])tMACphi[party2][k] = tMACphi[party2][k] ^ bH;
}
}));
}
joinNclean(res);
bool * xs = new bool[length*bucket_size];
for(int i = 0; i < length*bucket_size; ++i) xs[i] = tr[3*i];
#ifdef __debug
check_MAC_phi(tMACphi, tKEYphi, phi, xs, length*bucket_size);
#endif
//tKEYphti use as H
for(int k = 0; k < length*bucket_size; ++k) {
tKEYphi[party][k] = zero_block;
for(int i = 1; i <= nP; ++i) if (i != party) {
tKEYphi[party][k] = tKEYphi[party][k] ^ tKEYphi[i][k];
tKEYphi[party][k] = tKEYphi[party][k] ^ tMACphi[i][k];
tKEYphi[party][k] = tKEYphi[party][k] ^ tKEY[i][3*k+2];
tKEYphi[party][k] = tKEYphi[party][k] ^ tMAC[i][3*k+2];
}
if(tr[3*k]) tKEYphi[party][k] = tKEYphi[party][k] ^ phi[k];
if(tr[3*k+2])tKEYphi[party][k] = tKEYphi[party][k] ^ Delta;
}
#ifdef __debug
check_zero(tKEYphi[party], length*bucket_size);
#endif
block prg_key = sampleRandom(io, &prg, pool, party);
PRG prgf(&prg_key);
char (*dgst)[Hash::DIGEST_SIZE] = new char[nP+1][Hash::DIGEST_SIZE];
bool * tmp = new bool[length*bucket_size];
for(int i = 0; i < ssp; ++i) {
prgf.random_bool(tmp, length*bucket_size);
X[party][i] = inProd(tmp, tKEYphi[party], length*bucket_size);
}
Hash::hash_once(dgst[party], X[party], sizeof(block)*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, dgst, party2]() {
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
}));
}
joinNclean(res);
vector<future<bool>> res2;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res2.push_back(pool->enqueue([this, X, dgst, party2]() -> bool {
io->send_data(party2, X[party], sizeof(block)*ssp);
io->recv_data(party2, X[party2], sizeof(block)*ssp);
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, X[party2], sizeof(block)*ssp);
return strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0;
}));
}
if(joinNcleanCheat(res2)) error("commitment");
for(int i = 2; i <= nP; ++i)
xorBlocks_arr(X[1], X[1], X[i], ssp);
for(int i = 0; i < ssp; ++i)X[2][i] = zero_block;
if(!cmpBlock(X[1], X[2], ssp)) error("AND check");
//land -> and
block S = sampleRandom<nP>(io, &prg, pool, party);
int * ind = new int[length*bucket_size];
int *location = new int[length*bucket_size];
bool * d[nP+1];
for(int i = 1; i <= nP; ++i)
d[i] = new bool[length*(bucket_size-1)];
for(int i = 0; i < length*bucket_size; ++i)
location[i] = i;
PRG prg2(&S);
prg2.random_data(ind, length*bucket_size*4);
for(int i = length*bucket_size-1; i>=0; --i) {
int index = ind[i]%(i+1);
index = index>0? index:(-1*index);
int tmp = location[i];
location[i] = location[index];
location[index] = tmp;
}
delete[] ind;
for(int i = 0; i < length; ++i) {
for(int j = 0; j < bucket_size-1; ++j)
d[party][(bucket_size-1)*i+j] = tr[3*location[i*bucket_size]+1] != tr[3*location[i*bucket_size+1+j]+1];
for(int j = 1; j <= nP; ++j) if (j!= party) {
memcpy(MAC[j]+3*i, tMAC[j]+3*location[i*bucket_size], 3*sizeof(block));
memcpy(KEY[j]+3*i, tKEY[j]+3*location[i*bucket_size], 3*sizeof(block));
for(int k = 1; k < bucket_size; ++k) {
MAC[j][3*i] = MAC[j][3*i] ^ tMAC[j][3*location[i*bucket_size+k]];
KEY[j][3*i] = KEY[j][3*i] ^ tKEY[j][3*location[i*bucket_size+k]];
MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]+2];
KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]+2];
}
}
memcpy(r+3*i, tr+3*location[i*bucket_size], 3);
for(int k = 1; k < bucket_size; ++k) {
r[3*i] = r[3*i] != tr[3*location[i*bucket_size+k]];
r[3*i+2] = r[3*i+2] != tr[3*location[i*bucket_size+k]+2];
}
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, d, length, bucket_size, party2]() {
io->send_data(party2, d[party], (bucket_size-1)*length);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, d, length, bucket_size, party2]() {
io->recv_data(party2, d[party2], (bucket_size-1)*length);
}));
}
joinNclean(res);
for(int i = 2; i <= nP; ++i)
for(int j = 0; j < (bucket_size-1)*length; ++j)
d[1][j] = d[1][j]!=d[i][j];
for(int i = 0; i < length; ++i) {
for(int j = 1; j <= nP; ++j)if (j!= party) {
for(int k = 1; k < bucket_size; ++k)
if(d[1][(bucket_size-1)*i+k-1]) {
MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]];
KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]];
}
}
for(int k = 1; k < bucket_size; ++k)
if(d[1][(bucket_size-1)*i+k-1]) {
r[3*i+2] = r[3*i+2] != tr[3*location[i*bucket_size+k]];
}
}
#ifdef __debug
check_MAC(io, MAC, KEY, r, Delta, length*3, party);
check_correctness(io, r, length, party);
#endif
// ret.get();
delete[] tr;
delete[] phi;
delete[] e;
delete[] dgst;
delete[] tmp;
delete[] location;
delete[] xs;
for(int i = 1; i <= nP; ++i) {
delete[] tMAC[i];
delete[] tKEY[i];
delete[] tMACphi[i];
delete[] tKEYphi[i];
delete[] X[i];
delete[] s[i];
delete[] d[i];
}
delete[] s[0];
}
//TODO: change to justGarble
uint8_t garble(block * KEY, bool * r, bool * r2, int i, int I) {
uint8_t data = 0;
block tmp[4], tmp2[4], tmpH[4];
tmp[0] = KEY[3*i];
tmp[1] = tmp[0] ^ Delta;
tmp[2] = KEY[3*i+1];
tmp[3] = tmp[2] ^ Delta;
HnID(prps+I, tmp, tmp, 4*i, 4, tmpH);
tmp2[0] = tmp[0] ^ tmp[2];
tmp2[1] = tmp[1] ^ tmp[2];
tmp2[2] = tmp[0] ^ tmp[3];
tmp2[3] = tmp[1] ^ tmp[3];
data = LSB(tmp2[0]);
data |= (LSB(tmp2[1])<<1);
data |= (LSB(tmp2[2])<<2);
data |= (LSB(tmp2[3])<<3);
if ( ((false != r[3*i] ) && (false != r[3*i+1])) != r2[i] )
data= data ^ 0x1;
if ( ((true != r[3*i] ) && (false != r[3*i+1])) != r2[i] )
data = data ^ 0x2;
if ( ((false != r[3*i] ) && (true != r[3*i+1])) != r2[i] )
data = data ^ 0x4;
if ( ((true != r[3*i] ) && (true != r[3*i+1])) != r2[i] )
data = data ^ 0x8;
return data;
}
bool evaluate(uint8_t tmp, block * MAC, bool * r, int i, int I) {
block hin = sigma(MAC[3*i]) ^ makeBlock(0, 4*i + r[3*i]);
block hin2 = sigma(MAC[3*i+1]) ^ makeBlock(0, 4*i + 2 + r[3*i+1]);
block bH = prps[I].H(hin) ^ prps[I].H(hin2);
uint8_t res = LSB(bH);
tmp >>= (r[3*i+1]*2+r[3*i]);
return (tmp&0x1) != (res&0x1);
}
void check_MAC_phi(block * MAC[nP+1], block * KEY[nP+1], block * phi, bool * r, int length) {
block * tmp = new block[length];
block *tD = new block[length];
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, phi, length*sizeof(block));
io->send_data(j, KEY[j], sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, tD, length*sizeof(block));
io->recv_data(i, tmp, sizeof(block)*length);
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD[k];
}
if(!cmpBlock(MAC[i], tmp, length))
error("check_MAC_phi failed!");
}
}
delete[] tmp;
delete[] tD;
if(party == 1)
cerr<<"check_MAC_phi pass!\n"<<flush;
}
void check_zero(block * b, int l) {
if(party == 1) {
block * tmp1 = new block[l];
block * tmp2 = new block[l];
memcpy(tmp1, b, l*sizeof(block));
for(int i = 2; i <= nP; ++i) {
io->recv_data(i, tmp2, l*sizeof(block));
xorBlocks_arr(tmp1, tmp1, tmp2, l);
}
block z = zero_block;
for(int i = 0; i < l; ++i)
if(!cmpBlock(&z, &tmp1[i], 1))
error("check sum zero failed!");
cerr<<"check zero sum pass!\n"<<flush;
delete[] tmp1;
delete[] tmp2;
} else {
io->send_data(1, b, l*sizeof(block));
io->flush(1);
}
}
void HnID(CRH* crh, block*out, block* in, uint64_t id, int length, block * scratch = nullptr) {
bool del = false;
if(scratch == nullptr) {
del = true;
scratch = new block[length];
}
for(int i = 0; i < length; ++i){
out[i] = scratch[i] = sigma(in[i]) ^ makeBlock(0, id);
++id;
}
crh->permute_block(scratch, length);
xorBlocks_arr(out, scratch, out, length);
if(del) {
delete[] scratch;
scratch = nullptr;
}
}
};
#endif// FPRE_H__

238
src/cpp/emp-agmpc/helper.h Normal file
View File

@@ -0,0 +1,238 @@
#ifndef __HELPER
#define __HELPER
#include <emp-tool/emp-tool.h>
#include "cmpc_config.h"
#include "netmp.h"
#include <future>
using namespace emp;
using std::future;
using std::cout;
using std::max;
using std::cerr;
using std::endl;
using std::flush;
const static block inProdTableBlock[] = {zero_block, all_one_block};
block inProd(bool * b, block * blk, int length) {
block res = zero_block;
for(int i = 0; i < length; ++i)
// if(b[i])
// res = res ^ blk[i];
res = res ^ (inProdTableBlock[b[i]] & blk[i]);
return res;
}
#ifdef __GNUC__
#ifndef __clang__
#pragma GCC push_options
#pragma GCC optimize ("unroll-loops")
#endif
#endif
template<int ssp>
void inProdhelp(block *Ms, bool * tmp[ssp], block * MAC, int i) {
for(int j = 0; j < ssp; ++j)
Ms[j] = Ms[j] ^ (inProdTableBlock[tmp[j][i]] & MAC[i]);
}
#ifdef __GNUC__
#ifndef __clang__
#pragma GCC pop_options
#endif
#endif
template<int ssp>
void inProds(block *Ms, bool * tmp[ssp], block * MAC, int length) {
memset(Ms, 0, sizeof(block)*ssp);
for(int i = 0; i < length; ++i) {
inProdhelp<ssp>(Ms, tmp, MAC, i);
}
}
bool inProd(bool * b, bool* b2, int length) {
bool res = false;
for(int i = 0; i < length; ++i)
res = (res != (b[i] and b2[i]));
return res;
}
template<typename T>
void joinNclean(vector<future<T>>& res) {
for(auto &v: res) v.get();
res.clear();
}
bool joinNcleanCheat(vector<future<bool>>& res) {
bool cheat = false;
for(auto &v: res) cheat = cheat or v.get();
res.clear();
return cheat;
}
void send_bool(NetIO * io, const bool * data, int length) {
if(lan_network) {
io->send_data(data, length);
return;
}
for(int i = 0; i < length;) {
uint64_t tmp = 0;
for(int j = 0; j < 64 and i < length; ++i,++j) {
if(data[i])
tmp|=(0x1ULL<<j);
}
io->send_data(&tmp, 8);
}
}
void recv_bool(NetIO * io, bool * data, int length) {
if(lan_network) {
io->recv_data(data, length);
return;
}
for(int i = 0; i < length;) {
uint64_t tmp = 0;
io->recv_data(&tmp, 8);
for(int j = 63; j >= 0 and i < length; ++i,--j) {
data[i] = (tmp&0x1) == 0x1;
tmp>>=1;
}
}
}
template<int B>
void send_partial_block(NetIO * io, const block * data, int length) {
for(int i = 0; i < length; ++i) {
io->send_data(&(data[i]), B);
}
}
template<int B>
void recv_partial_block(NetIO * io, block * data, int length) {
for(int i = 0; i < length; ++i) {
io->recv_data(&(data[i]), B);
}
}
inline uint8_t LSB(block & b) {
return _mm_extract_epi8(b, 0) & 0x1;
}
template<int nP>
block sampleRandom(NetIOMP<nP> * io, PRG * prg, ThreadPool * pool, int party) {
vector<future<void>> res;
vector<future<bool>> res2;
char (*dgst)[Hash::DIGEST_SIZE] = new char[nP+1][Hash::DIGEST_SIZE];
block *S = new block[nP+1];
prg->random_block(&S[party], 1);
Hash::hash_once(dgst[party], &S[party], sizeof(block));
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([dgst, io, party, party2]() {
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
}));
}
joinNclean(res);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res2.push_back(pool->enqueue([io, S, dgst, party, party2]() -> bool {
io->send_data(party2, &S[party], sizeof(block));
io->recv_data(party2, &S[party2], sizeof(block));
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &S[party2], sizeof(block));
return strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0;
}));
}
bool cheat = joinNcleanCheat(res2);
if(cheat) {
cout <<"cheat in sampleRandom\n"<<flush;
exit(0);
}
for(int i = 2; i <= nP; ++i)
S[1] = S[1] ^ S[i];
block result = S[1];
delete[] S;
delete[] dgst;
return result;
}
template<int nP>
void check_MAC(NetIOMP<nP> * io, block * MAC[nP+1], block * KEY[nP+1], bool * r, block Delta, int length, int party) {
block * tmp = new block[length];
block tD;
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, &Delta, sizeof(block));
io->send_data(j, KEY[j], sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, &tD, sizeof(block));
io->recv_data(i, tmp, sizeof(block)*length);
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD;
}
if(!cmpBlock(MAC[i], tmp, length))
error("check_MAC failed!");
}
}
delete[] tmp;
if(party == 1)
cerr<<"check_MAC pass!\n"<<flush;
}
template<int nP>
void check_correctness(NetIOMP<nP>* io, bool * r, int length, int party) {
if (party == 1) {
bool * tmp1 = new bool[length*3];
bool * tmp2 = new bool[length*3];
memcpy(tmp1, r, length*3);
for(int i = 2; i <= nP; ++i) {
io->recv_data(i, tmp2, length*3);
for(int k = 0; k < length*3; ++k)
tmp1[k] = (tmp1[k] != tmp2[k]);
}
for(int k = 0; k < length; ++k) {
if((tmp1[3*k] and tmp1[3*k+1]) != tmp1[3*k+2])
error("check_correctness failed!");
}
delete[] tmp1;
delete[] tmp2;
cerr<<"check_correctness pass!\n"<<flush;
} else {
io->send_data(1, r, length*3);
io->flush(1);
}
}
inline const char* hex_char_to_bin(char c) {
switch(toupper(c)) {
case '0': return "0000";
case '1': return "0001";
case '2': return "0010";
case '3': return "0011";
case '4': return "0100";
case '5': return "0101";
case '6': return "0110";
case '7': return "0111";
case '8': return "1000";
case '9': return "1001";
case 'A': return "1010";
case 'B': return "1011";
case 'C': return "1100";
case 'D': return "1101";
case 'E': return "1110";
case 'F': return "1111";
default: return "0";
}
}
inline std::string hex_to_binary(std::string hex) {
std::string bin;
for(unsigned i = 0; i != hex.length(); ++i)
bin += hex_char_to_bin(hex[i]);
return bin;
}
#endif// __HELPER

679
src/cpp/emp-agmpc/mpc.h Normal file
View File

@@ -0,0 +1,679 @@
#ifndef CMPC_H__
#define CMPC_H__
#include "fpremp.h"
#include "abitmp.h"
#include "netmp.h"
#include "flexible_input_output.h"
#include <emp-tool/emp-tool.h>
using namespace emp;
template<int nP>
class CMPC { public:
const static int SSP = 5;//5*8 in fact...
const block MASK = makeBlock(0x0ULL, 0xFFFFFULL);
FpreMP<nP>* fpre = nullptr;
block* mac[nP+1];
block* key[nP+1];
bool* value;
block * preprocess_mac[nP+1];
block * preprocess_key[nP+1];
bool* preprocess_value;
block * sigma_mac[nP+1];
block * sigma_key[nP+1];
bool * sigma_value;
block * ANDS_mac[nP+1];
block * ANDS_key[nP+1];
bool * ANDS_value;
block * labels;
BristolFormat * cf;
NetIOMP<nP> * io;
int num_ands = 0, num_in;
int party, total_pre, ssp;
ThreadPool * pool;
block Delta;
block (*GTM)[4][nP+1];
block (*GTK)[4][nP+1];
bool (*GTv)[4];
block (*GT)[nP+1][4][nP+1];
block * eval_labels[nP+1];
PRP prp;
CMPC(NetIOMP<nP> * io[2], ThreadPool * pool, int party, BristolFormat * cf, bool * _delta = nullptr, int ssp = 40) {
this->party = party;
this->io = io[0];
this->cf = cf;
this->ssp = ssp;
this->pool = pool;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE)
++num_ands;
}
num_in = cf->n1+cf->n2;
total_pre = num_in + num_ands + 3*ssp;
fpre = new FpreMP<nP>(io, pool, party, _delta, ssp);
Delta = fpre->Delta;
if(party == 1) {
GTM = new block[num_ands][4][nP+1];
GTK = new block[num_ands][4][nP+1];
GTv = new bool[num_ands][4];
GT = new block[num_ands][nP+1][4][nP+1];
}
labels = new block[cf->num_wire];
for(int i = 1; i <= nP; ++i) {
key[i] = new block[cf->num_wire];
mac[i] = new block[cf->num_wire];
ANDS_key[i] = new block[num_ands*3];
ANDS_mac[i] = new block[num_ands*3];
preprocess_mac[i] = new block[total_pre];
preprocess_key[i] = new block[total_pre];
sigma_mac[i] = new block[num_ands];
sigma_key[i] = new block[num_ands];
eval_labels[i] = new block[cf->num_wire];
}
value = new bool[cf->num_wire];
ANDS_value = new bool[num_ands*3];
preprocess_value = new bool[total_pre];
sigma_value = new bool[num_ands];
}
~CMPC() {
delete fpre;
if(party == 1) {
delete[] GTM;
delete[] GTK;
delete[] GTv;
delete[] GT;
}
delete[] labels;
for(int i = 1; i <= nP; ++i) {
delete[] key[i];
delete[] mac[i];
delete[] ANDS_key[i];
delete[] ANDS_mac[i];
delete[] preprocess_mac[i];
delete[] preprocess_key[i];
delete[] sigma_mac[i];
delete[] sigma_key[i];
delete[] eval_labels[i];
}
delete[] value;
delete[] ANDS_value;
delete[] preprocess_value;
delete[] sigma_value;
}
PRG prg;
void function_independent() {
if(party != 1)
prg.random_block(labels, cf->num_wire);
fpre->compute(ANDS_mac, ANDS_key, ANDS_value, num_ands);
prg.random_bool(preprocess_value, total_pre);
fpre->abit->compute(preprocess_mac, preprocess_key, preprocess_value, total_pre);
auto ret = fpre->abit->check(preprocess_mac, preprocess_key, preprocess_value, total_pre);
ret.get();
for(int i = 1; i <= nP; ++i) {
memcpy(key[i], preprocess_key[i], num_in * sizeof(block));
memcpy(mac[i], preprocess_mac[i], num_in * sizeof(block));
}
memcpy(value, preprocess_value, num_in * sizeof(bool));
#ifdef __debug
check_MAC<nP>(io, ANDS_mac, ANDS_key, ANDS_value, Delta, num_ands*3, party);
check_correctness<nP>(io, ANDS_value, num_ands, party);
#endif
// ret.get();
}
void function_dependent() {
int ands = num_in;
bool * x[nP+1];
bool * y[nP+1];
for(int i = 1; i <= nP; ++i) {
x[i] = new bool[num_ands];
y[i] = new bool[num_ands];
}
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = preprocess_key[j][ands];
mac[j][cf->gates[4*i+2]] = preprocess_mac[j][ands];
}
value[cf->gates[4*i+2]] = preprocess_value[ands];
++ands;
}
}
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]] ^ key[j][cf->gates[4*i+1]];
mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]] ^ mac[j][cf->gates[4*i+1]];
}
value[cf->gates[4*i+2]] = value[cf->gates[4*i]] != value[cf->gates[4*i+1]];
if(party != 1)
labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == NOT_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]];
mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]];
}
value[cf->gates[4*i+2]] = value[cf->gates[4*i]];
if(party != 1)
labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ Delta;
}
}
#ifdef __debug
check_MAC<nP>(io, mac, key, value, Delta, cf->num_wire, party);
#endif
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
x[party][ands] = value[cf->gates[4*i]] != ANDS_value[3*ands];
y[party][ands] = value[cf->gates[4*i+1]] != ANDS_value[3*ands+1];
ands++;
}
}
vector<future<void>> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, x, y, party2]() {
io->send_data(party2, x[party], num_ands);
io->send_data(party2, y[party], num_ands);
io->flush(party2);
}));
res.push_back(pool->enqueue([this, x, y, party2]() {
io->recv_data(party2, x[party2], num_ands);
io->recv_data(party2, y[party2], num_ands);
}));
}
joinNclean(res);
for(int i = 2; i <= nP; ++i) for(int j = 0; j < num_ands; ++j) {
x[1][j] = x[1][j] != x[i][j];
y[1][j] = y[1][j] != y[i][j];
}
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = ANDS_mac[j][3*ands+2];
sigma_key[j][ands] = ANDS_key[j][3*ands+2];
}
sigma_value[ands] = ANDS_value[3*ands+2];
if(x[1][ands]) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands+1];
sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands+1];
}
sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands+1];
}
if(y[1][ands]) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands];
sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands];
}
sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands];
}
if(x[1][ands] and y[1][ands]) {
if(party != 1)
sigma_key[1][ands] = sigma_key[1][ands] ^ Delta;
else
sigma_value[ands] = not sigma_value[ands];
}
ands++;
}
}//sigma_[] stores the and of input wires to each AND gates
#ifdef __debug_
check_MAC<nP>(io, sigma_mac, sigma_key, sigma_value, Delta, num_ands, party);
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
bool tmp[] = { value[cf->gates[4*i]], value[cf->gates[4*i+1]], sigma_value[ands]};
check_correctness(io, tmp, 1, party);
ands++;
}
}
#endif
ands = 0;
block H[4][nP+1];
block K[4][nP+1], M[4][nP+1];
bool r[4];
if(party != 1) {
for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) {
r[0] = sigma_value[ands] != value[cf->gates[4*i+2]];
r[1] = r[0] != value[cf->gates[4*i]];
r[2] = r[0] != value[cf->gates[4*i+1]];
r[3] = r[1] != value[cf->gates[4*i+1]];
for(int j = 1; j <= nP; ++j) {
M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]];
M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]];
M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]];
M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]];
K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]];
K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]];
K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]];
K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]];
}
K[3][1] = K[3][1] ^ Delta;
Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], ands);
for(int j = 0; j < 4; ++j) {
for(int k = 1; k <= nP; ++k) if(k != party) {
H[j][k] = H[j][k] ^ M[j][k];
H[j][party] = H[j][party] ^ K[j][k];
}
H[j][party] = H[j][party] ^ labels[cf->gates[4*i+2]];
if(r[j])
H[j][party] = H[j][party] ^ Delta;
}
for(int j = 0; j < 4; ++j)
io->send_data(1, H[j]+1, sizeof(block)*(nP));
++ands;
}
io->flush(1);
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, party2]() {
for(int i = 0; i < num_ands; ++i)
for(int j = 0; j < 4; ++j)
io->recv_data(party2, GT[i][party2][j]+1, sizeof(block)*(nP));
}));
}
for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) {
r[0] = sigma_value[ands] != value[cf->gates[4*i+2]];
r[1] = r[0] != value[cf->gates[4*i]];
r[2] = r[0] != value[cf->gates[4*i+1]];
r[3] = r[1] != value[cf->gates[4*i+1]];
r[3] = r[3] != true;
for(int j = 1; j <= nP; ++j) {
M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]];
M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]];
M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]];
M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]];
K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]];
K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]];
K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]];
K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]];
}
memcpy(GTK[ands], K, sizeof(block)*4*(nP+1));
memcpy(GTM[ands], M, sizeof(block)*4*(nP+1));
memcpy(GTv[ands], r, sizeof(bool)*4);
++ands;
}
joinNclean(res);
}
for(int i = 1; i <= nP; ++i) {
delete[] x[i];
delete[] y[i];
}
}
void online (bool * input, bool * output) {
bool * mask_input = new bool[cf->num_wire];
for(int i = 0; i < num_in; ++i)
mask_input[i] = input[i] != value[i];
if(party != 1) {
io->send_data(1, mask_input, num_in);
io->flush(1);
io->recv_data(1, mask_input, num_in);
} else {
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i) tmp[i] = new bool[num_in];
vector<future<void>> res;
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, tmp, party2]() {
io->recv_data(party2, tmp[party2], num_in);
}));
}
joinNclean(res);
for(int i = 0; i < num_in; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[i] = tmp[j][i] != mask_input[i];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, mask_input, party2]() {
io->send_data(party2, mask_input, num_in);
io->flush(party2);
}));
}
joinNclean(res);
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
}
if(party!= 1) {
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
}
io->flush(1);
} else {
vector<future<void>> res;
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, party2]() {
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
}));
}
joinNclean(res);
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
block t0 = GTK[ands][index][j] ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}
}
ands++;
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
}
}
}
if(party != 1) {
io->send_data(1, value+cf->num_wire - cf->n3, cf->n3);
io->flush(1);
} else {
vector<future<void>> res;
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i)
tmp[i] = new bool[cf->n3];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, tmp, party2]() {
io->recv_data(party2, tmp[party2], cf->n3);
}));
}
joinNclean(res);
for(int i = 0; i < cf->n3; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 0; i < cf->n3; ++i)
mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3);
}
delete[] mask_input;
}
void Hash(block H[4][nP+1], const block & a, const block & b, uint64_t idx) {
block T[4];
T[0] = sigma(a);
T[1] = sigma(a ^ Delta);
T[2] = sigma(sigma(b));
T[3] = sigma(sigma(b ^ Delta));
H[0][0] = T[0] ^ T[2];
H[1][0] = T[0] ^ T[3];
H[2][0] = T[1] ^ T[2];
H[3][0] = T[1] ^ T[3];
for(int j = 0; j < 4; ++j) for(int i = 1; i <= nP; ++i) {
H[j][i] = H[j][0] ^ makeBlock(4*idx+j, i);
}
for(int j = 0; j < 4; ++j) {
prp.permute_block(H[j]+1, nP);
}
}
void Hash(block H[nP+1], const block &a, const block& b, uint64_t idx, uint64_t row) {
H[0] = sigma(a) ^ sigma(sigma(b));
for(int i = 1; i <= nP; ++i) {
H[i] = H[0] ^ makeBlock(4*idx+row, i);
}
prp.permute_block(H+1, nP);
}
string tostring(bool a) {
if(a) return "T";
else return "F";
}
void online (bool * input, bool * output, int* start, int* end) {
bool * mask_input = new bool[cf->num_wire];
bool * input_mask[nP+1];
for(int i = 0; i <= nP; ++i) input_mask[i] = new bool[end[party] - start[party]];
memcpy(input_mask[party], value+start[party], end[party] - start[party]);
memcpy(input_mask[0], input+start[party], end[party] - start[party]);
vector<future<bool>> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
res.push_back(pool->enqueue([this, start, end, party2]() {
char dig[Hash::DIGEST_SIZE];
io->send_data(party2, value+start[party2], end[party2]-start[party2]);
emp::Hash::hash_once(dig, mac[party2]+start[party2], (end[party2]-start[party2])*sizeof(block));
io->send_data(party2, dig, Hash::DIGEST_SIZE);
io->flush(party2);
return false;
}));
res.push_back(pool->enqueue([this, start, end, input_mask, party2]() {
char dig[Hash::DIGEST_SIZE];
char dig2[Hash::DIGEST_SIZE];
io->recv_data(party2, input_mask[party2], end[party]-start[party]);
block * tmp = new block[end[party]-start[party]];
for(int i = 0; i < end[party] - start[party]; ++i) {
tmp[i] = key[party2][i+start[party]];
if(input_mask[party2][i])tmp[i] = tmp[i] ^ Delta;
}
emp::Hash::hash_once(dig2, tmp, (end[party]-start[party])*sizeof(block));
io->recv_data(party2, dig, Hash::DIGEST_SIZE);
delete[] tmp;
return strncmp(dig, dig2, Hash::DIGEST_SIZE) != 0;
}));
}
if(joinNcleanCheat(res)) error("cheat!");
for(int i = 1; i <= nP; ++i)
for(int j = 0; j < end[party] - start[party]; ++j)
input_mask[0][j] = input_mask[0][j] != input_mask[i][j];
if(party != 1) {
io->send_data(1, input_mask[0], end[party] - start[party]);
io->flush(1);
io->recv_data(1, mask_input, num_in);
} else {
vector<future<void>> res;
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, mask_input, start, end , party2]() {
io->recv_data(party2, mask_input+start[party2], end[party2] - start[party2]);
}));
}
joinNclean(res);
memcpy(mask_input, input_mask[0], end[1]-start[1]);
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, mask_input, party2]() {
io->send_data(party2, mask_input, num_in);
io->flush(party2);
}));
}
joinNclean(res);
}
if(party!= 1) {
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
}
io->flush(1);
} else {
vector<future<void>> res;
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, party2]() {
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
}));
}
joinNclean(res);
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
block t0 = GTK[ands][index][j] ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}
}
ands++;
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
}
}
}
if(party != 1) {
io->send_data(1, value+cf->num_wire - cf->n3, cf->n3);
io->flush(1);
} else {
vector<future<void>> res;
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i)
tmp[i] = new bool[cf->n3];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, tmp, party2]() {
io->recv_data(party2, tmp[party2], cf->n3);
}));
}
joinNclean(res);
for(int i = 0; i < cf->n3; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 0; i < cf->n3; ++i)
mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3);
}
delete[] mask_input;
}
void online (FlexIn<nP> * input, FlexOut<nP> *output) {
bool * mask_input = new bool[cf->num_wire];
input->associate_cmpc(pool, value, mac, key, io, Delta);
input->input(mask_input);
if(party!= 1) {
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
}
io->flush(1);
} else {
vector<future<void>> res;
for(int i = 2; i <= nP; ++i) {
int party2 = i;
res.push_back(pool->enqueue([this, party2]() {
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
}));
}
joinNclean(res);
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
block t0 = GTK[ands][index][j] ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}
}
ands++;
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
}
}
}
output->associate_cmpc(pool, value, mac, key, eval_labels, labels, io, Delta);
output->output(mask_input, cf->num_wire - cf->n3);
delete[] mask_input;
}
};
#endif// CMPC_H__

124
src/cpp/emp-agmpc/netmp.h Normal file
View File

@@ -0,0 +1,124 @@
#ifndef NETIOMP_H__
#define NETIOMP_H__
#include <emp-tool/emp-tool.h>
#include "cmpc_config.h"
using namespace emp;
template<int nP>
class NetIOMP { public:
NetIO*ios[nP+1];
NetIO*ios2[nP+1];
int party;
bool sent[nP+1];
NetIOMP(int party, int port) {
this->party = party;
memset(sent, false, nP+1);
for(int i = 1; i <= nP; ++i)for(int j = 1; j <= nP; ++j)if(i < j){
if(i == party) {
#ifdef LOCALHOST
usleep(1000);
ios[j] = new NetIO(IP[j], port+2*(i*nP+j), true);
#else
usleep(1000);
ios[j] = new NetIO(IP[j], port+2*(i), true);
#endif
ios[j]->set_nodelay();
#ifdef LOCALHOST
usleep(1000);
ios2[j] = new NetIO(nullptr, port+2*(i*nP+j)+1, true);
#else
usleep(1000);
ios2[j] = new NetIO(nullptr, port+2*(j)+1, true);
#endif
ios2[j]->set_nodelay();
} else if(j == party) {
#ifdef LOCALHOST
usleep(1000);
ios[i] = new NetIO(nullptr, port+2*(i*nP+j), true);
#else
usleep(1000);
ios[i] = new NetIO(nullptr, port+2*(i), true);
#endif
ios[i]->set_nodelay();
#ifdef LOCALHOST
usleep(1000);
ios2[i] = new NetIO(IP[i], port+2*(i*nP+j)+1, true);
#else
usleep(1000);
ios2[i] = new NetIO(IP[i], port+2*(j)+1, true);
#endif
ios2[i]->set_nodelay();
}
}
}
int64_t count() {
int64_t res = 0;
for(int i = 1; i <= nP; ++i) if(i != party){
res += ios[i]->counter;
res += ios2[i]->counter;
}
return res;
}
~NetIOMP() {
for(int i = 1; i <= nP; ++i)
if(i != party) {
delete ios[i];
delete ios2[i];
}
}
void send_data(int dst, const void * data, size_t len) {
if(dst != 0 and dst!= party) {
if(party < dst)
ios[dst]->send_data(data, len);
else
ios2[dst]->send_data(data, len);
sent[dst] = true;
}
#ifdef __MORE_FLUSH
flush(dst);
#endif
}
void recv_data(int src, void * data, size_t len) {
if(src != 0 and src!= party) {
if(sent[src])flush(src);
if(src < party)
ios[src]->recv_data(data, len);
else
ios2[src]->recv_data(data, len);
}
}
NetIO*& get(size_t idx, bool b = false){
if (b)
return ios[idx];
else return ios2[idx];
}
void flush(int idx = 0) {
if(idx == 0) {
for(int i = 1; i <= nP; ++i)
if(i != party) {
ios[i]->flush();
ios2[i]->flush();
}
} else {
if(party < idx)
ios[idx]->flush();
else
ios2[idx]->flush();
}
}
void sync() {
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if(i < j) {
if(i == party) {
ios[j]->sync();
ios2[j]->sync();
} else if(j == party) {
ios[i]->sync();
ios2[i]->sync();
}
}
}
};
#endif //NETIOMP_H__