Start using Vec, NVec to avoid nP template

This commit is contained in:
Andrew Morris
2025-01-28 17:07:50 +11:00
parent b36132f8e1
commit b00d157a60
9 changed files with 401 additions and 514 deletions

View File

@@ -16,7 +16,7 @@ int main(int argc, char** argv) {
NetIOMP<nP> *ios[2] = {&io, &io2};
BristolFormat cf(circuit_file_location.c_str());
CMPC<nP>* mpc = new CMPC<nP>(ios, party, &cf);
CMPC<nP>* mpc = new CMPC<nP>(nP, ios, party, &cf);
cout <<"Setup:\t"<<party<<"\n";
mpc->function_independent();

View File

@@ -4,6 +4,7 @@
#include <emp-ot/emp-ot.h>
#include "netmp.h"
#include "helper.h"
#include "nvec.h"
template<int nP>
class ABitMP { public:
@@ -64,21 +65,21 @@ class ABitMP { public:
delete abit2[i];
}
}
void compute(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
void compute(NVec<block>& MAC, NVec<block>& KEY, bool* data, int length) {
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
if (party < party2) {
abit2[party2]->recv_cot(MAC[party2], data, length);
abit2[party2]->recv_cot(&MAC.at(party2, 0), data, length);
io->flush(party2);
abit1[party2]->send_cot(KEY[party2], length);
abit1[party2]->send_cot(&KEY.at(party2, 0), length);
io->flush(party2);
} else {
abit1[party2]->send_cot(KEY[party2], length);
abit1[party2]->send_cot(&KEY.at(party2, 0), length);
io->flush(party2);
abit2[party2]->recv_cot(MAC[party2], data, length);
abit2[party2]->recv_cot(&MAC.at(party2, 0), data, length);
io->flush(party2);
}
}
@@ -87,12 +88,12 @@ class ABitMP { public:
#endif
}
void check(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
void check(const NVec<block>& MAC, const NVec<block>& KEY, bool* data, int length) {
check1(MAC, KEY, data, length);
check2(MAC, KEY, data, length);
}
void check1(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
void check1(const NVec<block>& MAC, const NVec<block>& KEY, bool* data, int length) {
block seed = sampleRandom(io, &prg, party);
PRG prg2(&seed);
uint8_t * tmp;
@@ -141,13 +142,13 @@ class ABitMP { public:
for(int tt = 0; tt < length/SIZE; tt++) {
int start = SIZE*tt;
for(int i = SIZE*tt; i < SIZE*(tt+1) and i < length; i+=chk) {
tMAC[(i-start)/chk][1] = MAC[k][i];
tMAC[(i-start)/chk][2] = MAC[k][i+1];
tMAC[(i-start)/chk][3] = MAC[k][i] ^ MAC[k][i+1];
tMAC[(i-start)/chk][1] = MAC.at(k, i);
tMAC[(i-start)/chk][2] = MAC.at(k, i+1);
tMAC[(i-start)/chk][3] = MAC.at(k, i) ^ MAC.at(k, i+1);
tKEY[(i-start)/chk][1] = KEY[k][i];
tKEY[(i-start)/chk][2] = KEY[k][i+1];
tKEY[(i-start)/chk][3] = KEY[k][i] ^ KEY[k][i+1];
tKEY[(i-start)/chk][1] = KEY.at(k, i);
tKEY[(i-start)/chk][2] = KEY.at(k, i+1);
tKEY[(i-start)/chk][3] = KEY.at(k, i) ^ KEY.at(k, i+1);
for(int j = 0; j < ssp; ++j) {
Ms[k][j] = Ms[k][j] ^ tMAC[(i-start)/chk][*tmpptr];
Ks[k][j] = Ks[k][j] ^ tKEY[(i-start)/chk][*tmpptr];
@@ -187,7 +188,7 @@ class ABitMP { public:
}
}
void check2(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) {
void check2(const NVec<block>& MAC, const NVec<block> KEY, bool* data, int length) {
//last 2*ssp are garbage already.
block * Ks[2], *Ms[nP+1][nP+1];
block * KK[nP+1];
@@ -208,7 +209,7 @@ class ABitMP { public:
for(int i = 0; i < ssp; ++i) {
Ks[0][i] = zero_block;
for(int j = 1; j <= nP; ++j) if(j != party)
Ks[0][i] = Ks[0][i] ^ KEY[j][length-3*ssp+i];
Ks[0][i] = Ks[0][i] ^ KEY.at(j, length-3*ssp+i);
Ks[1][i] = Ks[0][i] ^ Delta;
Hash::hash_once(dgst0[party*ssp+i], &Ks[0][i], sizeof(block));
@@ -217,7 +218,7 @@ class ABitMP { public:
Hash h;
h.put(data+length-3*ssp, ssp);
for(int j = 1; j <= nP; ++j) if(j != party) {
h.put(&MAC[j][length-3*ssp], ssp*sizeof(block));
h.put(&MAC.at(j, length-3*ssp), ssp*sizeof(block));
}
h.digest(dgst[party]);
@@ -233,14 +234,14 @@ class ABitMP { public:
vector<bool> res2;
for(int k = 1; k <= nP; ++k) if(k!= party)
memcpy(Ms[party][k], MAC[k]+length-3*ssp, sizeof(block)*ssp);
memcpy(Ms[party][k], &MAC.at(k, length-3*ssp), sizeof(block)*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, data + length - 3*ssp, ssp);
for(int k = 1; k <= nP; ++k) if(k != party)
io->send_data(party2, MAC[k] + length - 3*ssp, sizeof(block)*ssp);
io->send_data(party2, &MAC.at(k, length - 3*ssp), sizeof(block)*ssp);
res2.push_back(false);
Hash h;

View File

@@ -6,5 +6,6 @@
#include "emp-agmpc/helper.h"
#include "emp-agmpc/mpc.h"
#include "emp-agmpc/netmp.h"
#include "emp-agmpc/nvector.h"
#include "emp-agmpc/vec.h"
#include "emp-agmpc/nvec.h"
#endif// EMP_AGMPC_H

View File

@@ -58,12 +58,12 @@ public:
authenticated_share_assignment.clear();
}
void associate_cmpc(bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], NetIOMP<nP> *associated_io, block associated_Delta) {
void associate_cmpc(bool *associated_value, NVec<block>& associated_mac, NVec<block>& associated_key, NetIOMP<nP> *associated_io, block associated_Delta) {
this->cmpc_associated = true;
this->value = associated_value;
for(int j = 1; j <= nP; j++) {
this->mac[j] = associated_mac[j];
this->key[j] = associated_key[j];
this->mac[j] = &associated_mac.at(j, 0);
this->key[j] = &associated_key.at(j, 0);
}
this->io = associated_io;
this->Delta = associated_Delta;
@@ -539,17 +539,17 @@ public:
authenticated_share_results.clear();
}
void associate_cmpc(bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], block *associated_eval_labels[nP + 1], block *associated_labels, NetIOMP<nP> *associated_io, block associated_Delta) {
void associate_cmpc(bool *associated_value, NVec<block>& associated_mac, NVec<block>& associated_key, NVec<block>& associated_eval_labels, Vec<block>& associated_labels, NetIOMP<nP> *associated_io, block associated_Delta) {
this->cmpc_associated = true;
this->value = associated_value;
this->labels = associated_labels;
this->labels = &associated_labels.at(0);
for (int j = 1; j <= nP; j++) {
this->mac[j] = associated_mac[j];
this->key[j] = associated_key[j];
this->mac[j] = &associated_mac.at(j, 0);
this->key[j] = &associated_key.at(j, 0);
}
if (party == ALICE){
for (int j = 2; j <= nP; j++) {
this->eval_labels[j] = associated_eval_labels[j];
this->eval_labels[j] = &associated_eval_labels.at(j, 0);
}
}
this->io = associated_io;

View File

@@ -6,6 +6,7 @@
#include "abitmp.h"
#include "netmp.h"
#include "cmpc_config.h"
#include "nvec.h"
using namespace emp;
template<int nP>
@@ -44,87 +45,77 @@ class FpreMP { public:
return 4;
else return 5;
}
void compute(block * MAC[nP+1], block * KEY[nP+1], bool * r, int length) {
void compute(NVec<block>& MAC, NVec<block>& KEY, bool* r, int length) {
int64_t bucket_size = get_bucket_size(length);
block * tMAC[nP+1];
block * tKEY[nP+1];
block * tKEYphi[nP+1];
block * tMACphi[nP+1];
block * phi;
block *X [nP+1];
bool *tr = new bool[length*bucket_size*3+3*ssp];
phi = new block[length*bucket_size];
bool *s[nP+1], *e = new bool[length*bucket_size];
for(int i = 1; i <= nP; ++i) {
tMAC[i] = new block[length*bucket_size*3+3*ssp];
tKEY[i] = new block[length*bucket_size*3+3*ssp];
tKEYphi[i] = new block[length*bucket_size+3*ssp];
tMACphi[i] = new block[length*bucket_size+3*ssp];
X[i] = new block[ssp];
}
for(int i = 0; i <= nP; ++i) {
s[i] = new bool[length*bucket_size];
memset(s[i], 0, length*bucket_size);
}
prg.random_bool(tr, length*bucket_size*3+3*ssp);
NVec<block> tMAC(nP+1, length*bucket_size*3+3*ssp);
NVec<block> tKEY(nP+1, length*bucket_size*3+3*ssp);
NVec<block> tKEYphi(nP+1, length*bucket_size*3+3*ssp);
NVec<block> tMACphi(nP+1, length*bucket_size*3+3*ssp);
Vec<block> phi(length*bucket_size);
NVec<block> X(nP+1, ssp);
Vec<bool> tr(length*bucket_size*3+3*ssp);
NVec<bool> s(nP+1, length*bucket_size);
Vec<bool> e(length*bucket_size);
prg.random_bool(&tr[0], length*bucket_size*3+3*ssp);
// memset(tr, false, length*bucket_size*3+3*ssp);
abit->compute(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp);
abit->compute(tMAC, tKEY, &tr[0], length*bucket_size*3 + 3*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j ) {
if(i == party) {
prgs[j].random_bool(s[j], length*bucket_size);
prgs[j].random_bool(&s.at(j, 0), length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = garble(tKEY[j], tr, s[j], k, j);
uint8_t data = garble(&tKEY.at(j, 0), &tr[0], &s.at(j, 0), k, j);
io->send_data(j, &data, 1);
s[j][k] = (s[j][k] != (tr[3*k] and tr[3*k+1]));
s.at(j, k) = (s.at(j, k) != (tr[3*k] and tr[3*k+1]));
}
io->flush(j);
} else if (j == party) {
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = 0;
io->recv_data(i, &data, 1);
bool tmp = evaluate(data, tMAC[i], tr, k, i);
s[i][k] = (tmp != (tr[3*k] and tr[3*k+1]));
bool tmp = evaluate(data, &tMAC.at(i, 0), &tr[0], k, i);
s.at(i, k) = (tmp != (tr[3*k] and tr[3*k+1]));
}
}
}
for(int k = 0; k < length*bucket_size; ++k) {
s[0][k] = (tr[3*k] and tr[3*k+1]);
s.at(0, k) = (tr[3*k] and tr[3*k+1]);
for(int i = 1; i <= nP; ++i)
if (i != party) {
s[0][k] = (s[0][k] != s[i][k]);
s.at(0, k) = (s.at(0, k) != s.at(i, k));
}
e[k] = (s[0][k] != tr[3*k+2]);
tr[3*k+2] = s[0][k];
e[k] = (s.at(0, k) != tr[3*k+2]);
tr[3*k+2] = s.at(0, k);
}
#ifdef __debug
check_correctness(io, tr, length*bucket_size, party);
check_correctness(io, &tr[0], length*bucket_size, party);
#endif
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, e, length*bucket_size);
io->send_data(party2, &e[0], length*bucket_size);
io->flush(party2);
bool * tmp = new bool[length*bucket_size];
io->recv_data(party2, tmp, length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
if(tmp[k])
tKEY[party2][3*k+2] = tKEY[party2][3*k+2] ^ Delta;
tKEY.at(party2, 3*k+2) = tKEY.at(party2, 3*k+2) ^ Delta;
}
delete[] tmp;
}
#ifdef __debug
check_MAC(io, tMAC, tKEY, tr, Delta, length*bucket_size*3, party);
#endif
abit->check(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp);
abit->check(tMAC, tKEY, &tr[0], length*bucket_size*3 + 3*ssp);
//check compute phi
for(int k = 0; k < length*bucket_size; ++k) {
phi[k] = zero_block;
for(int i = 1; i <= nP; ++i) if (i != party) {
phi[k] = phi[k] ^ tKEY[i][3*k+1];
phi[k] = phi[k] ^ tMAC[i][3*k+1];
phi[k] = phi[k] ^ tKEY.at(i, 3*k+1);
phi[k] = phi[k] ^ tMAC.at(i, 3*k+1);
}
if(tr[3*k+1])phi[k] = phi[k] ^ Delta;
}
@@ -136,10 +127,10 @@ class FpreMP { public:
{
block bH[2], tmpH[2];
for(int k = 0; k < length*bucket_size; ++k) {
bH[0] = tKEY[party2][3*k];
bH[0] = tKEY.at(party2, 3*k);
bH[1] = bH[0] ^ Delta;
HnID(prps+party2, bH, bH, 2*k, 2, tmpH);
tKEYphi[party2][k] = bH[0];
tKEYphi.at(party2, k) = bH[0];
bH[1] = bH[0] ^ bH[1];
bH[1] = phi[k] ^ bH[1];
io->send_data(party2, &bH[1], sizeof(block));
@@ -151,9 +142,9 @@ class FpreMP { public:
block bH;
for(int k = 0; k < length*bucket_size; ++k) {
io->recv_data(party2, &bH, sizeof(block));
block hin = sigma(tMAC[party2][3*k]) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi[party2][k] = prps2[party2].H(hin);
if(tr[3*k])tMACphi[party2][k] = tMACphi[party2][k] ^ bH;
block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi.at(party2, k) = prps2[party2].H(hin);
if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH;
}
}
} else {
@@ -161,19 +152,19 @@ class FpreMP { public:
block bH;
for(int k = 0; k < length*bucket_size; ++k) {
io->recv_data(party2, &bH, sizeof(block));
block hin = sigma(tMAC[party2][3*k]) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi[party2][k] = prps2[party2].H(hin);
if(tr[3*k])tMACphi[party2][k] = tMACphi[party2][k] ^ bH;
block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi.at(party2, k) = prps2[party2].H(hin);
if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH;
}
}
{
block bH[2], tmpH[2];
for(int k = 0; k < length*bucket_size; ++k) {
bH[0] = tKEY[party2][3*k];
bH[0] = tKEY.at(party2, 3*k);
bH[1] = bH[0] ^ Delta;
HnID(prps+party2, bH, bH, 2*k, 2, tmpH);
tKEYphi[party2][k] = bH[0];
tKEYphi.at(party2, k) = bH[0];
bH[1] = bH[0] ^ bH[1];
bH[1] = phi[k] ^ bH[1];
io->send_data(party2, &bH[1], sizeof(block));
@@ -191,19 +182,19 @@ class FpreMP { public:
#endif
//tKEYphti use as H
for(int k = 0; k < length*bucket_size; ++k) {
tKEYphi[party][k] = zero_block;
tKEYphi.at(party, k) = zero_block;
for(int i = 1; i <= nP; ++i) if (i != party) {
tKEYphi[party][k] = tKEYphi[party][k] ^ tKEYphi[i][k];
tKEYphi[party][k] = tKEYphi[party][k] ^ tMACphi[i][k];
tKEYphi[party][k] = tKEYphi[party][k] ^ tKEY[i][3*k+2];
tKEYphi[party][k] = tKEYphi[party][k] ^ tMAC[i][3*k+2];
tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tKEYphi.at(i, k);
tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tMACphi.at(i, k);
tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tKEY.at(i, 3*k+2);
tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tMAC.at(i, 3*k+2);
}
if(tr[3*k]) tKEYphi[party][k] = tKEYphi[party][k] ^ phi[k];
if(tr[3*k+2])tKEYphi[party][k] = tKEYphi[party][k] ^ Delta;
if(tr[3*k]) tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ phi[k];
if(tr[3*k+2])tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ Delta;
}
#ifdef __debug
check_zero(tKEYphi[party], length*bucket_size);
check_zero(&tKEYphi.at(party, 0), length*bucket_size);
#endif
block prg_key = sampleRandom(io, &prg, party);
@@ -212,9 +203,9 @@ class FpreMP { public:
bool * tmp = new bool[length*bucket_size];
for(int i = 0; i < ssp; ++i) {
prgf.random_bool(tmp, length*bucket_size);
X[party][i] = inProd(tmp, tKEYphi[party], length*bucket_size);
X.at(party, i) = inProd(tmp, &tKEYphi.at(party, 0), length*bucket_size);
}
Hash::hash_once(dgst[party], X[party], sizeof(block)*ssp);
Hash::hash_once(dgst[party], &X.at(party, 0), sizeof(block)*ssp);
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
@@ -226,18 +217,18 @@ class FpreMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, X[party], sizeof(block)*ssp);
io->recv_data(party2, X[party2], sizeof(block)*ssp);
io->send_data(party2, &X.at(party, 0), sizeof(block)*ssp);
io->recv_data(party2, &X.at(party2, 0), sizeof(block)*ssp);
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, X[party2], sizeof(block)*ssp);
Hash::hash_once(tmp, &X.at(party2, 0), sizeof(block)*ssp);
res2.push_back(strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0);
}
if(checkCheat(res2)) error("commitment");
for(int i = 2; i <= nP; ++i)
xorBlocks_arr(X[1], X[1], X[i], ssp);
for(int i = 0; i < ssp; ++i)X[2][i] = zero_block;
if(!cmpBlock(X[1], X[2], ssp)) error("AND check");
xorBlocks_arr(&X.at(1, 0), &X.at(1, 0), &X.at(i, 0), ssp);
for(int i = 0; i < ssp; ++i)X.at(2, i) = zero_block;
if(!cmpBlock(&X.at(1, 0), &X.at(2, 0), ssp)) error("AND check");
//land -> and
block S = sampleRandom<nP>(io, &prg, party);
@@ -264,17 +255,17 @@ class FpreMP { public:
for(int j = 0; j < bucket_size-1; ++j)
d[party][(bucket_size-1)*i+j] = tr[3*location[i*bucket_size]+1] != tr[3*location[i*bucket_size+1+j]+1];
for(int j = 1; j <= nP; ++j) if (j!= party) {
memcpy(MAC[j]+3*i, tMAC[j]+3*location[i*bucket_size], 3*sizeof(block));
memcpy(KEY[j]+3*i, tKEY[j]+3*location[i*bucket_size], 3*sizeof(block));
memcpy(&MAC.at(j, 3*i), &tMAC.at(j, 3*location[i*bucket_size]), 3*sizeof(block));
memcpy(&KEY.at(j, 3*i), &tKEY.at(j, 3*location[i*bucket_size]), 3*sizeof(block));
for(int k = 1; k < bucket_size; ++k) {
MAC[j][3*i] = MAC[j][3*i] ^ tMAC[j][3*location[i*bucket_size+k]];
KEY[j][3*i] = KEY[j][3*i] ^ tKEY[j][3*location[i*bucket_size+k]];
MAC.at(j, 3*i) = MAC.at(j, 3*i) ^ tMAC.at(j, 3*location[i*bucket_size+k]);
KEY.at(j, 3*i) = KEY.at(j, 3*i) ^ tKEY.at(j, 3*location[i*bucket_size+k]);
MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]+2];
KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]+2];
MAC.at(j, 3*i+2) = MAC.at(j, 3*i+2) ^ tMAC.at(j, 3*location[i*bucket_size+k]+2);
KEY.at(j, 3*i+2) = KEY.at(j, 3*i+2) ^ tKEY.at(j, 3*location[i*bucket_size+k]+2);
}
}
memcpy(r+3*i, tr+3*location[i*bucket_size], 3);
memcpy(&r[3*i], &tr[3*location[i*bucket_size]], 3);
for(int k = 1; k < bucket_size; ++k) {
r[3*i] = r[3*i] != tr[3*location[i*bucket_size+k]];
r[3*i+2] = r[3*i+2] != tr[3*location[i*bucket_size+k]+2];
@@ -295,8 +286,8 @@ class FpreMP { public:
for(int j = 1; j <= nP; ++j)if (j!= party) {
for(int k = 1; k < bucket_size; ++k)
if(d[1][(bucket_size-1)*i+k-1]) {
MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]];
KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]];
MAC.at(j, 3*i+2) = MAC.at(j, 3*i+2) ^ tMAC.at(j, 3*location[i*bucket_size+k]);
KEY.at(j, 3*i+2) = KEY.at(j, 3*i+2) ^ tKEY.at(j, 3*location[i*bucket_size+k]);
}
}
for(int k = 1; k < bucket_size; ++k)
@@ -311,23 +302,6 @@ class FpreMP { public:
#endif
// ret.get();
delete[] tr;
delete[] phi;
delete[] e;
delete[] dgst;
delete[] tmp;
delete[] location;
delete[] xs;
for(int i = 1; i <= nP; ++i) {
delete[] tMAC[i];
delete[] tKEY[i];
delete[] tMACphi[i];
delete[] tKEYphi[i];
delete[] X[i];
delete[] s[i];
delete[] d[i];
}
delete[] s[0];
}
//TODO: change to justGarble
@@ -368,13 +342,13 @@ class FpreMP { public:
return (tmp&0x1) != (res&0x1);
}
void check_MAC_phi(block * MAC[nP+1], block * KEY[nP+1], block * phi, bool * r, int length) {
void check_MAC_phi(const NVec<block>& MAC, const NVec<block>& KEY, block * phi, bool * r, int length) {
block * tmp = new block[length];
block *tD = new block[length];
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, phi, length*sizeof(block));
io->send_data(j, KEY[j], sizeof(block)*length);
io->send_data(j, &KEY.at(j, 0), sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, tD, length*sizeof(block));
@@ -382,7 +356,7 @@ class FpreMP { public:
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD[k];
}
if(!cmpBlock(MAC[i], tmp, length))
if(!cmpBlock(&MAC.at(i, 0), tmp, length))
error("check_MAC_phi failed!");
}
}

View File

@@ -3,6 +3,7 @@
#include <emp-tool/emp-tool.h>
#include "cmpc_config.h"
#include "netmp.h"
#include "nvec.h"
#include <future>
using namespace emp;
using std::future;
@@ -117,13 +118,13 @@ block sampleRandom(NetIOMP<nP> * io, PRG * prg, int party) {
}
template<int nP>
void check_MAC(NetIOMP<nP> * io, block * MAC[nP+1], block * KEY[nP+1], bool * r, block Delta, int length, int party) {
void check_MAC(NetIOMP<nP> * io, const NVec<block>& MAC, const NVec<block>& KEY, bool * r, block Delta, int length, int party) {
block * tmp = new block[length];
block tD;
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, &Delta, sizeof(block));
io->send_data(j, KEY[j], sizeof(block)*length);
io->send_data(j, &KEY.at(j, 0), sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, &tD, sizeof(block));
@@ -131,7 +132,7 @@ void check_MAC(NetIOMP<nP> * io, block * MAC[nP+1], block * KEY[nP+1], bool * r,
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD;
}
if(!cmpBlock(MAC[i], tmp, length))
if(!cmpBlock(&MAC.at(i, 0), tmp, length))
error("check_MAC failed!");
}
}

View File

@@ -3,45 +3,59 @@
#include "fpremp.h"
#include "abitmp.h"
#include "netmp.h"
#include "vec.h"
#include "nvec.h"
#include "flexible_input_output.h"
#include <emp-tool/emp-tool.h>
using namespace emp;
template<int nP>
template <int nP_deprecated>
class CMPC { public:
const static int SSP = 5;//5*8 in fact...
const block MASK = makeBlock(0x0ULL, 0xFFFFFULL);
FpreMP<nP>* fpre = nullptr;
block* mac[nP+1];
block* key[nP+1];
bool* value;
FpreMP<nP_deprecated>* fpre = nullptr;
block * preprocess_mac[nP+1];
block * preprocess_key[nP+1];
bool* preprocess_value;
int nP;
block * sigma_mac[nP+1];
block * sigma_key[nP+1];
bool * sigma_value;
NVec<block> mac; // dim: parties, wires
NVec<block> key; // dim: parties, wires
Vec<bool> value; // dim: wires
block * ANDS_mac[nP+1];
block * ANDS_key[nP+1];
bool * ANDS_value;
NVec<block> preprocess_mac; // dim: parties, total_pre
NVec<block> preprocess_key; // dim: parties, total_pre
Vec<bool> preprocess_value; // dim: total_pre
block * labels;
NVec<block> sigma_mac; // dim: parties, num_ands
NVec<block> sigma_key; // dim: parties, num_ands
Vec<bool> sigma_value; // dim: num_ands
NVec<block> ANDS_mac; // dim: parties, num_ands*3
NVec<block> ANDS_key; // dim: parties, num_ands*3
Vec<bool> ANDS_value; // dim: num_ands*3
Vec<block> labels; // dim: wires
BristolFormat * cf;
NetIOMP<nP> * io;
NetIOMP<nP_deprecated> * io;
int num_ands = 0, num_in;
int party, total_pre, ssp;
block Delta;
block (*GTM)[4][nP+1];
block (*GTK)[4][nP+1];
bool (*GTv)[4];
block (*GT)[nP+1][4][nP+1];
block * eval_labels[nP+1];
NVec<block> GTM; // dim: num_ands, 4, parties
NVec<block> GTK; // dim: num_ands, 4, parties
NVec<bool> GTv; // dim: num_ands, 4
NVec<block> GT; // dim: num_ands, parties, 4, parties
NVec<block> eval_labels; // dim: parties, wires
PRP prp;
CMPC(NetIOMP<nP> * io[2], int party, BristolFormat * cf, bool * _delta = nullptr, int ssp = 40) {
CMPC(
int nP,
NetIOMP<nP_deprecated> * io[2],
int party,
BristolFormat * cf,
bool * _delta = nullptr,
int ssp = 40
) {
this->nP = nP;
this->party = party;
this->io = io[0];
this->cf = cf;
@@ -53,96 +67,69 @@ class CMPC { public:
}
num_in = cf->n1+cf->n2;
total_pre = num_in + num_ands + 3*ssp;
fpre = new FpreMP<nP>(io, party, _delta, ssp);
fpre = new FpreMP<nP_deprecated>(io, party, _delta, ssp);
Delta = fpre->Delta;
if(party == 1) {
GTM = new block[num_ands][4][nP+1];
GTK = new block[num_ands][4][nP+1];
GTv = new bool[num_ands][4];
GT = new block[num_ands][nP+1][4][nP+1];
GTM.resize(num_ands, 4, nP+1);
GTK.resize(num_ands, 4, nP+1);
GTv.resize(num_ands, 4);
GT.resize(num_ands, nP+1, 4, nP+1);
}
labels = new block[cf->num_wire];
for(int i = 1; i <= nP; ++i) {
key[i] = new block[cf->num_wire];
mac[i] = new block[cf->num_wire];
ANDS_key[i] = new block[num_ands*3];
ANDS_mac[i] = new block[num_ands*3];
preprocess_mac[i] = new block[total_pre];
preprocess_key[i] = new block[total_pre];
sigma_mac[i] = new block[num_ands];
sigma_key[i] = new block[num_ands];
eval_labels[i] = new block[cf->num_wire];
}
value = new bool[cf->num_wire];
ANDS_value = new bool[num_ands*3];
preprocess_value = new bool[total_pre];
sigma_value = new bool[num_ands];
labels.resize(cf->num_wire);
key.resize(nP+1, cf->num_wire);
mac.resize(nP+1, cf->num_wire);
ANDS_key.resize(nP+1, num_ands*3);
ANDS_mac.resize(nP+1, num_ands*3);
preprocess_mac.resize(nP+1, total_pre);
preprocess_key.resize(nP+1, total_pre);
sigma_mac.resize(nP+1, num_ands);
sigma_key.resize(nP+1, num_ands);
eval_labels.resize(nP+1, cf->num_wire);
value.resize(cf->num_wire);
ANDS_value.resize(num_ands*3);
preprocess_value.resize(total_pre);
sigma_value.resize(num_ands);
}
~CMPC() {
delete fpre;
if(party == 1) {
delete[] GTM;
delete[] GTK;
delete[] GTv;
delete[] GT;
}
delete[] labels;
for(int i = 1; i <= nP; ++i) {
delete[] key[i];
delete[] mac[i];
delete[] ANDS_key[i];
delete[] ANDS_mac[i];
delete[] preprocess_mac[i];
delete[] preprocess_key[i];
delete[] sigma_mac[i];
delete[] sigma_key[i];
delete[] eval_labels[i];
}
delete[] value;
delete[] ANDS_value;
delete[] preprocess_value;
delete[] sigma_value;
}
PRG prg;
void function_independent() {
if(party != 1)
prg.random_block(labels, cf->num_wire);
prg.random_block(&labels[0], cf->num_wire);
fpre->compute(ANDS_mac, ANDS_key, ANDS_value, num_ands);
fpre->compute(ANDS_mac, ANDS_key, &ANDS_value[0], num_ands);
prg.random_bool(preprocess_value, total_pre);
fpre->abit->compute(preprocess_mac, preprocess_key, preprocess_value, total_pre);
fpre->abit->check(preprocess_mac, preprocess_key, preprocess_value, total_pre);
prg.random_bool(&preprocess_value[0], total_pre);
fpre->abit->compute(preprocess_mac, preprocess_key, &preprocess_value[0], total_pre);
fpre->abit->check(preprocess_mac, preprocess_key, &preprocess_value[0], total_pre);
for(int i = 1; i <= nP; ++i) {
memcpy(key[i], preprocess_key[i], num_in * sizeof(block));
memcpy(mac[i], preprocess_mac[i], num_in * sizeof(block));
memcpy(&key.at(i, 0), &preprocess_key.at(i, 0), num_in * sizeof(block));
memcpy(&mac.at(i, 0), &preprocess_mac.at(i, 0), num_in * sizeof(block));
}
memcpy(value, preprocess_value, num_in * sizeof(bool));
memcpy(&value[0], &preprocess_value[0], num_in * sizeof(bool));
#ifdef __debug
check_MAC<nP>(io, ANDS_mac, ANDS_key, ANDS_value, Delta, num_ands*3, party);
check_correctness<nP>(io, ANDS_value, num_ands, party);
check_MAC(nP, io, ANDS_mac, ANDS_key, ANDS_value, Delta, num_ands*3, party);
check_correctness(nP, io, ANDS_value, num_ands, party);
#endif
// ret.get();
}
void function_dependent() {
int ands = num_in;
bool * x[nP+1];
bool * y[nP+1];
for(int i = 1; i <= nP; ++i) {
x[i] = new bool[num_ands];
y[i] = new bool[num_ands];
}
NVec<bool> x(nP+1, num_ands);
NVec<bool> y(nP+1, num_ands);
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = preprocess_key[j][ands];
mac[j][cf->gates[4*i+2]] = preprocess_mac[j][ands];
key.at(j, cf->gates[4*i+2]) = preprocess_key.at(j, ands);
mac.at(j, cf->gates[4*i+2]) = preprocess_mac.at(j, ands);
}
value[cf->gates[4*i+2]] = preprocess_value[ands];
++ands;
@@ -152,16 +139,16 @@ class CMPC { public:
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]] ^ key[j][cf->gates[4*i+1]];
mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]] ^ mac[j][cf->gates[4*i+1]];
key.at(j, cf->gates[4*i+2]) = key.at(j, cf->gates[4*i]) ^ key.at(j, cf->gates[4*i+1]);
mac.at(j, cf->gates[4*i+2]) = mac.at(j, cf->gates[4*i]) ^ mac.at(j, cf->gates[4*i+1]);
}
value[cf->gates[4*i+2]] = value[cf->gates[4*i]] != value[cf->gates[4*i+1]];
if(party != 1)
labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == NOT_GATE) {
for(int j = 1; j <= nP; ++j) {
key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]];
mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]];
key.at(j, cf->gates[4*i+2]) = key.at(j, cf->gates[4*i]);
mac.at(j, cf->gates[4*i+2]) = mac.at(j, cf->gates[4*i]);
}
value[cf->gates[4*i+2]] = value[cf->gates[4*i]];
if(party != 1)
@@ -170,14 +157,14 @@ class CMPC { public:
}
#ifdef __debug
check_MAC<nP>(io, mac, key, value, Delta, cf->num_wire, party);
check_MAC(nP, io, mac, key, value, Delta, cf->num_wire, party);
#endif
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
x[party][ands] = value[cf->gates[4*i]] != ANDS_value[3*ands];
y[party][ands] = value[cf->gates[4*i+1]] != ANDS_value[3*ands+1];
x.at(party, ands) = value[cf->gates[4*i]] != ANDS_value[3*ands];
y.at(party, ands) = value[cf->gates[4*i+1]] != ANDS_value[3*ands+1];
ands++;
}
}
@@ -185,44 +172,44 @@ class CMPC { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, x[party], num_ands);
io->send_data(party2, y[party], num_ands);
io->send_data(party2, &x.at(party, 0), num_ands);
io->send_data(party2, &y.at(party, 0), num_ands);
io->flush(party2);
io->recv_data(party2, x[party2], num_ands);
io->recv_data(party2, y[party2], num_ands);
io->recv_data(party2, &x.at(party2, 0), num_ands);
io->recv_data(party2, &y.at(party2, 0), num_ands);
}
for(int i = 2; i <= nP; ++i) for(int j = 0; j < num_ands; ++j) {
x[1][j] = x[1][j] != x[i][j];
y[1][j] = y[1][j] != y[i][j];
x.at(1, j) = x.at(1, j) != x.at(i, j);
y.at(1, j) = y.at(1, j) != y.at(i, j);
}
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = ANDS_mac[j][3*ands+2];
sigma_key[j][ands] = ANDS_key[j][3*ands+2];
sigma_mac.at(j, ands) = ANDS_mac.at(j, 3*ands+2);
sigma_key.at(j, ands) = ANDS_key.at(j, 3*ands+2);
}
sigma_value[ands] = ANDS_value[3*ands+2];
if(x[1][ands]) {
if(x.at(1, ands)) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands+1];
sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands+1];
sigma_mac.at(j, ands) = sigma_mac.at(j, ands) ^ ANDS_mac.at(j, 3*ands+1);
sigma_key.at(j, ands) = sigma_key.at(j, ands) ^ ANDS_key.at(j, 3*ands+1);
}
sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands+1];
}
if(y[1][ands]) {
if(y.at(1, ands)) {
for(int j = 1; j <= nP; ++j) {
sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands];
sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands];
sigma_mac.at(j, ands) = sigma_mac.at(j, ands) ^ ANDS_mac.at(j, 3*ands);
sigma_key.at(j, ands) = sigma_key.at(j, ands) ^ ANDS_key.at(j, 3*ands);
}
sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands];
}
if(x[1][ands] and y[1][ands]) {
if(x.at(1, ands) and y.at(1, ands)) {
if(party != 1)
sigma_key[1][ands] = sigma_key[1][ands] ^ Delta;
sigma_key.at(1, ands) = sigma_key.at(1, ands) ^ Delta;
else
sigma_value[ands] = not sigma_value[ands];
}
@@ -230,7 +217,7 @@ class CMPC { public:
}
}//sigma_[] stores the and of input wires to each AND gates
#ifdef __debug_
check_MAC<nP>(io, sigma_mac, sigma_key, sigma_value, Delta, num_ands, party);
check_MAC(nP, io, sigma_mac, sigma_key, sigma_value, Delta, num_ands, party);
ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == AND_GATE) {
@@ -242,8 +229,9 @@ class CMPC { public:
#endif
ands = 0;
block H[4][nP+1];
block K[4][nP+1], M[4][nP+1];
NVec<block> H(4, nP+1);
NVec<block> K(4, nP+1);
NVec<block> M(4, nP+1);
bool r[4];
if(party != 1) {
for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) {
@@ -253,30 +241,30 @@ class CMPC { public:
r[3] = r[1] != value[cf->gates[4*i+1]];
for(int j = 1; j <= nP; ++j) {
M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]];
M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]];
M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]];
M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]];
M.at(0, j) = sigma_mac.at(j, ands) ^ mac.at(j, cf->gates[4*i+2]);
M.at(1, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i]);
M.at(2, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i+1]);
M.at(3, j) = M.at(1, j) ^ mac.at(j, cf->gates[4*i+1]);
K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]];
K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]];
K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]];
K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]];
K.at(0, j) = sigma_key.at(j, ands) ^ key.at(j, cf->gates[4*i+2]);
K.at(1, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i]);
K.at(2, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i+1]);
K.at(3, j) = K.at(1, j) ^ key.at(j, cf->gates[4*i+1]);
}
K[3][1] = K[3][1] ^ Delta;
K.at(3, 1) = K.at(3, 1) ^ Delta;
Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], ands);
for(int j = 0; j < 4; ++j) {
for(int k = 1; k <= nP; ++k) if(k != party) {
H[j][k] = H[j][k] ^ M[j][k];
H[j][party] = H[j][party] ^ K[j][k];
H.at(j, k) = H.at(j, k) ^ M.at(j, k);
H.at(j, party) = H.at(j, party) ^ K.at(j, k);
}
H[j][party] = H[j][party] ^ labels[cf->gates[4*i+2]];
H.at(j, party) = H.at(j, party) ^ labels[cf->gates[4*i+2]];
if(r[j])
H[j][party] = H[j][party] ^ Delta;
H.at(j, party) = H.at(j, party) ^ Delta;
}
for(int j = 0; j < 4; ++j)
io->send_data(1, H[j]+1, sizeof(block)*(nP));
io->send_data(1, &H.at(j, 1), sizeof(block)*(nP));
++ands;
}
io->flush(1);
@@ -285,7 +273,7 @@ class CMPC { public:
int party2 = i;
for(int i = 0; i < num_ands; ++i)
for(int j = 0; j < 4; ++j)
io->recv_data(party2, GT[i][party2][j]+1, sizeof(block)*(nP));
io->recv_data(party2, &GT.at(i, party2, j, 1), sizeof(block)*(nP));
}
for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) {
r[0] = sigma_value[ands] != value[cf->gates[4*i+2]];
@@ -295,144 +283,43 @@ class CMPC { public:
r[3] = r[3] != true;
for(int j = 1; j <= nP; ++j) {
M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]];
M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]];
M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]];
M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]];
M.at(0, j) = sigma_mac.at(j, ands) ^ mac.at(j, cf->gates[4*i+2]);
M.at(1, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i]);
M.at(2, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i+1]);
M.at(3, j) = M.at(1, j) ^ mac.at(j, cf->gates[4*i+1]);
K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]];
K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]];
K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]];
K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]];
K.at(0, j) = sigma_key.at(j, ands) ^ key.at(j, cf->gates[4*i+2]);
K.at(1, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i]);
K.at(2, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i+1]);
K.at(3, j) = K.at(1, j) ^ key.at(j, cf->gates[4*i+1]);
}
memcpy(GTK[ands], K, sizeof(block)*4*(nP+1));
memcpy(GTM[ands], M, sizeof(block)*4*(nP+1));
memcpy(GTv[ands], r, sizeof(bool)*4);
memcpy(&GTK.at(ands, 0, 0), &K.at(0, 0), sizeof(block)*4*(nP+1));
memcpy(&GTM.at(ands, 0, 0), &M.at(0, 0), sizeof(block)*4*(nP+1));
memcpy(&GTv.at(ands, 0), r, sizeof(bool)*4);
++ands;
}
}
for(int i = 1; i <= nP; ++i) {
delete[] x[i];
delete[] y[i];
}
}
void online (bool * input, bool * output) {
bool * mask_input = new bool[cf->num_wire];
for(int i = 0; i < num_in; ++i)
mask_input[i] = input[i] != value[i];
if(party != 1) {
io->send_data(1, mask_input, num_in);
io->flush(1);
io->recv_data(1, mask_input, num_in);
} else {
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i) tmp[i] = new bool[num_in];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, tmp[party2], num_in);
}
for(int i = 0; i < num_in; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[i] = tmp[j][i] != mask_input[i];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->send_data(party2, mask_input, num_in);
io->flush(party2);
}
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
}
if(party!= 1) {
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
}
io->flush(1);
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
}
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
block t0 = GTK[ands][index][j] ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}
}
ands++;
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
}
}
}
if(party != 1) {
io->send_data(1, value+cf->num_wire - cf->n3, cf->n3);
io->flush(1);
} else {
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i)
tmp[i] = new bool[cf->n3];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, tmp[party2], cf->n3);
}
for(int i = 0; i < cf->n3; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 0; i < cf->n3; ++i)
mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3);
}
delete[] mask_input;
}
void Hash(block H[4][nP+1], const block & a, const block & b, uint64_t idx) {
void Hash(NVec<block>& H, const block & a, const block & b, uint64_t idx) {
block T[4];
T[0] = sigma(a);
T[1] = sigma(a ^ Delta);
T[2] = sigma(sigma(b));
T[3] = sigma(sigma(b ^ Delta));
H[0][0] = T[0] ^ T[2];
H[1][0] = T[0] ^ T[3];
H[2][0] = T[1] ^ T[2];
H[3][0] = T[1] ^ T[3];
H.at(0, 0) = T[0] ^ T[2];
H.at(1, 0) = T[0] ^ T[3];
H.at(2, 0) = T[1] ^ T[2];
H.at(3, 0) = T[1] ^ T[3];
for(int j = 0; j < 4; ++j) for(int i = 1; i <= nP; ++i) {
H[j][i] = H[j][0] ^ makeBlock(4*idx+j, i);
H.at(j, i) = H.at(j, 0) ^ makeBlock(4*idx+j, i);
}
for(int j = 0; j < 4; ++j) {
prp.permute_block(H[j]+1, nP);
prp.permute_block(&H.at(j, 1), nP);
}
}
void Hash(block H[nP+1], const block &a, const block& b, uint64_t idx, uint64_t row) {
void Hash(block* H, const block &a, const block& b, uint64_t idx, uint64_t row) {
H[0] = sigma(a) ^ sigma(sigma(b));
for(int i = 1; i <= nP; ++i) {
H[i] = H[0] ^ makeBlock(4*idx+row, i);
@@ -445,137 +332,9 @@ class CMPC { public:
else return "F";
}
void online (bool * input, bool * output, int* start, int* end) {
void online (FlexIn<nP_deprecated> * input, FlexOut<nP_deprecated> *output) {
bool * mask_input = new bool[cf->num_wire];
bool * input_mask[nP+1];
for(int i = 0; i <= nP; ++i) input_mask[i] = new bool[end[party] - start[party]];
memcpy(input_mask[party], value+start[party], end[party] - start[party]);
memcpy(input_mask[0], input+start[party], end[party] - start[party]);
vector<bool> res;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
{
char dig[Hash::DIGEST_SIZE];
io->send_data(party2, value+start[party2], end[party2]-start[party2]);
emp::Hash::hash_once(dig, mac[party2]+start[party2], (end[party2]-start[party2])*sizeof(block));
io->send_data(party2, dig, Hash::DIGEST_SIZE);
io->flush(party2);
res.push_back(false);
}
{
char dig[Hash::DIGEST_SIZE];
char dig2[Hash::DIGEST_SIZE];
io->recv_data(party2, input_mask[party2], end[party]-start[party]);
block * tmp = new block[end[party]-start[party]];
for(int i = 0; i < end[party] - start[party]; ++i) {
tmp[i] = key[party2][i+start[party]];
if(input_mask[party2][i])tmp[i] = tmp[i] ^ Delta;
}
emp::Hash::hash_once(dig2, tmp, (end[party]-start[party])*sizeof(block));
io->recv_data(party2, dig, Hash::DIGEST_SIZE);
delete[] tmp;
res.push_back(strncmp(dig, dig2, Hash::DIGEST_SIZE) != 0);
}
}
if(checkCheat(res)) error("cheat!");
for(int i = 1; i <= nP; ++i)
for(int j = 0; j < end[party] - start[party]; ++j)
input_mask[0][j] = input_mask[0][j] != input_mask[i][j];
if(party != 1) {
io->send_data(1, input_mask[0], end[party] - start[party]);
io->flush(1);
io->recv_data(1, mask_input, num_in);
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, mask_input+start[party2], end[party2] - start[party2]);
}
memcpy(mask_input, input_mask[0], end[1]-start[1]);
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->send_data(party2, mask_input, num_in);
io->flush(party2);
}
}
if(party!= 1) {
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
}
io->flush(1);
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
}
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
block t0 = GTK[ands][index][j] ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
else {cout <<ands <<"no match GT!"<<endl<<flush;
}
}
ands++;
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
}
}
}
if(party != 1) {
io->send_data(1, value+cf->num_wire - cf->n3, cf->n3);
io->flush(1);
} else {
bool * tmp[nP+1];
for(int i = 2; i <= nP; ++i)
tmp[i] = new bool[cf->n3];
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, tmp[party2], cf->n3);
}
for(int i = 0; i < cf->n3; ++i)
for(int j = 2; j <= nP; ++j)
mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 0; i < cf->n3; ++i)
mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i];
for(int i = 2; i <= nP; ++i) delete[] tmp[i];
memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3);
}
delete[] mask_input;
}
void online (FlexIn<nP> * input, FlexOut<nP> *output) {
bool * mask_input = new bool[cf->num_wire];
input->associate_cmpc(value, mac, key, io, Delta);
input->associate_cmpc(&value[0], mac, key, io, Delta);
input->input(mask_input);
if(party!= 1) {
@@ -588,30 +347,30 @@ class CMPC { public:
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, eval_labels[party2], num_in*sizeof(block));
io->recv_data(party2, &eval_labels.at(party2, 0), num_in*sizeof(block));
}
int ands = 0;
for(int i = 0; i < cf->num_gate; ++i) {
if (cf->gates[4*i+3] == XOR_GATE) {
for(int j = 2; j<= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]];
eval_labels.at(j, cf->gates[4*i+2]) = eval_labels.at(j, cf->gates[4*i]) ^ eval_labels.at(j, cf->gates[4*i+1]);
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]];
} else if (cf->gates[4*i+3] == AND_GATE) {
int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]];
block H[nP+1];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j];
mask_input[cf->gates[4*i+2]] = GTv[ands][index];
eval_labels.at(j, cf->gates[4*i+2]) = GTM.at(ands, index, j);
mask_input[cf->gates[4*i+2]] = GTv.at(ands, index);
for(int j = 2; j <= nP; ++j) {
Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index);
xorBlocks_arr(H, H, GT[ands][j][index], nP+1);
Hash(H, eval_labels.at(j, cf->gates[4*i]), eval_labels.at(j, cf->gates[4*i+1]), ands, index);
xorBlocks_arr(H, H, &GT.at(ands, j, index, 0), nP+1);
for(int k = 2; k <= nP; ++k)
eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]];
eval_labels.at(k, cf->gates[4*i+2]) = H[k] ^ eval_labels.at(k, cf->gates[4*i+2]);
block t0 = GTK[ands][index][j] ^ Delta;
block t0 = GTK.at(ands, index, j) ^ Delta;
if(cmpBlock(&H[1], &GTK[ands][index][j], 1))
if(cmpBlock(&H[1], &GTK.at(ands, index, j), 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false;
else if(cmpBlock(&H[1], &t0, 1))
mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true;
@@ -622,12 +381,12 @@ class CMPC { public:
} else {
mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]];
for(int j = 2; j <= nP; ++j)
eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]];
eval_labels.at(j, cf->gates[4*i+2]) = eval_labels.at(j, cf->gates[4*i]);
}
}
}
output->associate_cmpc(value, mac, key, eval_labels, labels, io, Delta);
output->associate_cmpc(&value[0], mac, key, eval_labels, labels, io, Delta);
output->output(mask_input, cf->num_wire - cf->n3);
delete[] mask_input;

