always use io channels

This commit is contained in:
Andrew Morris
2025-01-29 13:12:42 +11:00
parent 23c2bdac0b
commit c0a3e3f292
6 changed files with 71 additions and 90 deletions

View File

@@ -168,13 +168,13 @@ class ABitMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, Ms[party2], sizeof(block)*ssp);
io->send_data(party2, bs[party2], ssp);
io->send_channel(party2).send_data(Ms[party2], sizeof(block)*ssp);
io->send_channel(party2).send_data(bs[party2], ssp);
io->flush(party2);
res.push_back(false);
io->recv_data(party2, tMs[party2], sizeof(block)*ssp);
io->recv_data(party2, tbs[party2], ssp);
io->recv_channel(party2).recv_data(tMs[party2], sizeof(block)*ssp);
io->recv_channel(party2).recv_data(tbs[party2], ssp);
for(int k = 0; k < ssp; ++k) {
if(tbs[party2][k])
Ks[party2][k] = Ks[party2][k] ^ Delta;
@@ -228,12 +228,12 @@ class ABitMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->send_data(party2, dgst0[party*ssp], Hash::DIGEST_SIZE*ssp);
io->send_data(party2, dgst1[party*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst0[party2*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_data(party2, dgst1[party2*ssp], Hash::DIGEST_SIZE*ssp);
io->send_channel(party2).send_data(dgst[party], Hash::DIGEST_SIZE);
io->send_channel(party2).send_data(dgst0[party*ssp], Hash::DIGEST_SIZE*ssp);
io->send_channel(party2).send_data(dgst1[party*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_channel(party2).recv_data(dgst[party2], Hash::DIGEST_SIZE);
io->recv_channel(party2).recv_data(dgst0[party2*ssp], Hash::DIGEST_SIZE*ssp);
io->recv_channel(party2).recv_data(dgst1[party2*ssp], Hash::DIGEST_SIZE*ssp);
}
vector<bool> res2;
@@ -243,16 +243,16 @@ class ABitMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, data + length - 3*ssp, ssp);
io->send_channel(party2).send_data(data + length - 3*ssp, ssp);
for(int k = 1; k <= nP; ++k) if(k != party)
io->send_data(party2, &MAC.at(k, length - 3*ssp), sizeof(block)*ssp);
io->send_channel(party2).send_data(&MAC.at(k, length - 3*ssp), sizeof(block)*ssp);
res2.push_back(false);
Hash h;
io->recv_data(party2, bs[party2], ssp);
io->recv_channel(party2).recv_data(bs[party2], ssp);
h.put(bs[party2], ssp);
for(int k = 1; k <= nP; ++k) if(k != party2) {
io->recv_data(party2, Ms[party2][k], sizeof(block)*ssp);
io->recv_channel(party2).recv_data(Ms[party2][k], sizeof(block)*ssp);
h.put(Ms[party2][k], sizeof(block)*ssp);
}
char tmp[Hash::DIGEST_SIZE];h.digest(tmp);
@@ -268,20 +268,20 @@ class ABitMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, bs[party], ssp);
io->send_channel(party2).send_data(bs[party], ssp);
for(int i = 0; i < ssp; ++i) {
if(bs[party][i])
io->send_data(party2, &Ks[1][i], sizeof(block));
io->send_channel(party2).send_data(&Ks[1][i], sizeof(block));
else
io->send_data(party2, &Ks[0][i], sizeof(block));
io->send_channel(party2).send_data(&Ks[0][i], sizeof(block));
}
io->flush(party2);
res2.push_back(false);
bool cheat = false;
bool *tmp_bool = new bool[ssp];
io->recv_data(party2, tmp_bool, ssp);
io->recv_data(party2, KK[party2], ssp*sizeof(block));
io->recv_channel(party2).recv_data(tmp_bool, ssp);
io->recv_channel(party2).recv_data(KK[party2], ssp*sizeof(block));
for(int i = 0; i < ssp; ++i) {
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &KK[party2][i], sizeof(block));

View File

@@ -162,9 +162,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, open_bit_shares_for_plaintext_input_send[party2].data(), sizeof(BitWithMac) * len);
io->send_channel(party2).send_data(open_bit_shares_for_plaintext_input_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
io->recv_data(party2, open_bit_shares_for_plaintext_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->recv_channel(party2).recv_data(open_bit_shares_for_plaintext_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}
}
@@ -222,9 +222,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, masked_input_sent.data(), sizeof(char) * len);
io->send_channel(party2).send_data(masked_input_sent.data(), sizeof(char) * len);
io->flush(party2);
io->recv_data(party2, masked_input_recv[party2].data(), sizeof(char) * len);
io->recv_channel(party2).recv_data(masked_input_recv[party2].data(), sizeof(char) * len);
io->flush(party2);
}
}
@@ -321,9 +321,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, open_bit_shares_for_authenticated_bits_send[party2].data(), sizeof(BitWithMac) * len);
io->send_channel(party2).send_data(open_bit_shares_for_authenticated_bits_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
io->recv_data(party2, open_bit_shares_for_authenticated_bits_recv[party2].data(), sizeof(BitWithMac) * len);
io->recv_channel(party2).recv_data(open_bit_shares_for_authenticated_bits_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}
}
@@ -394,9 +394,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, open_bit_shares_for_unauthenticated_bits_send.data(), sizeof(char) * len);
io->send_channel(party2).send_data(open_bit_shares_for_unauthenticated_bits_send.data(), sizeof(char) * len);
io->flush(party2);
io->recv_data(party2, open_bit_shares_for_unauthenticated_bits_recv[party2].data(), sizeof(char) * len);
io->recv_channel(party2).recv_data(open_bit_shares_for_unauthenticated_bits_recv[party2].data(), sizeof(char) * len);
io->flush(party2);
}
}
@@ -453,9 +453,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, open_bit_shares_for_public_input_send[party2].data(), sizeof(BitWithMac) * len);
io->send_channel(party2).send_data(open_bit_shares_for_public_input_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
io->recv_data(party2, open_bit_shares_for_public_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->recv_channel(party2).recv_data(open_bit_shares_for_public_input_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}
}
@@ -607,11 +607,11 @@ public:
}
for(int j = 2; j <= nP; j++) {
io->send_data(j, output_wire_label_send[j].data(), sizeof(block) * len);
io->send_channel(j).send_data(output_wire_label_send[j].data(), sizeof(block) * len);
io->flush(j);
}
}else {
io->recv_data(ALICE, output_wire_label_recv.data(), sizeof(block) * len);
io->recv_channel(ALICE).recv_data(output_wire_label_recv.data(), sizeof(block) * len);
io->flush(ALICE);
}
@@ -682,9 +682,9 @@ public:
if ((i < j) and (i == party or j == party)) {
int party2 = i + j - party;
io->send_data(party2, output_mask_send[party2].data(), sizeof(BitWithMac) * len);
io->send_channel(party2).send_data(output_mask_send[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
io->recv_data(party2, output_mask_recv[party2].data(), sizeof(BitWithMac) * len);
io->recv_channel(party2).recv_data(output_mask_recv[party2].data(), sizeof(BitWithMac) * len);
io->flush(party2);
}
}

View File

@@ -68,14 +68,14 @@ class FpreMP { public:
prgs[j].random_bool(&s.at(j, 0), length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = garble(&tKEY.at(j, 0), &tr[0], &s.at(j, 0), k, j);
io->send_data(j, &data, 1);
io->send_channel(j).send_data(&data, 1);
s.at(j, k) = (s.at(j, k) != (tr[3*k] and tr[3*k+1]));
}
io->flush(j);
} else if (j == party) {
for(int k = 0; k < length*bucket_size; ++k) {
uint8_t data = 0;
io->recv_data(i, &data, 1);
io->recv_channel(i).recv_data(&data, 1);
bool tmp = evaluate(data, &tMAC.at(i, 0), &tr[0], k, i);
s.at(i, k) = (tmp != (tr[3*k] and tr[3*k+1]));
}
@@ -97,11 +97,11 @@ class FpreMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, &e[0], length*bucket_size);
io->send_channel(party2).send_data(&e[0], length*bucket_size);
io->flush(party2);
bool * tmp = new bool[length*bucket_size];
io->recv_data(party2, tmp, length*bucket_size);
io->recv_channel(party2).recv_data(tmp, length*bucket_size);
for(int k = 0; k < length*bucket_size; ++k) {
if(tmp[k])
tKEY.at(party2, 3*k+2) = tKEY.at(party2, 3*k+2) ^ Delta;
@@ -135,7 +135,7 @@ class FpreMP { public:
tKEYphi.at(party2, k) = bH[0];
bH[1] = bH[0] ^ bH[1];
bH[1] = phi[k] ^ bH[1];
io->send_data(party2, &bH[1], sizeof(block));
io->send_channel(party2).send_data(&bH[1], sizeof(block));
}
io->flush(party2);
}
@@ -143,7 +143,7 @@ class FpreMP { public:
{
block bH;
for(int k = 0; k < length*bucket_size; ++k) {
io->recv_data(party2, &bH, sizeof(block));
io->recv_channel(party2).recv_data(&bH, sizeof(block));
block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi.at(party2, k) = prps2[party2].H(hin);
if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH;
@@ -153,7 +153,7 @@ class FpreMP { public:
{
block bH;
for(int k = 0; k < length*bucket_size; ++k) {
io->recv_data(party2, &bH, sizeof(block));
io->recv_channel(party2).recv_data(&bH, sizeof(block));
block hin = sigma(tMAC.at(party2, 3*k)) ^ makeBlock(0, 2*k+tr[3*k]);
tMACphi.at(party2, k) = prps2[party2].H(hin);
if(tr[3*k])tMACphi.at(party2, k) = tMACphi.at(party2, k) ^ bH;
@@ -169,7 +169,7 @@ class FpreMP { public:
tKEYphi.at(party2, k) = bH[0];
bH[1] = bH[0] ^ bH[1];
bH[1] = phi[k] ^ bH[1];
io->send_data(party2, &bH[1], sizeof(block));
io->send_channel(party2).send_data(&bH[1], sizeof(block));
}
io->flush(party2);
}
@@ -211,16 +211,16 @@ class FpreMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
io->send_channel(party2).send_data(dgst[party], Hash::DIGEST_SIZE);
io->recv_channel(party2).recv_data(dgst[party2], Hash::DIGEST_SIZE);
}
vector<bool> res2;
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, &X.at(party, 0), sizeof(block)*ssp);
io->recv_data(party2, &X.at(party2, 0), sizeof(block)*ssp);
io->send_channel(party2).send_data(&X.at(party, 0), sizeof(block)*ssp);
io->recv_channel(party2).recv_data(&X.at(party2, 0), sizeof(block)*ssp);
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &X.at(party2, 0), sizeof(block)*ssp);
res2.push_back(strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0);
@@ -276,9 +276,9 @@ class FpreMP { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, d[party], (bucket_size-1)*length);
io->send_channel(party2).send_data(d[party], (bucket_size-1)*length);
io->flush(party2);
io->recv_data(party2, d[party2], (bucket_size-1)*length);
io->recv_channel(party2).recv_data(d[party2], (bucket_size-1)*length);
}
for(int i = 2; i <= nP; ++i)
for(int j = 0; j < (bucket_size-1)*length; ++j)
@@ -349,12 +349,12 @@ class FpreMP { public:
block *tD = new block[length];
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, phi, length*sizeof(block));
io->send_data(j, &KEY.at(j, 0), sizeof(block)*length);
io->send_channel(j).send_data(phi, length*sizeof(block));
io->send_channel(j).send_data(&KEY.at(j, 0), sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, tD, length*sizeof(block));
io->recv_data(i, tmp, sizeof(block)*length);
io->recv_channel(i).recv_data(tD, length*sizeof(block));
io->recv_channel(i).recv_data(tmp, sizeof(block)*length);
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD[k];
}
@@ -375,7 +375,7 @@ class FpreMP { public:
block * tmp2 = new block[l];
memcpy(tmp1, b, l*sizeof(block));
for(int i = 2; i <= nP; ++i) {
io->recv_data(i, tmp2, l*sizeof(block));
io->recv_channel(i).recv_data(tmp2, l*sizeof(block));
xorBlocks_arr(tmp1, tmp1, tmp2, l);
}
block z = zero_block;
@@ -386,7 +386,7 @@ class FpreMP { public:
delete[] tmp1;
delete[] tmp2;
} else {
io->send_data(1, b, l*sizeof(block));
io->send_channel(1).send_data(b, l*sizeof(block));
io->flush(1);
}
}

View File

@@ -91,13 +91,13 @@ block sampleRandom(int nP, NetIOMP * io, PRG * prg, int party) {
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, dgst[party], Hash::DIGEST_SIZE);
io->recv_data(party2, dgst[party2], Hash::DIGEST_SIZE);
io->send_channel(party2).send_data(dgst[party], Hash::DIGEST_SIZE);
io->recv_channel(party2).recv_data(dgst[party2], Hash::DIGEST_SIZE);
}
for(int i = 1; i <= nP; ++i) for(int j = 1; j<= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, &S[party], sizeof(block));
io->recv_data(party2, &S[party2], sizeof(block));
io->send_channel(party2).send_data(&S[party], sizeof(block));
io->recv_channel(party2).recv_data(&S[party2], sizeof(block));
char tmp[Hash::DIGEST_SIZE];
Hash::hash_once(tmp, &S[party2], sizeof(block));
bool cheat = strncmp(tmp, dgst[party2], Hash::DIGEST_SIZE)!=0;
@@ -121,12 +121,12 @@ void check_MAC(int nP, NetIOMP * io, const NVec<block>& MAC, const NVec<block>&
block tD;
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if (i < j) {
if(party == i) {
io->send_data(j, &Delta, sizeof(block));
io->send_data(j, &KEY.at(j, 0), sizeof(block)*length);
io->send_channel(j).send_data(&Delta, sizeof(block));
io->send_channel(j).send_data(&KEY.at(j, 0), sizeof(block)*length);
io->flush(j);
} else if(party == j) {
io->recv_data(i, &tD, sizeof(block));
io->recv_data(i, tmp, sizeof(block)*length);
io->recv_channel(i).recv_data(&tD, sizeof(block));
io->recv_channel(i).recv_data(tmp, sizeof(block)*length);
for(int k = 0; k < length; ++k) {
if(r[k])tmp[k] = tmp[k] ^ tD;
}
@@ -145,7 +145,7 @@ void check_correctness(int nP, NetIOMP* io, bool * r, int length, int party) {
bool * tmp2 = new bool[length*3];
memcpy(tmp1, r, length*3);
for(int i = 2; i <= nP; ++i) {
io->recv_data(i, tmp2, length*3);
io->recv_channel(i).recv_data(tmp2, length*3);
for(int k = 0; k < length*3; ++k)
tmp1[k] = (tmp1[k] != tmp2[k]);
}
@@ -157,7 +157,7 @@ void check_correctness(int nP, NetIOMP* io, bool * r, int length, int party) {
delete[] tmp2;
cerr<<"check_correctness pass!\n"<<flush;
} else {
io->send_data(1, r, length*3);
io->send_channel(1).send_data(r, length*3);
io->flush(1);
}
}
@@ -193,4 +193,3 @@ inline std::string hex_to_binary(std::string hex) {
}
#endif// __HELPER

View File

@@ -171,12 +171,12 @@ class CMPC { public:
for(int i = 1; i <= nP; ++i) for(int j = 1; j <= nP; ++j) if( (i < j) and (i == party or j == party) ) {
int party2 = i + j - party;
io->send_data(party2, &x.at(party, 0), num_ands);
io->send_data(party2, &y.at(party, 0), num_ands);
io->send_channel(party2).send_data(&x.at(party, 0), num_ands);
io->send_channel(party2).send_data(&y.at(party, 0), num_ands);
io->flush(party2);
io->recv_data(party2, &x.at(party2, 0), num_ands);
io->recv_data(party2, &y.at(party2, 0), num_ands);
io->recv_channel(party2).recv_data(&x.at(party2, 0), num_ands);
io->recv_channel(party2).recv_data(&y.at(party2, 0), num_ands);
}
for(int i = 2; i <= nP; ++i) for(int j = 0; j < num_ands; ++j) {
x.at(1, j) = x.at(1, j) != x.at(i, j);
@@ -263,7 +263,7 @@ class CMPC { public:
H.at(j, party) = H.at(j, party) ^ Delta;
}
for(int j = 0; j < 4; ++j)
io->send_data(1, &H.at(j, 1), sizeof(block)*(nP));
io->send_channel(1).send_data(&H.at(j, 1), sizeof(block)*(nP));
++ands;
}
io->flush(1);
@@ -272,7 +272,7 @@ class CMPC { public:
int party2 = i;
for(int i = 0; i < num_ands; ++i)
for(int j = 0; j < 4; ++j)
io->recv_data(party2, &GT.at(i, party2, j, 1), sizeof(block)*(nP));
io->recv_channel(party2).recv_data(&GT.at(i, party2, j, 1), sizeof(block)*(nP));
}
for(int i = 0; i < cf->num_gate; ++i) if(cf->gates[4*i+3] == AND_GATE) {
r[0] = sigma_value[ands] != value[cf->gates[4*i+2]];
@@ -340,13 +340,13 @@ class CMPC { public:
for(int i = 0; i < num_in; ++i) {
block tmp = labels[i];
if(mask_input[i]) tmp = tmp ^ Delta;
io->send_data(1, &tmp, sizeof(block));
io->send_channel(1).send_data(&tmp, sizeof(block));
}
io->flush(1);
} else {
for(int i = 2; i <= nP; ++i) {
int party2 = i;
io->recv_data(party2, &eval_labels.at(party2, 0), num_in*sizeof(block));
io->recv_channel(party2).recv_data(&eval_labels.at(party2, 0), num_in*sizeof(block));
}
int ands = 0;

View File

@@ -66,24 +66,6 @@ class NetIOMP { public:
return res;
}
void send_data(int dst, const void * data, size_t len) {
assert(dst != 0);
assert(dst != party);
if(party < dst)
ios[dst]->send_data(data, len);
else
ios2[dst]->send_data(data, len);
}
void recv_data(int src, void * data, size_t len) {
assert(src != 0);
assert(src != party);
if(src < party)
ios[src]->recv_data(data, len);
else
ios2[src]->recv_data(data, len);
}
IOChannel& send_channel(int party2) {
assert(party2 != 0);
assert(party2 != party);