From b00d157a6007bce64c6f1f589b606f36f46e2ef5 Mon Sep 17 00:00:00 2001 From: Andrew Morris Date: Tue, 28 Jan 2025 17:07:50 +1100 Subject: [PATCH] Start using Vec, NVec to avoid nP template --- programs/test_mpc.cpp | 2 +- src/cpp/emp-agmpc/abitmp.h | 37 +- src/cpp/emp-agmpc/emp-agmpc.h | 3 +- src/cpp/emp-agmpc/flexible_input_output.h | 16 +- src/cpp/emp-agmpc/fpremp.h | 160 +++---- src/cpp/emp-agmpc/helper.h | 7 +- src/cpp/emp-agmpc/mpc.h | 529 ++++++---------------- src/cpp/emp-agmpc/{nvector.h => nvec.h} | 11 +- src/cpp/emp-agmpc/vec.h | 150 ++++++ 9 files changed, 401 insertions(+), 514 deletions(-) rename src/cpp/emp-agmpc/{nvector.h => nvec.h} (93%) create mode 100644 src/cpp/emp-agmpc/vec.h diff --git a/programs/test_mpc.cpp b/programs/test_mpc.cpp index e89a1b6..4b23bba 100644 --- a/programs/test_mpc.cpp +++ b/programs/test_mpc.cpp @@ -16,7 +16,7 @@ int main(int argc, char** argv) { NetIOMP *ios[2] = {&io, &io2}; BristolFormat cf(circuit_file_location.c_str()); - CMPC* mpc = new CMPC(ios, party, &cf); + CMPC* mpc = new CMPC(nP, ios, party, &cf); cout <<"Setup:\t"<function_independent(); diff --git a/src/cpp/emp-agmpc/abitmp.h b/src/cpp/emp-agmpc/abitmp.h index 69aae6c..205a727 100644 --- a/src/cpp/emp-agmpc/abitmp.h +++ b/src/cpp/emp-agmpc/abitmp.h @@ -4,6 +4,7 @@ #include #include "netmp.h" #include "helper.h" +#include "nvec.h" template class ABitMP { public: @@ -64,21 +65,21 @@ class ABitMP { public: delete abit2[i]; } } - void compute(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) { + void compute(NVec& MAC, NVec& KEY, bool* data, int length) { for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; if (party < party2) { - abit2[party2]->recv_cot(MAC[party2], data, length); + abit2[party2]->recv_cot(&MAC.at(party2, 0), data, length); io->flush(party2); - abit1[party2]->send_cot(KEY[party2], length); + abit1[party2]->send_cot(&KEY.at(party2, 0), length); io->flush(party2); } else { - abit1[party2]->send_cot(KEY[party2], length); + abit1[party2]->send_cot(&KEY.at(party2, 0), length); io->flush(party2); - abit2[party2]->recv_cot(MAC[party2], data, length); + abit2[party2]->recv_cot(&MAC.at(party2, 0), data, length); io->flush(party2); } } @@ -87,12 +88,12 @@ class ABitMP { public: #endif } - void check(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) { + void check(const NVec& MAC, const NVec& KEY, bool* data, int length) { check1(MAC, KEY, data, length); check2(MAC, KEY, data, length); } - void check1(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) { + void check1(const NVec& MAC, const NVec& KEY, bool* data, int length) { block seed = sampleRandom(io, &prg, party); PRG prg2(&seed); uint8_t * tmp; @@ -141,13 +142,13 @@ class ABitMP { public: for(int tt = 0; tt < length/SIZE; tt++) { int start = SIZE*tt; for(int i = SIZE*tt; i < SIZE*(tt+1) and i < length; i+=chk) { - tMAC[(i-start)/chk][1] = MAC[k][i]; - tMAC[(i-start)/chk][2] = MAC[k][i+1]; - tMAC[(i-start)/chk][3] = MAC[k][i] ^ MAC[k][i+1]; + tMAC[(i-start)/chk][1] = MAC.at(k, i); + tMAC[(i-start)/chk][2] = MAC.at(k, i+1); + tMAC[(i-start)/chk][3] = MAC.at(k, i) ^ MAC.at(k, i+1); - tKEY[(i-start)/chk][1] = KEY[k][i]; - tKEY[(i-start)/chk][2] = KEY[k][i+1]; - tKEY[(i-start)/chk][3] = KEY[k][i] ^ KEY[k][i+1]; + tKEY[(i-start)/chk][1] = KEY.at(k, i); + tKEY[(i-start)/chk][2] = KEY.at(k, i+1); + tKEY[(i-start)/chk][3] = KEY.at(k, i) ^ KEY.at(k, i+1); for(int j = 0; j < ssp; ++j) { Ms[k][j] = Ms[k][j] ^ tMAC[(i-start)/chk][*tmpptr]; Ks[k][j] = Ks[k][j] ^ tKEY[(i-start)/chk][*tmpptr]; @@ -187,7 +188,7 @@ class ABitMP { public: } } - void check2(block * MAC[nP+1], block * KEY[nP+1], bool* data, int length) { + void check2(const NVec& MAC, const NVec KEY, bool* data, int length) { //last 2*ssp are garbage already. block * Ks[2], *Ms[nP+1][nP+1]; block * KK[nP+1]; @@ -208,7 +209,7 @@ class ABitMP { public: for(int i = 0; i < ssp; ++i) { Ks[0][i] = zero_block; for(int j = 1; j <= nP; ++j) if(j != party) - Ks[0][i] = Ks[0][i] ^ KEY[j][length-3*ssp+i]; + Ks[0][i] = Ks[0][i] ^ KEY.at(j, length-3*ssp+i); Ks[1][i] = Ks[0][i] ^ Delta; Hash::hash_once(dgst0[party*ssp+i], &Ks[0][i], sizeof(block)); @@ -217,7 +218,7 @@ class ABitMP { public: Hash h; h.put(data+length-3*ssp, ssp); for(int j = 1; j <= nP; ++j) if(j != party) { - h.put(&MAC[j][length-3*ssp], ssp*sizeof(block)); + h.put(&MAC.at(j, length-3*ssp), ssp*sizeof(block)); } h.digest(dgst[party]); @@ -233,14 +234,14 @@ class ABitMP { public: vector res2; for(int k = 1; k <= nP; ++k) if(k!= party) - memcpy(Ms[party][k], MAC[k]+length-3*ssp, sizeof(block)*ssp); + memcpy(Ms[party][k], &MAC.at(k, length-3*ssp), sizeof(block)*ssp); for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; io->send_data(party2, data + length - 3*ssp, ssp); for(int k = 1; k <= nP; ++k) if(k != party) - io->send_data(party2, MAC[k] + length - 3*ssp, sizeof(block)*ssp); + io->send_data(party2, &MAC.at(k, length - 3*ssp), sizeof(block)*ssp); res2.push_back(false); Hash h; diff --git a/src/cpp/emp-agmpc/emp-agmpc.h b/src/cpp/emp-agmpc/emp-agmpc.h index fa2ef0b..1da7fda 100644 --- a/src/cpp/emp-agmpc/emp-agmpc.h +++ b/src/cpp/emp-agmpc/emp-agmpc.h @@ -6,5 +6,6 @@ #include "emp-agmpc/helper.h" #include "emp-agmpc/mpc.h" #include "emp-agmpc/netmp.h" -#include "emp-agmpc/nvector.h" +#include "emp-agmpc/vec.h" +#include "emp-agmpc/nvec.h" #endif// EMP_AGMPC_H diff --git a/src/cpp/emp-agmpc/flexible_input_output.h b/src/cpp/emp-agmpc/flexible_input_output.h index e6e4d41..a1c84bf 100644 --- a/src/cpp/emp-agmpc/flexible_input_output.h +++ b/src/cpp/emp-agmpc/flexible_input_output.h @@ -58,12 +58,12 @@ public: authenticated_share_assignment.clear(); } - void associate_cmpc(bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], NetIOMP *associated_io, block associated_Delta) { + void associate_cmpc(bool *associated_value, NVec& associated_mac, NVec& associated_key, NetIOMP *associated_io, block associated_Delta) { this->cmpc_associated = true; this->value = associated_value; for(int j = 1; j <= nP; j++) { - this->mac[j] = associated_mac[j]; - this->key[j] = associated_key[j]; + this->mac[j] = &associated_mac.at(j, 0); + this->key[j] = &associated_key.at(j, 0); } this->io = associated_io; this->Delta = associated_Delta; @@ -539,17 +539,17 @@ public: authenticated_share_results.clear(); } - void associate_cmpc(bool *associated_value, block *associated_mac[nP + 1], block *associated_key[nP + 1], block *associated_eval_labels[nP + 1], block *associated_labels, NetIOMP *associated_io, block associated_Delta) { + void associate_cmpc(bool *associated_value, NVec& associated_mac, NVec& associated_key, NVec& associated_eval_labels, Vec& associated_labels, NetIOMP *associated_io, block associated_Delta) { this->cmpc_associated = true; this->value = associated_value; - this->labels = associated_labels; + this->labels = &associated_labels.at(0); for (int j = 1; j <= nP; j++) { - this->mac[j] = associated_mac[j]; - this->key[j] = associated_key[j]; + this->mac[j] = &associated_mac.at(j, 0); + this->key[j] = &associated_key.at(j, 0); } if (party == ALICE){ for (int j = 2; j <= nP; j++) { - this->eval_labels[j] = associated_eval_labels[j]; + this->eval_labels[j] = &associated_eval_labels.at(j, 0); } } this->io = associated_io; diff --git a/src/cpp/emp-agmpc/fpremp.h b/src/cpp/emp-agmpc/fpremp.h index b8b4ad2..c2f44f0 100644 --- a/src/cpp/emp-agmpc/fpremp.h +++ b/src/cpp/emp-agmpc/fpremp.h @@ -6,6 +6,7 @@ #include "abitmp.h" #include "netmp.h" #include "cmpc_config.h" +#include "nvec.h" using namespace emp; template @@ -44,87 +45,77 @@ class FpreMP { public: return 4; else return 5; } - void compute(block * MAC[nP+1], block * KEY[nP+1], bool * r, int length) { + void compute(NVec& MAC, NVec& KEY, bool* r, int length) { int64_t bucket_size = get_bucket_size(length); - block * tMAC[nP+1]; - block * tKEY[nP+1]; - block * tKEYphi[nP+1]; - block * tMACphi[nP+1]; - block * phi; - block *X [nP+1]; - bool *tr = new bool[length*bucket_size*3+3*ssp]; - phi = new block[length*bucket_size]; - bool *s[nP+1], *e = new bool[length*bucket_size]; - for(int i = 1; i <= nP; ++i) { - tMAC[i] = new block[length*bucket_size*3+3*ssp]; - tKEY[i] = new block[length*bucket_size*3+3*ssp]; - tKEYphi[i] = new block[length*bucket_size+3*ssp]; - tMACphi[i] = new block[length*bucket_size+3*ssp]; - X[i] = new block[ssp]; - } - for(int i = 0; i <= nP; ++i) { - s[i] = new bool[length*bucket_size]; - memset(s[i], 0, length*bucket_size); - } - prg.random_bool(tr, length*bucket_size*3+3*ssp); + NVec tMAC(nP+1, length*bucket_size*3+3*ssp); + NVec tKEY(nP+1, length*bucket_size*3+3*ssp); + NVec tKEYphi(nP+1, length*bucket_size*3+3*ssp); + NVec tMACphi(nP+1, length*bucket_size*3+3*ssp); + Vec phi(length*bucket_size); + NVec X(nP+1, ssp); + Vec tr(length*bucket_size*3+3*ssp); + NVec s(nP+1, length*bucket_size); + Vec e(length*bucket_size); + + prg.random_bool(&tr[0], length*bucket_size*3+3*ssp); // memset(tr, false, length*bucket_size*3+3*ssp); - abit->compute(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp); + abit->compute(tMAC, tKEY, &tr[0], length*bucket_size*3 + 3*ssp); for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j ) { if(i == party) { - prgs[j].random_bool(s[j], length*bucket_size); + prgs[j].random_bool(&s.at(j, 0), length*bucket_size); for(int k = 0; k < length*bucket_size; ++k) { - uint8_t data = garble(tKEY[j], tr, s[j], k, j); + uint8_t data = garble(&tKEY.at(j, 0), &tr[0], &s.at(j, 0), k, j); io->send_data(j, &data, 1); - s[j][k] = (s[j][k] != (tr[3*k] and tr[3*k+1])); + s.at(j, k) = (s.at(j, k) != (tr[3*k] and tr[3*k+1])); } io->flush(j); } else if (j == party) { for(int k = 0; k < length*bucket_size; ++k) { uint8_t data = 0; io->recv_data(i, &data, 1); - bool tmp = evaluate(data, tMAC[i], tr, k, i); - s[i][k] = (tmp != (tr[3*k] and tr[3*k+1])); + bool tmp = evaluate(data, &tMAC.at(i, 0), &tr[0], k, i); + s.at(i, k) = (tmp != (tr[3*k] and tr[3*k+1])); } } } for(int k = 0; k < length*bucket_size; ++k) { - s[0][k] = (tr[3*k] and tr[3*k+1]); + s.at(0, k) = (tr[3*k] and tr[3*k+1]); for(int i = 1; i <= nP; ++i) if (i != party) { - s[0][k] = (s[0][k] != s[i][k]); + s.at(0, k) = (s.at(0, k) != s.at(i, k)); } - e[k] = (s[0][k] != tr[3*k+2]); - tr[3*k+2] = s[0][k]; + e[k] = (s.at(0, k) != tr[3*k+2]); + tr[3*k+2] = s.at(0, k); } #ifdef __debug - check_correctness(io, tr, length*bucket_size, party); + check_correctness(io, &tr[0], length*bucket_size, party); #endif for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; - io->send_data(party2, e, length*bucket_size); + io->send_data(party2, &e[0], length*bucket_size); io->flush(party2); bool * tmp = new bool[length*bucket_size]; io->recv_data(party2, tmp, length*bucket_size); for(int k = 0; k < length*bucket_size; ++k) { if(tmp[k]) - tKEY[party2][3*k+2] = tKEY[party2][3*k+2] ^ Delta; + tKEY.at(party2, 3*k+2) = tKEY.at(party2, 3*k+2) ^ Delta; } delete[] tmp; } #ifdef __debug check_MAC(io, tMAC, tKEY, tr, Delta, length*bucket_size*3, party); #endif - abit->check(tMAC, tKEY, tr, length*bucket_size*3 + 3*ssp); + abit->check(tMAC, tKEY, &tr[0], length*bucket_size*3 + 3*ssp); //check compute phi for(int k = 0; k < length*bucket_size; ++k) { phi[k] = zero_block; for(int i = 1; i <= nP; ++i) if (i != party) { - phi[k] = phi[k] ^ tKEY[i][3*k+1]; - phi[k] = phi[k] ^ tMAC[i][3*k+1]; + phi[k] = phi[k] ^ tKEY.at(i, 3*k+1); + phi[k] = phi[k] ^ tMAC.at(i, 3*k+1); } if(tr[3*k+1])phi[k] = phi[k] ^ Delta; } @@ -136,10 +127,10 @@ class FpreMP { public: { block bH[2], tmpH[2]; for(int k = 0; k < length*bucket_size; ++k) { - bH[0] = tKEY[party2][3*k]; + bH[0] = tKEY.at(party2, 3*k); bH[1] = bH[0] ^ Delta; HnID(prps+party2, bH, bH, 2*k, 2, tmpH); - tKEYphi[party2][k] = bH[0]; + tKEYphi.at(party2, k) = bH[0]; bH[1] = bH[0] ^ bH[1]; bH[1] = phi[k] ^ bH[1]; io->send_data(party2, &bH[1], sizeof(block)); @@ -151,9 +142,9 @@ class FpreMP { public: block bH; for(int k = 0; k < length*bucket_size; ++k) { io->recv_data(party2, &bH, sizeof(block)); - block hin = sigma(tMAC[party2][3*k]) ^ makeBlock(0, 2*k+tr[3*k]); - tMACphi[party2][k] = prps2[party2].H(hin); - if(tr[3*k])tMACphi[party2][k] = tMACphi[party2][k] ^ bH; + block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]); + tMACphi.at(party2, k) = prps2[party2].H(hin); + if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH; } } } else { @@ -161,19 +152,19 @@ class FpreMP { public: block bH; for(int k = 0; k < length*bucket_size; ++k) { io->recv_data(party2, &bH, sizeof(block)); - block hin = sigma(tMAC[party2][3*k]) ^ makeBlock(0, 2*k+tr[3*k]); - tMACphi[party2][k] = prps2[party2].H(hin); - if(tr[3*k])tMACphi[party2][k] = tMACphi[party2][k] ^ bH; + block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]); + tMACphi.at(party2, k) = prps2[party2].H(hin); + if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH; } } { block bH[2], tmpH[2]; for(int k = 0; k < length*bucket_size; ++k) { - bH[0] = tKEY[party2][3*k]; + bH[0] = tKEY.at(party2, 3*k); bH[1] = bH[0] ^ Delta; HnID(prps+party2, bH, bH, 2*k, 2, tmpH); - tKEYphi[party2][k] = bH[0]; + tKEYphi.at(party2, k) = bH[0]; bH[1] = bH[0] ^ bH[1]; bH[1] = phi[k] ^ bH[1]; io->send_data(party2, &bH[1], sizeof(block)); @@ -191,19 +182,19 @@ class FpreMP { public: #endif //tKEYphti use as H for(int k = 0; k < length*bucket_size; ++k) { - tKEYphi[party][k] = zero_block; + tKEYphi.at(party, k) = zero_block; for(int i = 1; i <= nP; ++i) if (i != party) { - tKEYphi[party][k] = tKEYphi[party][k] ^ tKEYphi[i][k]; - tKEYphi[party][k] = tKEYphi[party][k] ^ tMACphi[i][k]; - tKEYphi[party][k] = tKEYphi[party][k] ^ tKEY[i][3*k+2]; - tKEYphi[party][k] = tKEYphi[party][k] ^ tMAC[i][3*k+2]; + tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tKEYphi.at(i, k); + tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tMACphi.at(i, k); + tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tKEY.at(i, 3*k+2); + tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ tMAC.at(i, 3*k+2); } - if(tr[3*k]) tKEYphi[party][k] = tKEYphi[party][k] ^ phi[k]; - if(tr[3*k+2])tKEYphi[party][k] = tKEYphi[party][k] ^ Delta; + if(tr[3*k]) tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ phi[k]; + if(tr[3*k+2])tKEYphi.at(party, k) = tKEYphi.at(party, k) ^ Delta; } #ifdef __debug - check_zero(tKEYphi[party], length*bucket_size); + check_zero(&tKEYphi.at(party, 0), length*bucket_size); #endif block prg_key = sampleRandom(io, &prg, party); @@ -212,9 +203,9 @@ class FpreMP { public: bool * tmp = new bool[length*bucket_size]; for(int i = 0; i < ssp; ++i) { prgf.random_bool(tmp, length*bucket_size); - X[party][i] = inProd(tmp, tKEYphi[party], length*bucket_size); + X.at(party, i) = inProd(tmp, &tKEYphi.at(party, 0), length*bucket_size); } - Hash::hash_once(dgst[party], X[party], sizeof(block)*ssp); + Hash::hash_once(dgst[party], &X.at(party, 0), sizeof(block)*ssp); for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; @@ -226,18 +217,18 @@ class FpreMP { public: for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; - io->send_data(party2, X[party], sizeof(block)*ssp); - io->recv_data(party2, X[party2], sizeof(block)*ssp); + io->send_data(party2, &X.at(party, 0), sizeof(block)*ssp); + io->recv_data(party2, &X.at(party2, 0), sizeof(block)*ssp); char tmp[Hash::DIGEST_SIZE]; - Hash::hash_once(tmp, X[party2], sizeof(block)*ssp); + Hash::hash_once(tmp, &X.at(party2, 0), sizeof(block)*ssp); res2.push_back(strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0); } if(checkCheat(res2)) error("commitment"); for(int i = 2; i <= nP; ++i) - xorBlocks_arr(X[1], X[1], X[i], ssp); - for(int i = 0; i < ssp; ++i)X[2][i] = zero_block; - if(!cmpBlock(X[1], X[2], ssp)) error("AND check"); + xorBlocks_arr(&X.at(1, 0), &X.at(1, 0), &X.at(i, 0), ssp); + for(int i = 0; i < ssp; ++i)X.at(2, i) = zero_block; + if(!cmpBlock(&X.at(1, 0), &X.at(2, 0), ssp)) error("AND check"); //land -> and block S = sampleRandom(io, &prg, party); @@ -264,17 +255,17 @@ class FpreMP { public: for(int j = 0; j < bucket_size-1; ++j) d[party][(bucket_size-1)*i+j] = tr[3*location[i*bucket_size]+1] != tr[3*location[i*bucket_size+1+j]+1]; for(int j = 1; j <= nP; ++j) if (j!= party) { - memcpy(MAC[j]+3*i, tMAC[j]+3*location[i*bucket_size], 3*sizeof(block)); - memcpy(KEY[j]+3*i, tKEY[j]+3*location[i*bucket_size], 3*sizeof(block)); + memcpy(&MAC.at(j, 3*i), &tMAC.at(j, 3*location[i*bucket_size]), 3*sizeof(block)); + memcpy(&KEY.at(j, 3*i), &tKEY.at(j, 3*location[i*bucket_size]), 3*sizeof(block)); for(int k = 1; k < bucket_size; ++k) { - MAC[j][3*i] = MAC[j][3*i] ^ tMAC[j][3*location[i*bucket_size+k]]; - KEY[j][3*i] = KEY[j][3*i] ^ tKEY[j][3*location[i*bucket_size+k]]; + MAC.at(j, 3*i) = MAC.at(j, 3*i) ^ tMAC.at(j, 3*location[i*bucket_size+k]); + KEY.at(j, 3*i) = KEY.at(j, 3*i) ^ tKEY.at(j, 3*location[i*bucket_size+k]); - MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]+2]; - KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]+2]; + MAC.at(j, 3*i+2) = MAC.at(j, 3*i+2) ^ tMAC.at(j, 3*location[i*bucket_size+k]+2); + KEY.at(j, 3*i+2) = KEY.at(j, 3*i+2) ^ tKEY.at(j, 3*location[i*bucket_size+k]+2); } } - memcpy(r+3*i, tr+3*location[i*bucket_size], 3); + memcpy(&r[3*i], &tr[3*location[i*bucket_size]], 3); for(int k = 1; k < bucket_size; ++k) { r[3*i] = r[3*i] != tr[3*location[i*bucket_size+k]]; r[3*i+2] = r[3*i+2] != tr[3*location[i*bucket_size+k]+2]; @@ -295,8 +286,8 @@ class FpreMP { public: for(int j = 1; j <= nP; ++j)if (j!= party) { for(int k = 1; k < bucket_size; ++k) if(d[1][(bucket_size-1)*i+k-1]) { - MAC[j][3*i+2] = MAC[j][3*i+2] ^ tMAC[j][3*location[i*bucket_size+k]]; - KEY[j][3*i+2] = KEY[j][3*i+2] ^ tKEY[j][3*location[i*bucket_size+k]]; + MAC.at(j, 3*i+2) = MAC.at(j, 3*i+2) ^ tMAC.at(j, 3*location[i*bucket_size+k]); + KEY.at(j, 3*i+2) = KEY.at(j, 3*i+2) ^ tKEY.at(j, 3*location[i*bucket_size+k]); } } for(int k = 1; k < bucket_size; ++k) @@ -311,23 +302,6 @@ class FpreMP { public: #endif // ret.get(); - delete[] tr; - delete[] phi; - delete[] e; - delete[] dgst; - delete[] tmp; - delete[] location; - delete[] xs; - for(int i = 1; i <= nP; ++i) { - delete[] tMAC[i]; - delete[] tKEY[i]; - delete[] tMACphi[i]; - delete[] tKEYphi[i]; - delete[] X[i]; - delete[] s[i]; - delete[] d[i]; - } - delete[] s[0]; } //TODO: change to justGarble @@ -368,13 +342,13 @@ class FpreMP { public: return (tmp&0x1) != (res&0x1); } - void check_MAC_phi(block * MAC[nP+1], block * KEY[nP+1], block * phi, bool * r, int length) { + void check_MAC_phi(const NVec& MAC, const NVec& KEY, block * phi, bool * r, int length) { block * tmp = new block[length]; block *tD = new block[length]; for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) { if(party == i) { io->send_data(j, phi, length*sizeof(block)); - io->send_data(j, KEY[j], sizeof(block)*length); + io->send_data(j, &KEY.at(j, 0), sizeof(block)*length); io->flush(j); } else if(party == j) { io->recv_data(i, tD, length*sizeof(block)); @@ -382,7 +356,7 @@ class FpreMP { public: for(int k = 0; k < length; ++k) { if(r[k])tmp[k] = tmp[k] ^ tD[k]; } - if(!cmpBlock(MAC[i], tmp, length)) + if(!cmpBlock(&MAC.at(i, 0), tmp, length)) error("check_MAC_phi failed!"); } } diff --git a/src/cpp/emp-agmpc/helper.h b/src/cpp/emp-agmpc/helper.h index 7f55dc2..7011415 100644 --- a/src/cpp/emp-agmpc/helper.h +++ b/src/cpp/emp-agmpc/helper.h @@ -3,6 +3,7 @@ #include #include "cmpc_config.h" #include "netmp.h" +#include "nvec.h" #include using namespace emp; using std::future; @@ -117,13 +118,13 @@ block sampleRandom(NetIOMP * io, PRG * prg, int party) { } template -void check_MAC(NetIOMP * io, block * MAC[nP+1], block * KEY[nP+1], bool * r, block Delta, int length, int party) { +void check_MAC(NetIOMP * io, const NVec& MAC, const NVec& KEY, bool * r, block Delta, int length, int party) { block * tmp = new block[length]; block tD; for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) { if(party == i) { io->send_data(j, &Delta, sizeof(block)); - io->send_data(j, KEY[j], sizeof(block)*length); + io->send_data(j, &KEY.at(j, 0), sizeof(block)*length); io->flush(j); } else if(party == j) { io->recv_data(i, &tD, sizeof(block)); @@ -131,7 +132,7 @@ void check_MAC(NetIOMP * io, block * MAC[nP+1], block * KEY[nP+1], bool * r, for(int k = 0; k < length; ++k) { if(r[k])tmp[k] = tmp[k] ^ tD; } - if(!cmpBlock(MAC[i], tmp, length)) + if(!cmpBlock(&MAC.at(i, 0), tmp, length)) error("check_MAC failed!"); } } diff --git a/src/cpp/emp-agmpc/mpc.h b/src/cpp/emp-agmpc/mpc.h index 11e54f4..6e0ccb1 100644 --- a/src/cpp/emp-agmpc/mpc.h +++ b/src/cpp/emp-agmpc/mpc.h @@ -3,45 +3,59 @@ #include "fpremp.h" #include "abitmp.h" #include "netmp.h" +#include "vec.h" +#include "nvec.h" #include "flexible_input_output.h" #include using namespace emp; -template +template class CMPC { public: const static int SSP = 5;//5*8 in fact... const block MASK = makeBlock(0x0ULL, 0xFFFFFULL); - FpreMP* fpre = nullptr; - block* mac[nP+1]; - block* key[nP+1]; - bool* value; + FpreMP* fpre = nullptr; - block * preprocess_mac[nP+1]; - block * preprocess_key[nP+1]; - bool* preprocess_value; + int nP; - block * sigma_mac[nP+1]; - block * sigma_key[nP+1]; - bool * sigma_value; + NVec mac; // dim: parties, wires + NVec key; // dim: parties, wires + Vec value; // dim: wires - block * ANDS_mac[nP+1]; - block * ANDS_key[nP+1]; - bool * ANDS_value; + NVec preprocess_mac; // dim: parties, total_pre + NVec preprocess_key; // dim: parties, total_pre + Vec preprocess_value; // dim: total_pre - block * labels; + NVec sigma_mac; // dim: parties, num_ands + NVec sigma_key; // dim: parties, num_ands + Vec sigma_value; // dim: num_ands + + NVec ANDS_mac; // dim: parties, num_ands*3 + NVec ANDS_key; // dim: parties, num_ands*3 + Vec ANDS_value; // dim: num_ands*3 + + Vec labels; // dim: wires BristolFormat * cf; - NetIOMP * io; + NetIOMP * io; int num_ands = 0, num_in; int party, total_pre, ssp; block Delta; - block (*GTM)[4][nP+1]; - block (*GTK)[4][nP+1]; - bool (*GTv)[4]; - block (*GT)[nP+1][4][nP+1]; - block * eval_labels[nP+1]; + NVec GTM; // dim: num_ands, 4, parties + NVec GTK; // dim: num_ands, 4, parties + NVec GTv; // dim: num_ands, 4 + NVec GT; // dim: num_ands, parties, 4, parties + NVec eval_labels; // dim: parties, wires PRP prp; - CMPC(NetIOMP * io[2], int party, BristolFormat * cf, bool * _delta = nullptr, int ssp = 40) { + + CMPC( + int nP, + NetIOMP * io[2], + int party, + BristolFormat * cf, + bool * _delta = nullptr, + int ssp = 40 + ) { + this->nP = nP; this->party = party; this->io = io[0]; this->cf = cf; @@ -53,96 +67,69 @@ class CMPC { public: } num_in = cf->n1+cf->n2; total_pre = num_in + num_ands + 3*ssp; - fpre = new FpreMP(io, party, _delta, ssp); + fpre = new FpreMP(io, party, _delta, ssp); Delta = fpre->Delta; if(party == 1) { - GTM = new block[num_ands][4][nP+1]; - GTK = new block[num_ands][4][nP+1]; - GTv = new bool[num_ands][4]; - GT = new block[num_ands][nP+1][4][nP+1]; + GTM.resize(num_ands, 4, nP+1); + GTK.resize(num_ands, 4, nP+1); + GTv.resize(num_ands, 4); + GT.resize(num_ands, nP+1, 4, nP+1); } - labels = new block[cf->num_wire]; - for(int i = 1; i <= nP; ++i) { - key[i] = new block[cf->num_wire]; - mac[i] = new block[cf->num_wire]; - ANDS_key[i] = new block[num_ands*3]; - ANDS_mac[i] = new block[num_ands*3]; - preprocess_mac[i] = new block[total_pre]; - preprocess_key[i] = new block[total_pre]; - sigma_mac[i] = new block[num_ands]; - sigma_key[i] = new block[num_ands]; - eval_labels[i] = new block[cf->num_wire]; - } - value = new bool[cf->num_wire]; - ANDS_value = new bool[num_ands*3]; - preprocess_value = new bool[total_pre]; - sigma_value = new bool[num_ands]; + labels.resize(cf->num_wire); + key.resize(nP+1, cf->num_wire); + mac.resize(nP+1, cf->num_wire); + ANDS_key.resize(nP+1, num_ands*3); + ANDS_mac.resize(nP+1, num_ands*3); + preprocess_mac.resize(nP+1, total_pre); + preprocess_key.resize(nP+1, total_pre); + sigma_mac.resize(nP+1, num_ands); + sigma_key.resize(nP+1, num_ands); + eval_labels.resize(nP+1, cf->num_wire); + + value.resize(cf->num_wire); + ANDS_value.resize(num_ands*3); + preprocess_value.resize(total_pre); + sigma_value.resize(num_ands); } ~CMPC() { delete fpre; - if(party == 1) { - delete[] GTM; - delete[] GTK; - delete[] GTv; - delete[] GT; - } - delete[] labels; - for(int i = 1; i <= nP; ++i) { - delete[] key[i]; - delete[] mac[i]; - delete[] ANDS_key[i]; - delete[] ANDS_mac[i]; - delete[] preprocess_mac[i]; - delete[] preprocess_key[i]; - delete[] sigma_mac[i]; - delete[] sigma_key[i]; - delete[] eval_labels[i]; - } - delete[] value; - delete[] ANDS_value; - delete[] preprocess_value; - delete[] sigma_value; } PRG prg; void function_independent() { if(party != 1) - prg.random_block(labels, cf->num_wire); + prg.random_block(&labels[0], cf->num_wire); - fpre->compute(ANDS_mac, ANDS_key, ANDS_value, num_ands); + fpre->compute(ANDS_mac, ANDS_key, &ANDS_value[0], num_ands); - prg.random_bool(preprocess_value, total_pre); - fpre->abit->compute(preprocess_mac, preprocess_key, preprocess_value, total_pre); - fpre->abit->check(preprocess_mac, preprocess_key, preprocess_value, total_pre); + prg.random_bool(&preprocess_value[0], total_pre); + fpre->abit->compute(preprocess_mac, preprocess_key, &preprocess_value[0], total_pre); + fpre->abit->check(preprocess_mac, preprocess_key, &preprocess_value[0], total_pre); for(int i = 1; i <= nP; ++i) { - memcpy(key[i], preprocess_key[i], num_in * sizeof(block)); - memcpy(mac[i], preprocess_mac[i], num_in * sizeof(block)); + memcpy(&key.at(i, 0), &preprocess_key.at(i, 0), num_in * sizeof(block)); + memcpy(&mac.at(i, 0), &preprocess_mac.at(i, 0), num_in * sizeof(block)); } - memcpy(value, preprocess_value, num_in * sizeof(bool)); + memcpy(&value[0], &preprocess_value[0], num_in * sizeof(bool)); #ifdef __debug - check_MAC(io, ANDS_mac, ANDS_key, ANDS_value, Delta, num_ands*3, party); - check_correctness(io, ANDS_value, num_ands, party); + check_MAC(nP, io, ANDS_mac, ANDS_key, ANDS_value, Delta, num_ands*3, party); + check_correctness(nP, io, ANDS_value, num_ands, party); #endif // ret.get(); } void function_dependent() { int ands = num_in; - bool * x[nP+1]; - bool * y[nP+1]; - for(int i = 1; i <= nP; ++i) { - x[i] = new bool[num_ands]; - y[i] = new bool[num_ands]; - } + NVec x(nP+1, num_ands); + NVec y(nP+1, num_ands); for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == AND_GATE) { for(int j = 1; j <= nP; ++j) { - key[j][cf->gates[4*i+2]] = preprocess_key[j][ands]; - mac[j][cf->gates[4*i+2]] = preprocess_mac[j][ands]; + key.at(j, cf->gates[4*i+2]) = preprocess_key.at(j, ands); + mac.at(j, cf->gates[4*i+2]) = preprocess_mac.at(j, ands); } value[cf->gates[4*i+2]] = preprocess_value[ands]; ++ands; @@ -152,16 +139,16 @@ class CMPC { public: for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == XOR_GATE) { for(int j = 1; j <= nP; ++j) { - key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]] ^ key[j][cf->gates[4*i+1]]; - mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]] ^ mac[j][cf->gates[4*i+1]]; + key.at(j, cf->gates[4*i+2]) = key.at(j, cf->gates[4*i]) ^ key.at(j, cf->gates[4*i+1]); + mac.at(j, cf->gates[4*i+2]) = mac.at(j, cf->gates[4*i]) ^ mac.at(j, cf->gates[4*i+1]); } value[cf->gates[4*i+2]] = value[cf->gates[4*i]] != value[cf->gates[4*i+1]]; if(party != 1) labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]]; } else if (cf->gates[4*i+3] == NOT_GATE) { for(int j = 1; j <= nP; ++j) { - key[j][cf->gates[4*i+2]] = key[j][cf->gates[4*i]]; - mac[j][cf->gates[4*i+2]] = mac[j][cf->gates[4*i]]; + key.at(j, cf->gates[4*i+2]) = key.at(j, cf->gates[4*i]); + mac.at(j, cf->gates[4*i+2]) = mac.at(j, cf->gates[4*i]); } value[cf->gates[4*i+2]] = value[cf->gates[4*i]]; if(party != 1) @@ -170,14 +157,14 @@ class CMPC { public: } #ifdef __debug - check_MAC(io, mac, key, value, Delta, cf->num_wire, party); + check_MAC(nP, io, mac, key, value, Delta, cf->num_wire, party); #endif ands = 0; for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == AND_GATE) { - x[party][ands] = value[cf->gates[4*i]] != ANDS_value[3*ands]; - y[party][ands] = value[cf->gates[4*i+1]] != ANDS_value[3*ands+1]; + x.at(party, ands) = value[cf->gates[4*i]] != ANDS_value[3*ands]; + y.at(party, ands) = value[cf->gates[4*i+1]] != ANDS_value[3*ands+1]; ands++; } } @@ -185,44 +172,44 @@ class CMPC { public: for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if( (i < j) and (i == party or j == party) ) { int party2 = i + j - party; - io->send_data(party2, x[party], num_ands); - io->send_data(party2, y[party], num_ands); + io->send_data(party2, &x.at(party, 0), num_ands); + io->send_data(party2, &y.at(party, 0), num_ands); io->flush(party2); - io->recv_data(party2, x[party2], num_ands); - io->recv_data(party2, y[party2], num_ands); + io->recv_data(party2, &x.at(party2, 0), num_ands); + io->recv_data(party2, &y.at(party2, 0), num_ands); } for(int i = 2; i <= nP; ++i) for(int j = 0; j < num_ands; ++j) { - x[1][j] = x[1][j] != x[i][j]; - y[1][j] = y[1][j] != y[i][j]; + x.at(1, j) = x.at(1, j) != x.at(i, j); + y.at(1, j) = y.at(1, j) != y.at(i, j); } ands = 0; for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == AND_GATE) { for(int j = 1; j <= nP; ++j) { - sigma_mac[j][ands] = ANDS_mac[j][3*ands+2]; - sigma_key[j][ands] = ANDS_key[j][3*ands+2]; + sigma_mac.at(j, ands) = ANDS_mac.at(j, 3*ands+2); + sigma_key.at(j, ands) = ANDS_key.at(j, 3*ands+2); } sigma_value[ands] = ANDS_value[3*ands+2]; - if(x[1][ands]) { + if(x.at(1, ands)) { for(int j = 1; j <= nP; ++j) { - sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands+1]; - sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands+1]; + sigma_mac.at(j, ands) = sigma_mac.at(j, ands) ^ ANDS_mac.at(j, 3*ands+1); + sigma_key.at(j, ands) = sigma_key.at(j, ands) ^ ANDS_key.at(j, 3*ands+1); } sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands+1]; } - if(y[1][ands]) { + if(y.at(1, ands)) { for(int j = 1; j <= nP; ++j) { - sigma_mac[j][ands] = sigma_mac[j][ands] ^ ANDS_mac[j][3*ands]; - sigma_key[j][ands] = sigma_key[j][ands] ^ ANDS_key[j][3*ands]; + sigma_mac.at(j, ands) = sigma_mac.at(j, ands) ^ ANDS_mac.at(j, 3*ands); + sigma_key.at(j, ands) = sigma_key.at(j, ands) ^ ANDS_key.at(j, 3*ands); } sigma_value[ands] = sigma_value[ands] != ANDS_value[3*ands]; } - if(x[1][ands] and y[1][ands]) { + if(x.at(1, ands) and y.at(1, ands)) { if(party != 1) - sigma_key[1][ands] = sigma_key[1][ands] ^ Delta; + sigma_key.at(1, ands) = sigma_key.at(1, ands) ^ Delta; else sigma_value[ands] = not sigma_value[ands]; } @@ -230,7 +217,7 @@ class CMPC { public: } }//sigma_[] stores the and of input wires to each AND gates #ifdef __debug_ - check_MAC(io, sigma_mac, sigma_key, sigma_value, Delta, num_ands, party); + check_MAC(nP, io, sigma_mac, sigma_key, sigma_value, Delta, num_ands, party); ands = 0; for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == AND_GATE) { @@ -242,8 +229,9 @@ class CMPC { public: #endif ands = 0; - block H[4][nP+1]; - block K[4][nP+1], M[4][nP+1]; + NVec H(4, nP+1); + NVec K(4, nP+1); + NVec M(4, nP+1); bool r[4]; if(party != 1) { for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) { @@ -253,30 +241,30 @@ class CMPC { public: r[3] = r[1] != value[cf->gates[4*i+1]]; for(int j = 1; j <= nP; ++j) { - M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]]; - M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]]; - M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]]; - M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]]; + M.at(0, j) = sigma_mac.at(j, ands) ^ mac.at(j, cf->gates[4*i+2]); + M.at(1, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i]); + M.at(2, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i+1]); + M.at(3, j) = M.at(1, j) ^ mac.at(j, cf->gates[4*i+1]); - K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]]; - K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]]; - K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]]; - K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]]; + K.at(0, j) = sigma_key.at(j, ands) ^ key.at(j, cf->gates[4*i+2]); + K.at(1, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i]); + K.at(2, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i+1]); + K.at(3, j) = K.at(1, j) ^ key.at(j, cf->gates[4*i+1]); } - K[3][1] = K[3][1] ^ Delta; + K.at(3, 1) = K.at(3, 1) ^ Delta; Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], ands); for(int j = 0; j < 4; ++j) { for(int k = 1; k <= nP; ++k) if(k != party) { - H[j][k] = H[j][k] ^ M[j][k]; - H[j][party] = H[j][party] ^ K[j][k]; + H.at(j, k) = H.at(j, k) ^ M.at(j, k); + H.at(j, party) = H.at(j, party) ^ K.at(j, k); } - H[j][party] = H[j][party] ^ labels[cf->gates[4*i+2]]; + H.at(j, party) = H.at(j, party) ^ labels[cf->gates[4*i+2]]; if(r[j]) - H[j][party] = H[j][party] ^ Delta; + H.at(j, party) = H.at(j, party) ^ Delta; } for(int j = 0; j < 4; ++j) - io->send_data(1, H[j]+1, sizeof(block)*(nP)); + io->send_data(1, &H.at(j, 1), sizeof(block)*(nP)); ++ands; } io->flush(1); @@ -285,7 +273,7 @@ class CMPC { public: int party2 = i; for(int i = 0; i < num_ands; ++i) for(int j = 0; j < 4; ++j) - io->recv_data(party2, GT[i][party2][j]+1, sizeof(block)*(nP)); + io->recv_data(party2, >.at(i, party2, j, 1), sizeof(block)*(nP)); } for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) { r[0] = sigma_value[ands] != value[cf->gates[4*i+2]]; @@ -295,144 +283,43 @@ class CMPC { public: r[3] = r[3] != true; for(int j = 1; j <= nP; ++j) { - M[0][j] = sigma_mac[j][ands] ^ mac[j][cf->gates[4*i+2]]; - M[1][j] = M[0][j] ^ mac[j][cf->gates[4*i]]; - M[2][j] = M[0][j] ^ mac[j][cf->gates[4*i+1]]; - M[3][j] = M[1][j] ^ mac[j][cf->gates[4*i+1]]; + M.at(0, j) = sigma_mac.at(j, ands) ^ mac.at(j, cf->gates[4*i+2]); + M.at(1, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i]); + M.at(2, j) = M.at(0, j) ^ mac.at(j, cf->gates[4*i+1]); + M.at(3, j) = M.at(1, j) ^ mac.at(j, cf->gates[4*i+1]); - K[0][j] = sigma_key[j][ands] ^ key[j][cf->gates[4*i+2]]; - K[1][j] = K[0][j] ^ key[j][cf->gates[4*i]]; - K[2][j] = K[0][j] ^ key[j][cf->gates[4*i+1]]; - K[3][j] = K[1][j] ^ key[j][cf->gates[4*i+1]]; + K.at(0, j) = sigma_key.at(j, ands) ^ key.at(j, cf->gates[4*i+2]); + K.at(1, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i]); + K.at(2, j) = K.at(0, j) ^ key.at(j, cf->gates[4*i+1]); + K.at(3, j) = K.at(1, j) ^ key.at(j, cf->gates[4*i+1]); } - memcpy(GTK[ands], K, sizeof(block)*4*(nP+1)); - memcpy(GTM[ands], M, sizeof(block)*4*(nP+1)); - memcpy(GTv[ands], r, sizeof(bool)*4); + memcpy(>K.at(ands, 0, 0), &K.at(0, 0), sizeof(block)*4*(nP+1)); + memcpy(>M.at(ands, 0, 0), &M.at(0, 0), sizeof(block)*4*(nP+1)); + memcpy(>v.at(ands, 0), r, sizeof(bool)*4); ++ands; } } - for(int i = 1; i <= nP; ++i) { - delete[] x[i]; - delete[] y[i]; - } } - - void online (bool * input, bool * output) { - bool * mask_input = new bool[cf->num_wire]; - for(int i = 0; i < num_in; ++i) - mask_input[i] = input[i] != value[i]; - if(party != 1) { - io->send_data(1, mask_input, num_in); - io->flush(1); - io->recv_data(1, mask_input, num_in); - } else { - bool * tmp[nP+1]; - for(int i = 2; i <= nP; ++i) tmp[i] = new bool[num_in]; - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, tmp[party2], num_in); - } - for(int i = 0; i < num_in; ++i) - for(int j = 2; j <= nP; ++j) - mask_input[i] = tmp[j][i] != mask_input[i]; - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->send_data(party2, mask_input, num_in); - io->flush(party2); - } - for(int i = 2; i <= nP; ++i) delete[] tmp[i]; - } - - if(party!= 1) { - for(int i = 0; i < num_in; ++i) { - block tmp = labels[i]; - if(mask_input[i]) tmp = tmp ^ Delta; - io->send_data(1, &tmp, sizeof(block)); - } - io->flush(1); - } else { - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, eval_labels[party2], num_in*sizeof(block)); - } - - int ands = 0; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == XOR_GATE) { - for(int j = 2; j<= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]]; - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]]; - } else if (cf->gates[4*i+3] == AND_GATE) { - int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]]; - block H[nP+1]; - for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j]; - mask_input[cf->gates[4*i+2]] = GTv[ands][index]; - for(int j = 2; j <= nP; ++j) { - Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index); - xorBlocks_arr(H, H, GT[ands][j][index], nP+1); - for(int k = 2; k <= nP; ++k) - eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]]; - - block t0 = GTK[ands][index][j] ^ Delta; - - if(cmpBlock(&H[1], >K[ands][index][j], 1)) - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false; - else if(cmpBlock(&H[1], &t0, 1)) - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true; - else {cout <gates[4*i+2]] = not mask_input[cf->gates[4*i]]; - for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]]; - } - } - } - if(party != 1) { - io->send_data(1, value+cf->num_wire - cf->n3, cf->n3); - io->flush(1); - } else { - bool * tmp[nP+1]; - for(int i = 2; i <= nP; ++i) - tmp[i] = new bool[cf->n3]; - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, tmp[party2], cf->n3); - } - for(int i = 0; i < cf->n3; ++i) - for(int j = 2; j <= nP; ++j) - mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i]; - for(int i = 0; i < cf->n3; ++i) - mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i]; - - for(int i = 2; i <= nP; ++i) delete[] tmp[i]; - memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3); - } - delete[] mask_input; - } - void Hash(block H[4][nP+1], const block & a, const block & b, uint64_t idx) { + void Hash(NVec& H, const block & a, const block & b, uint64_t idx) { block T[4]; T[0] = sigma(a); T[1] = sigma(a ^ Delta); T[2] = sigma(sigma(b)); T[3] = sigma(sigma(b ^ Delta)); - H[0][0] = T[0] ^ T[2]; - H[1][0] = T[0] ^ T[3]; - H[2][0] = T[1] ^ T[2]; - H[3][0] = T[1] ^ T[3]; + H.at(0, 0) = T[0] ^ T[2]; + H.at(1, 0) = T[0] ^ T[3]; + H.at(2, 0) = T[1] ^ T[2]; + H.at(3, 0) = T[1] ^ T[3]; for(int j = 0; j < 4; ++j) for(int i = 1; i <= nP; ++i) { - H[j][i] = H[j][0] ^ makeBlock(4*idx+j, i); + H.at(j, i) = H.at(j, 0) ^ makeBlock(4*idx+j, i); } for(int j = 0; j < 4; ++j) { - prp.permute_block(H[j]+1, nP); + prp.permute_block(&H.at(j, 1), nP); } } - void Hash(block H[nP+1], const block &a, const block& b, uint64_t idx, uint64_t row) { + void Hash(block* H, const block &a, const block& b, uint64_t idx, uint64_t row) { H[0] = sigma(a) ^ sigma(sigma(b)); for(int i = 1; i <= nP; ++i) { H[i] = H[0] ^ makeBlock(4*idx+row, i); @@ -445,137 +332,9 @@ class CMPC { public: else return "F"; } - void online (bool * input, bool * output, int* start, int* end) { + void online (FlexIn * input, FlexOut *output) { bool * mask_input = new bool[cf->num_wire]; - bool * input_mask[nP+1]; - for(int i = 0; i <= nP; ++i) input_mask[i] = new bool[end[party] - start[party]]; - memcpy(input_mask[party], value+start[party], end[party] - start[party]); - memcpy(input_mask[0], input+start[party], end[party] - start[party]); - - vector res; - for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) { - int party2 = i + j - party; - - { - char dig[Hash::DIGEST_SIZE]; - io->send_data(party2, value+start[party2], end[party2]-start[party2]); - emp::Hash::hash_once(dig, mac[party2]+start[party2], (end[party2]-start[party2])*sizeof(block)); - io->send_data(party2, dig, Hash::DIGEST_SIZE); - io->flush(party2); - res.push_back(false); - } - - { - char dig[Hash::DIGEST_SIZE]; - char dig2[Hash::DIGEST_SIZE]; - io->recv_data(party2, input_mask[party2], end[party]-start[party]); - block * tmp = new block[end[party]-start[party]]; - for(int i = 0; i < end[party] - start[party]; ++i) { - tmp[i] = key[party2][i+start[party]]; - if(input_mask[party2][i])tmp[i] = tmp[i] ^ Delta; - } - emp::Hash::hash_once(dig2, tmp, (end[party]-start[party])*sizeof(block)); - io->recv_data(party2, dig, Hash::DIGEST_SIZE); - delete[] tmp; - res.push_back(strncmp(dig, dig2, Hash::DIGEST_SIZE) != 0); - } - } - if(checkCheat(res)) error("cheat!"); - for(int i = 1; i <= nP; ++i) - for(int j = 0; j < end[party] - start[party]; ++j) - input_mask[0][j] = input_mask[0][j] != input_mask[i][j]; - - if(party != 1) { - io->send_data(1, input_mask[0], end[party] - start[party]); - io->flush(1); - io->recv_data(1, mask_input, num_in); - } else { - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, mask_input+start[party2], end[party2] - start[party2]); - } - memcpy(mask_input, input_mask[0], end[1]-start[1]); - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->send_data(party2, mask_input, num_in); - io->flush(party2); - } - } - - if(party!= 1) { - for(int i = 0; i < num_in; ++i) { - block tmp = labels[i]; - if(mask_input[i]) tmp = tmp ^ Delta; - io->send_data(1, &tmp, sizeof(block)); - } - io->flush(1); - } else { - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, eval_labels[party2], num_in*sizeof(block)); - } - - int ands = 0; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == XOR_GATE) { - for(int j = 2; j<= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]]; - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]]; - } else if (cf->gates[4*i+3] == AND_GATE) { - int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]]; - block H[nP+1]; - for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j]; - mask_input[cf->gates[4*i+2]] = GTv[ands][index]; - for(int j = 2; j <= nP; ++j) { - Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index); - xorBlocks_arr(H, H, GT[ands][j][index], nP+1); - for(int k = 2; k <= nP; ++k) - eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]]; - - block t0 = GTK[ands][index][j] ^ Delta; - - if(cmpBlock(&H[1], >K[ands][index][j], 1)) - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false; - else if(cmpBlock(&H[1], &t0, 1)) - mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true; - else {cout <gates[4*i+2]] = not mask_input[cf->gates[4*i]]; - for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]]; - } - } - } - if(party != 1) { - io->send_data(1, value+cf->num_wire - cf->n3, cf->n3); - io->flush(1); - } else { - bool * tmp[nP+1]; - for(int i = 2; i <= nP; ++i) - tmp[i] = new bool[cf->n3]; - for(int i = 2; i <= nP; ++i) { - int party2 = i; - io->recv_data(party2, tmp[party2], cf->n3); - } - for(int i = 0; i < cf->n3; ++i) - for(int j = 2; j <= nP; ++j) - mask_input[cf->num_wire - cf->n3 + i] = tmp[j][i] != mask_input[cf->num_wire - cf->n3 + i]; - for(int i = 0; i < cf->n3; ++i) - mask_input[cf->num_wire - cf->n3 + i] = value[cf->num_wire - cf->n3 + i] != mask_input[cf->num_wire - cf->n3 + i]; - - for(int i = 2; i <= nP; ++i) delete[] tmp[i]; - memcpy(output, mask_input + cf->num_wire - cf->n3, cf->n3); - } - delete[] mask_input; - } - - void online (FlexIn * input, FlexOut *output) { - bool * mask_input = new bool[cf->num_wire]; - input->associate_cmpc(value, mac, key, io, Delta); + input->associate_cmpc(&value[0], mac, key, io, Delta); input->input(mask_input); if(party!= 1) { @@ -588,30 +347,30 @@ class CMPC { public: } else { for(int i = 2; i <= nP; ++i) { int party2 = i; - io->recv_data(party2, eval_labels[party2], num_in*sizeof(block)); + io->recv_data(party2, &eval_labels.at(party2, 0), num_in*sizeof(block)); } int ands = 0; for(int i = 0; i < cf->num_gate; ++i) { if (cf->gates[4*i+3] == XOR_GATE) { for(int j = 2; j<= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]] ^ eval_labels[j][cf->gates[4*i+1]]; + eval_labels.at(j, cf->gates[4*i+2]) = eval_labels.at(j, cf->gates[4*i]) ^ eval_labels.at(j, cf->gates[4*i+1]); mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i]] != mask_input[cf->gates[4*i+1]]; } else if (cf->gates[4*i+3] == AND_GATE) { int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]]; block H[nP+1]; for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = GTM[ands][index][j]; - mask_input[cf->gates[4*i+2]] = GTv[ands][index]; + eval_labels.at(j, cf->gates[4*i+2]) = GTM.at(ands, index, j); + mask_input[cf->gates[4*i+2]] = GTv.at(ands, index); for(int j = 2; j <= nP; ++j) { - Hash(H, eval_labels[j][cf->gates[4*i]], eval_labels[j][cf->gates[4*i+1]], ands, index); - xorBlocks_arr(H, H, GT[ands][j][index], nP+1); + Hash(H, eval_labels.at(j, cf->gates[4*i]), eval_labels.at(j, cf->gates[4*i+1]), ands, index); + xorBlocks_arr(H, H, >.at(ands, j, index, 0), nP+1); for(int k = 2; k <= nP; ++k) - eval_labels[k][cf->gates[4*i+2]] = H[k] ^ eval_labels[k][cf->gates[4*i+2]]; + eval_labels.at(k, cf->gates[4*i+2]) = H[k] ^ eval_labels.at(k, cf->gates[4*i+2]); - block t0 = GTK[ands][index][j] ^ Delta; + block t0 = GTK.at(ands, index, j) ^ Delta; - if(cmpBlock(&H[1], >K[ands][index][j], 1)) + if(cmpBlock(&H[1], >K.at(ands, index, j), 1)) mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != false; else if(cmpBlock(&H[1], &t0, 1)) mask_input[cf->gates[4*i+2]] = mask_input[cf->gates[4*i+2]] != true; @@ -622,12 +381,12 @@ class CMPC { public: } else { mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]]; for(int j = 2; j <= nP; ++j) - eval_labels[j][cf->gates[4*i+2]] = eval_labels[j][cf->gates[4*i]]; + eval_labels.at(j, cf->gates[4*i+2]) = eval_labels.at(j, cf->gates[4*i]); } } } - output->associate_cmpc(value, mac, key, eval_labels, labels, io, Delta); + output->associate_cmpc(&value[0], mac, key, eval_labels, labels, io, Delta); output->output(mask_input, cf->num_wire - cf->n3); delete[] mask_input; diff --git a/src/cpp/emp-agmpc/nvector.h b/src/cpp/emp-agmpc/nvec.h similarity index 93% rename from src/cpp/emp-agmpc/nvector.h rename to src/cpp/emp-agmpc/nvec.h index c47207e..48f33bf 100644 --- a/src/cpp/emp-agmpc/nvector.h +++ b/src/cpp/emp-agmpc/nvec.h @@ -1,23 +1,24 @@ #ifndef NVECTOR_H #define NVECTOR_H -#include #include #include #include #include #include +#include "vec.h" + // N-dimensional vector class template -class NVector { +class NVec { public: // Default constructor - NVector() : total_size(0) {} + NVec() : total_size(0) {} // Constructor taking sizes of each dimension template - explicit NVector(Dims... dims) { + explicit NVec(Dims... dims) { resize(dims...); } @@ -61,7 +62,7 @@ public: private: std::vector dimensions; // Sizes of each dimension size_t total_size; // Total size of the data - std::vector data; // Linear storage for the elements + Vec data; // Linear storage for the elements // Compute the flat index from multi-dimensional indices size_t compute_flat_index(const std::vector& indices) const { diff --git a/src/cpp/emp-agmpc/vec.h b/src/cpp/emp-agmpc/vec.h new file mode 100644 index 0000000..bf86527 --- /dev/null +++ b/src/cpp/emp-agmpc/vec.h @@ -0,0 +1,150 @@ +#ifndef VEC_H +#define VEC_H + +#include +#include +#include +#include +#include + +// Like std::vector but without specialization. +// This is important because std::vector is a bitset which breaks assumptions +// made by code designed for bool* arrays. +template +class Vec { +private: + T* data; // Raw pointer to the array + size_t capacity; // Allocated capacity + size_t size_; // Current size + + void grow(size_t new_capacity = 0) { + if (new_capacity == 0) { + new_capacity = capacity == 0 ? 1 : capacity * 2; + } + T* new_data = new T[new_capacity]; + for (size_t i = 0; i < size_; ++i) { + new_data[i] = std::move(data[i]); + } + delete[] data; + data = new_data; + capacity = new_capacity; + } + +public: + // Constructors and destructor + Vec() : data(nullptr), capacity(0), size_(0) {} + explicit Vec(size_t n, const T& value = T()) + : data(new T[n]), capacity(n), size_(n) { + for (size_t i = 0; i < n; ++i) { + data[i] = value; + } + } + ~Vec() { delete[] data; } + + // Copy constructor + Vec(const Vec& other) + : data(new T[other.capacity]), capacity(other.capacity), size_(other.size_) { + for (size_t i = 0; i < size_; ++i) { + data[i] = other.data[i]; + } + } + + // Move constructor + Vec(Vec&& other) noexcept + : data(other.data), capacity(other.capacity), size_(other.size_) { + other.data = nullptr; + other.capacity = 0; + other.size_ = 0; + } + + // Copy assignment + Vec& operator=(const Vec& other) { + if (this != &other) { + delete[] data; + data = new T[other.capacity]; + capacity = other.capacity; + size_ = other.size_; + for (size_t i = 0; i < size_; ++i) { + data[i] = other.data[i]; + } + } + return *this; + } + + // Move assignment + Vec& operator=(Vec&& other) noexcept { + if (this != &other) { + delete[] data; + data = other.data; + capacity = other.capacity; + size_ = other.size_; + other.data = nullptr; + other.capacity = 0; + other.size_ = 0; + } + return *this; + } + + // Accessors + T& operator[](size_t index) { + assert(index < size_); + return data[index]; + } + + const T& operator[](size_t index) const { + assert(index < size_); + return data[index]; + } + + T& at(size_t index) { + if (index >= size_) { + throw std::out_of_range("Index out of range"); + } + return data[index]; + } + + const T& at(size_t index) const { + if (index >= size_) { + throw std::out_of_range("Index out of range"); + } + return data[index]; + } + + // Member functions + void push_back(const T& value) { + if (size_ == capacity) { + grow(); + } + data[size_++] = value; + } + + void pop_back() { + assert(size_ > 0); + --size_; + } + + void resize(size_t new_size, const T& value = T()) { + if (new_size > capacity) { + grow(new_size); + } + if (new_size > size_) { + for (size_t i = size_; i < new_size; ++i) { + data[i] = value; + } + } + size_ = new_size; + } + + size_t size() const { return size_; } + size_t get_capacity() const { return capacity; } + bool empty() const { return size_ == 0; } + + void clear() { + delete[] data; + data = nullptr; + capacity = 0; + size_ = 0; + } +}; + +#endif // VEC_H