View File

@@ -1,23 +1,24 @@
#ifndef NVECTOR_H
#define NVECTOR_H
#include <vector>
#include <stdexcept>
#include <initializer_list>
#include <numeric>
#include <cstddef>
#include <utility>
#include "vec.h"
// N-dimensional vector class
template <typename T>
class NVector {
class NVec {
public:
// Default constructor
NVector() : total_size(0) {}
NVec() : total_size(0) {}
// Constructor taking sizes of each dimension
template <typename... Dims>
explicit NVector(Dims... dims) {
explicit NVec(Dims... dims) {
resize(dims...);
}
@@ -61,7 +62,7 @@ public:
private:
std::vector<size_t> dimensions; // Sizes of each dimension
size_t total_size; // Total size of the data
std::vector<T> data; // Linear storage for the elements
Vec<T> data; // Linear storage for the elements
// Compute the flat index from multi-dimensional indices
size_t compute_flat_index(const std::vector<size_t>& indices) const {

150
src/cpp/emp-agmpc/vec.h Normal file
View File

@@ -0,0 +1,150 @@
#ifndef VEC_H
#define VEC_H
#include <cassert>
#include <cstddef>
#include <stdexcept>
#include <memory>
#include <utility>
// Like std::vector but without <bool> specialization.
// This is important because std::vector<bool> is a bitset which breaks assumptions
// made by code designed for bool* arrays.
template <typename T>
class Vec {
private:
T* data; // Raw pointer to the array
size_t capacity; // Allocated capacity
size_t size_; // Current size
void grow(size_t new_capacity = 0) {
if (new_capacity == 0) {
new_capacity = capacity == 0 ? 1 : capacity * 2;
}
T* new_data = new T[new_capacity];
for (size_t i = 0; i < size_; ++i) {
new_data[i] = std::move(data[i]);
}
delete[] data;
data = new_data;
capacity = new_capacity;
}
public:
// Constructors and destructor
Vec() : data(nullptr), capacity(0), size_(0) {}
explicit Vec(size_t n, const T& value = T())
: data(new T[n]), capacity(n), size_(n) {
for (size_t i = 0; i < n; ++i) {
data[i] = value;
}
}
~Vec() { delete[] data; }
// Copy constructor
Vec(const Vec& other)
: data(new T[other.capacity]), capacity(other.capacity), size_(other.size_) {
for (size_t i = 0; i < size_; ++i) {
data[i] = other.data[i];
}
}
// Move constructor
Vec(Vec&& other) noexcept
: data(other.data), capacity(other.capacity), size_(other.size_) {
other.data = nullptr;
other.capacity = 0;
other.size_ = 0;
}
// Copy assignment
Vec& operator=(const Vec& other) {
if (this != &other) {
delete[] data;
data = new T[other.capacity];
capacity = other.capacity;
size_ = other.size_;
for (size_t i = 0; i < size_; ++i) {
data[i] = other.data[i];
}
}
return *this;
}
// Move assignment
Vec& operator=(Vec&& other) noexcept {
if (this != &other) {
delete[] data;
data = other.data;
capacity = other.capacity;
size_ = other.size_;
other.data = nullptr;
other.capacity = 0;
other.size_ = 0;
}
return *this;
}
// Accessors
T& operator[](size_t index) {
assert(index < size_);
return data[index];
}
const T& operator[](size_t index) const {
assert(index < size_);
return data[index];
}
T& at(size_t index) {
if (index >= size_) {
throw std::out_of_range("Index out of range");
}
return data[index];
}
const T& at(size_t index) const {
if (index >= size_) {
throw std::out_of_range("Index out of range");
}
return data[index];
}
// Member functions
void push_back(const T& value) {
if (size_ == capacity) {
grow();
}
data[size_++] = value;
}
void pop_back() {
assert(size_ > 0);
--size_;
}
void resize(size_t new_size, const T& value = T()) {
if (new_size > capacity) {
grow(new_size);
}
if (new_size > size_) {
for (size_t i = size_; i < new_size; ++i) {
data[i] = value;
}
}
size_ = new_size;
}
size_t size() const { return size_; }
size_t get_capacity() const { return capacity; }
bool empty() const { return size_ == 0; }
void clear() {
delete[] data;
data = nullptr;
capacity = 0;
size_ = 0;
}
};
#endif // VEC_H