diff --git a/programs/2pc.cpp b/programs/2pc.cpp index 7aa1a8b..eaf6ddb 100644 --- a/programs/2pc.cpp +++ b/programs/2pc.cpp @@ -44,7 +44,7 @@ int main(int argc, char** argv) { auto net_io = std::make_shared(party == ALICE ? nullptr : IP, port); - IOChannel io(net_io); + IOChannel io(net_io); string file = circuit_file_location; @@ -63,21 +63,21 @@ int main(int argc, char** argv) { io.flush(); cout << "dep:\t" << party << "\t" << time_from(t1) << endl; - int input_size = party == ALICE ? 0 : 512; - std::vector in(input_size); + int input_size = party == ALICE ? 0 : 512; + std::vector in(input_size); - if (party == BOB) { - // we need a single starting 1 for a valid sha-1 block - // this will result in sha1("") == da39a3ee5e6b4b0d3255bfef95601890afd80709 - in[0] = true; + if (party == BOB) { + // we need a single starting 1 for a valid sha-1 block + // this will result in sha1("") == da39a3ee5e6b4b0d3255bfef95601890afd80709 + in[0] = true; - // also btw it's bob's input that matters for our circuit (see sha-1.txt) - // 512 0 160 - // | | ^ 160 output bits - // | ^ 0 input bits from Alice - // ^ 512 input bits from Bob - // (idk why but Bob's input comes first 🤷‍♂️) - } + // also btw it's bob's input that matters for our circuit (see sha-1.txt) + // 512 0 160 + // | | ^ 160 output bits + // | ^ 0 input bits from Alice + // ^ 512 input bits from Bob + // (idk why but Bob's input comes first 🤷‍♂️) + } t1 = clock_start(); std::vector out = twopc.online(in); diff --git a/programs/bench_lpn.cpp b/programs/bench_lpn.cpp index 1e087a4..454cd12 100644 --- a/programs/bench_lpn.cpp +++ b/programs/bench_lpn.cpp @@ -32,37 +32,37 @@ using namespace std; using namespace emp; int main(int argc, char** argv) { - PRG prg; - int k, n; - if (argc >= 3) { - k = atoi(argv[1]); - n = atoi(argv[2]); - } else { - k = 11; - n = 20; - } - if(n > 30 or k > 30) { - cout <<"Large test size! comment me if you want to run this size\n"; - exit(1); - } + PRG prg; + int k, n; + if (argc >= 3) { + k = atoi(argv[1]); + n = atoi(argv[2]); + } else { + k = 11; + n = 20; + } + if(n > 30 or k > 30) { + cout <<"Large test size! comment me if you want to run this size\n"; + exit(1); + } - block seed; - block * kk = new block[1< lpn(ALICE, 1<size()); - lpn.bench(nn, kk); - kk[0] = nn[0]; - } - cout << n<<"\t"< lpn(ALICE, 1<size()); + lpn.bench(nn, kk); + kk[0] = nn[0]; + } + cout << n<<"\t"<party = party; - this->cf = cf; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == AND_GATE) - ++num_ands; - } - cout << cf->n1<<" "<n2<<" "<n3<<" "<n1 + cf->n2 + num_ands; - fpre = new Fpre(io, party, num_ands); + C2PC(IOChannel io, int party, BristolFormat* cf) + : + io(io) + { + this->party = party; + this->cf = cf; + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == AND_GATE) + ++num_ands; + } + cout << cf->n1<<" "<n2<<" "<n3<<" "<n1 + cf->n2 + num_ands; + fpre = new Fpre(io, party, num_ands); - key = new block[cf->num_wire]; - mac = new block[cf->num_wire]; + key = new block[cf->num_wire]; + mac = new block[cf->num_wire]; - preprocess_mac = new block[total_pre]; - preprocess_key = new block[total_pre]; + preprocess_mac = new block[total_pre]; + preprocess_key = new block[total_pre]; - //sigma values in the paper - sigma_mac = new block[num_ands]; - sigma_key = new block[num_ands]; + //sigma values in the paper + sigma_mac = new block[num_ands]; + sigma_key = new block[num_ands]; - labels = new block[cf->num_wire]; + labels = new block[cf->num_wire]; - mask = new bool[cf->n1 + cf->n2]; - } - ~C2PC(){ - delete[] key; - delete[] mac; - delete[] mask; - delete[] GT; - delete[] GTK; - delete[] GTM; + mask = new bool[cf->n1 + cf->n2]; + } + ~C2PC(){ + delete[] key; + delete[] mac; + delete[] mask; + delete[] GT; + delete[] GTK; + delete[] GTM; - delete[] preprocess_mac; - delete[] preprocess_key; + delete[] preprocess_mac; + delete[] preprocess_key; - delete[] sigma_mac; - delete[] sigma_key; + delete[] sigma_mac; + delete[] sigma_key; - delete[] labels; - delete fpre; - } - PRG prg; - PRP prp; - block (* GT)[4][2] = nullptr; - block (* GTK)[4] = nullptr; - block (* GTM)[4] = nullptr; + delete[] labels; + delete fpre; + } + PRG prg; + PRP prp; + block (* GT)[4][2] = nullptr; + block (* GTK)[4] = nullptr; + block (* GTM)[4] = nullptr; - //not allocation - block * ANDS_mac = nullptr; - block * ANDS_key = nullptr; - void function_independent() { - if(party == ALICE) - prg.random_block(labels, cf->num_wire); + //not allocation + block * ANDS_mac = nullptr; + block * ANDS_key = nullptr; + void function_independent() { + if(party == ALICE) + prg.random_block(labels, cf->num_wire); - fpre->refill(); - ANDS_mac = fpre->MAC_res; - ANDS_key = fpre->KEY_res; + fpre->refill(); + ANDS_mac = fpre->MAC_res; + ANDS_key = fpre->KEY_res; - if(fpre->party == ALICE) { - fpre->abit1->send_dot(preprocess_key, total_pre); - fpre->abit2->recv_dot(preprocess_mac, total_pre); - } else { - fpre->abit1->recv_dot(preprocess_mac, total_pre); - fpre->abit2->send_dot(preprocess_key, total_pre); - } - memcpy(key, preprocess_key, (cf->n1+cf->n2)*sizeof(block)); - memcpy(mac, preprocess_mac, (cf->n1+cf->n2)*sizeof(block)); - } + if(fpre->party == ALICE) { + fpre->abit1->send_dot(preprocess_key, total_pre); + fpre->abit2->recv_dot(preprocess_mac, total_pre); + } else { + fpre->abit1->recv_dot(preprocess_mac, total_pre); + fpre->abit2->send_dot(preprocess_key, total_pre); + } + memcpy(key, preprocess_key, (cf->n1+cf->n2)*sizeof(block)); + memcpy(mac, preprocess_mac, (cf->n1+cf->n2)*sizeof(block)); + } - void function_dependent() { - int ands = cf->n1+cf->n2; - bool * x1 = new bool[num_ands]; - bool * y1 = new bool[num_ands]; - bool * x2 = new bool[num_ands]; - bool * y2 = new bool[num_ands]; + void function_dependent() { + int ands = cf->n1+cf->n2; + bool * x1 = new bool[num_ands]; + bool * y1 = new bool[num_ands]; + bool * x2 = new bool[num_ands]; + bool * y2 = new bool[num_ands]; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == AND_GATE) { - key[cf->gates[4*i+2]] = preprocess_key[ands]; - mac[cf->gates[4*i+2]] = preprocess_mac[ands]; - ++ands; - } - } + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == AND_GATE) { + key[cf->gates[4*i+2]] = preprocess_key[ands]; + mac[cf->gates[4*i+2]] = preprocess_mac[ands]; + ++ands; + } + } - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == XOR_GATE) { - key[cf->gates[4*i+2]] = key[cf->gates[4*i]] ^ key[cf->gates[4*i+1]]; - mac[cf->gates[4*i+2]] = mac[cf->gates[4*i]] ^ mac[cf->gates[4*i+1]]; - if(party == ALICE) - labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]]; - } else if (cf->gates[4*i+3] == NOT_GATE) { - if(party == ALICE) - labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ fpre->Delta; - key[cf->gates[4*i+2]] = key[cf->gates[4*i]]; - mac[cf->gates[4*i+2]] = mac[cf->gates[4*i]]; - } - } + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == XOR_GATE) { + key[cf->gates[4*i+2]] = key[cf->gates[4*i]] ^ key[cf->gates[4*i+1]]; + mac[cf->gates[4*i+2]] = mac[cf->gates[4*i]] ^ mac[cf->gates[4*i+1]]; + if(party == ALICE) + labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]]; + } else if (cf->gates[4*i+3] == NOT_GATE) { + if(party == ALICE) + labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ fpre->Delta; + key[cf->gates[4*i+2]] = key[cf->gates[4*i]]; + mac[cf->gates[4*i+2]] = mac[cf->gates[4*i]]; + } + } - ands = 0; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == AND_GATE) { - x1[ands] = getLSB(mac[cf->gates[4*i]] ^ANDS_mac[3*ands]); - y1[ands] = getLSB(mac[cf->gates[4*i+1]]^ANDS_mac[3*ands+1]); - ands++; - } - } - if(party == ALICE) { - io.send_bool(x1, num_ands); - io.send_bool(y1, num_ands); - io.recv_bool(x2, num_ands); - io.recv_bool(y2, num_ands); - } else { - io.recv_bool(x2, num_ands); - io.recv_bool(y2, num_ands); - io.send_bool(x1, num_ands); - io.send_bool(y1, num_ands); - } - io.flush(); - for(int i = 0; i < num_ands; ++i) { - x1[i] = logic_xor(x1[i], x2[i]); - y1[i] = logic_xor(y1[i], y2[i]); - } - ands = 0; - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == AND_GATE) { - sigma_mac[ands] = ANDS_mac[3*ands+2]; - sigma_key[ands] = ANDS_key[3*ands+2]; + ands = 0; + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == AND_GATE) { + x1[ands] = getLSB(mac[cf->gates[4*i]] ^ANDS_mac[3*ands]); + y1[ands] = getLSB(mac[cf->gates[4*i+1]]^ANDS_mac[3*ands+1]); + ands++; + } + } + if(party == ALICE) { + io.send_bool(x1, num_ands); + io.send_bool(y1, num_ands); + io.recv_bool(x2, num_ands); + io.recv_bool(y2, num_ands); + } else { + io.recv_bool(x2, num_ands); + io.recv_bool(y2, num_ands); + io.send_bool(x1, num_ands); + io.send_bool(y1, num_ands); + } + io.flush(); + for(int i = 0; i < num_ands; ++i) { + x1[i] = logic_xor(x1[i], x2[i]); + y1[i] = logic_xor(y1[i], y2[i]); + } + ands = 0; + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == AND_GATE) { + sigma_mac[ands] = ANDS_mac[3*ands+2]; + sigma_key[ands] = ANDS_key[3*ands+2]; - if(x1[ands]) { - sigma_mac[ands] = sigma_mac[ands] ^ ANDS_mac[3*ands+1]; - sigma_key[ands] = sigma_key[ands] ^ ANDS_key[3*ands+1]; - } - if(y1[ands]) { - sigma_mac[ands] = sigma_mac[ands] ^ ANDS_mac[3*ands]; - sigma_key[ands] = sigma_key[ands] ^ ANDS_key[3*ands]; - } - if(x1[ands] and y1[ands]) { - if(party == ALICE) - sigma_key[ands] = sigma_key[ands] ^ fpre->ZDelta; - else - sigma_mac[ands] = sigma_mac[ands] ^ fpre->one; - } + if(x1[ands]) { + sigma_mac[ands] = sigma_mac[ands] ^ ANDS_mac[3*ands+1]; + sigma_key[ands] = sigma_key[ands] ^ ANDS_key[3*ands+1]; + } + if(y1[ands]) { + sigma_mac[ands] = sigma_mac[ands] ^ ANDS_mac[3*ands]; + sigma_key[ands] = sigma_key[ands] ^ ANDS_key[3*ands]; + } + if(x1[ands] and y1[ands]) { + if(party == ALICE) + sigma_key[ands] = sigma_key[ands] ^ fpre->ZDelta; + else + sigma_mac[ands] = sigma_mac[ands] ^ fpre->one; + } - ands++; - } - }//sigma_[] stores the and of input wires to each AND gates + ands++; + } + }//sigma_[] stores the and of input wires to each AND gates - delete[] fpre->MAC; - delete[] fpre->KEY; - fpre->MAC = nullptr; - fpre->KEY = nullptr; - GT = new block[num_ands][4][2]; - GTK = new block[num_ands][4]; - GTM = new block[num_ands][4]; - - ands = 0; - block H[4][2]; - block K[4], M[4]; - for(int i = 0; i < cf->num_gate; ++i) { - if(cf->gates[4*i+3] == AND_GATE) { - M[0] = sigma_mac[ands] ^ mac[cf->gates[4*i+2]]; - M[1] = M[0] ^ mac[cf->gates[4*i]]; - M[2] = M[0] ^ mac[cf->gates[4*i+1]]; - M[3] = M[1] ^ mac[cf->gates[4*i+1]]; - if(party == BOB) - M[3] = M[3] ^ fpre->one; + delete[] fpre->MAC; + delete[] fpre->KEY; + fpre->MAC = nullptr; + fpre->KEY = nullptr; + GT = new block[num_ands][4][2]; + GTK = new block[num_ands][4]; + GTM = new block[num_ands][4]; - K[0] = sigma_key[ands] ^ key[cf->gates[4*i+2]]; - K[1] = K[0] ^ key[cf->gates[4*i]]; - K[2] = K[0] ^ key[cf->gates[4*i+1]]; - K[3] = K[1] ^ key[cf->gates[4*i+1]]; - if(party == ALICE) - K[3] = K[3] ^ fpre->ZDelta; + ands = 0; + block H[4][2]; + block K[4], M[4]; + for(int i = 0; i < cf->num_gate; ++i) { + if(cf->gates[4*i+3] == AND_GATE) { + M[0] = sigma_mac[ands] ^ mac[cf->gates[4*i+2]]; + M[1] = M[0] ^ mac[cf->gates[4*i]]; + M[2] = M[0] ^ mac[cf->gates[4*i+1]]; + M[3] = M[1] ^ mac[cf->gates[4*i+1]]; + if(party == BOB) + M[3] = M[3] ^ fpre->one; - if(party == ALICE) { - Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], i); - for(int j = 0; j < 4; ++j) { - H[j][0] = H[j][0] ^ M[j]; - H[j][1] = H[j][1] ^ K[j] ^ labels[cf->gates[4*i+2]]; - if(getLSB(M[j])) - H[j][1] = H[j][1] ^fpre->Delta; + K[0] = sigma_key[ands] ^ key[cf->gates[4*i+2]]; + K[1] = K[0] ^ key[cf->gates[4*i]]; + K[2] = K[0] ^ key[cf->gates[4*i+1]]; + K[3] = K[1] ^ key[cf->gates[4*i+1]]; + if(party == ALICE) + K[3] = K[3] ^ fpre->ZDelta; + + if(party == ALICE) { + Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], i); + for(int j = 0; j < 4; ++j) { + H[j][0] = H[j][0] ^ M[j]; + H[j][1] = H[j][1] ^ K[j] ^ labels[cf->gates[4*i+2]]; + if(getLSB(M[j])) + H[j][1] = H[j][1] ^fpre->Delta; #ifdef __debug - check2(M[j], K[j]); + check2(M[j], K[j]); #endif - } - for(int j = 0; j < 4; ++j ) { - send_partial_block(io, &H[j][0], 1); - io.send_block(&H[j][1], 1); - } - } else { - memcpy(GTK[ands], K, sizeof(block)*4); - memcpy(GTM[ands], M, sizeof(block)*4); + } + for(int j = 0; j < 4; ++j ) { + send_partial_block(io, &H[j][0], 1); + io.send_block(&H[j][1], 1); + } + } else { + memcpy(GTK[ands], K, sizeof(block)*4); + memcpy(GTM[ands], M, sizeof(block)*4); #ifdef __debug - for(int j = 0; j < 4; ++j) - check2(M[j], K[j]); + for(int j = 0; j < 4; ++j) + check2(M[j], K[j]); #endif - for(int j = 0; j < 4; ++j ) { - recv_partial_block(io, >[ands][j][0], 1); - io.recv_block(>[ands][j][1], 1); - } - } - ++ands; - } - } - delete[] x1; - delete[] x2; - delete[] y1; - delete[] y2; + for(int j = 0; j < 4; ++j ) { + recv_partial_block(io, >[ands][j][0], 1); + io.recv_block(>[ands][j][1], 1); + } + } + ++ands; + } + } + delete[] x1; + delete[] x2; + delete[] y1; + delete[] y2; - block tmp; - if(party == ALICE) { - send_partial_block(io, mac, cf->n1); - for(int i = cf->n1; i < cf->n1+cf->n2; ++i) { - recv_partial_block(io, &tmp, 1); - block ttt = key[i] ^ fpre->Delta; - ttt = ttt & MASK; - block mask_key = key[i] & MASK; - tmp = tmp & MASK; - if(cmpBlock(&tmp, &mask_key, 1)) - mask[i] = false; - else if(cmpBlock(&tmp, &ttt, 1)) - mask[i] = true; - else cout <<"no match! ALICE\t"<n1; ++i) { - recv_partial_block(io, &tmp, 1); - block ttt = key[i] ^ fpre->Delta; - ttt = ttt & MASK; - tmp = tmp & MASK; - block mask_key = key[i] & MASK; - if(cmpBlock(&tmp, &mask_key, 1)) { - mask[i] = false; - } else if(cmpBlock(&tmp, &ttt, 1)) { - mask[i] = true; - } - else cout <<"no match! BOB\t"<(io, mac, cf->n1); + for(int i = cf->n1; i < cf->n1+cf->n2; ++i) { + recv_partial_block(io, &tmp, 1); + block ttt = key[i] ^ fpre->Delta; + ttt = ttt & MASK; + block mask_key = key[i] & MASK; + tmp = tmp & MASK; + if(cmpBlock(&tmp, &mask_key, 1)) + mask[i] = false; + else if(cmpBlock(&tmp, &ttt, 1)) + mask[i] = true; + else cout <<"no match! ALICE\t"<n1; ++i) { + recv_partial_block(io, &tmp, 1); + block ttt = key[i] ^ fpre->Delta; + ttt = ttt & MASK; + tmp = tmp & MASK; + block mask_key = key[i] & MASK; + if(cmpBlock(&tmp, &mask_key, 1)) { + mask[i] = false; + } else if(cmpBlock(&tmp, &ttt, 1)) { + mask[i] = true; + } + else cout <<"no match! BOB\t"<(io, mac+cf->n1, cf->n2); - } - io.flush(); - } + send_partial_block(io, mac+cf->n1, cf->n2); + } + io.flush(); + } - std::vector online( - const std::vector& input, - bool alice_output = false - ) { - std::vector output(cf->n3); + std::vector online( + const std::vector& input, + bool alice_output = false + ) { + std::vector output(cf->n3); - int correct_input_size = party == ALICE ? cf->n2 : cf->n1; + int correct_input_size = party == ALICE ? cf->n2 : cf->n1; - if (input.size() != correct_input_size) { - throw std::invalid_argument("input size does not match circuit"); - } + if (input.size() != correct_input_size) { + throw std::invalid_argument("input size does not match circuit"); + } - uint8_t * mask_input = new uint8_t[cf->num_wire]; - memset(mask_input, 0, cf->num_wire); - block tmp; + uint8_t * mask_input = new uint8_t[cf->num_wire]; + memset(mask_input, 0, cf->num_wire); + block tmp; #ifdef __debug - for(int i = 0; i < cf->n1+cf->n2; ++i) - check2(mac[i], key[i]); + for(int i = 0; i < cf->n1+cf->n2; ++i) + check2(mac[i], key[i]); #endif - if(party == ALICE) { - for(int i = cf->n1; i < cf->n1+cf->n2; ++i) { - mask_input[i] = logic_xor(input[i - cf->n1], getLSB(mac[i])); - mask_input[i] = logic_xor(mask_input[i], mask[i]); - } - io.recv_data(mask_input, cf->n1); - io.send_data(mask_input+cf->n1, cf->n2); - for(int i = 0; i < cf->n1 + cf->n2; ++i) { - tmp = labels[i]; - if(mask_input[i]) tmp = tmp ^ fpre->Delta; - io.send_block(&tmp, 1); - } - //send output mask data - send_partial_block(io, mac+cf->num_wire - cf->n3, cf->n3); - } else { - for(int i = 0; i < cf->n1; ++i) { - mask_input[i] = logic_xor(input[i], getLSB(mac[i])); - mask_input[i] = logic_xor(mask_input[i], mask[i]); - } - io.send_data(mask_input, cf->n1); - io.recv_data(mask_input+cf->n1, cf->n2); - io.recv_block(labels, cf->n1 + cf->n2); - } - int ands = 0; - if(party == BOB) { - for(int i = 0; i < cf->num_gate; ++i) { - if (cf->gates[4*i+3] == XOR_GATE) { - labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]]; - mask_input[cf->gates[4*i+2]] = logic_xor(mask_input[cf->gates[4*i]], mask_input[cf->gates[4*i+1]]); - } else if (cf->gates[4*i+3] == AND_GATE) { - int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]]; - block H[2]; - Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], i, index); - GT[ands][index][0] = GT[ands][index][0] ^ H[0]; - GT[ands][index][1] = GT[ands][index][1] ^ H[1]; + if(party == ALICE) { + for(int i = cf->n1; i < cf->n1+cf->n2; ++i) { + mask_input[i] = logic_xor(input[i - cf->n1], getLSB(mac[i])); + mask_input[i] = logic_xor(mask_input[i], mask[i]); + } + io.recv_data(mask_input, cf->n1); + io.send_data(mask_input+cf->n1, cf->n2); + for(int i = 0; i < cf->n1 + cf->n2; ++i) { + tmp = labels[i]; + if(mask_input[i]) tmp = tmp ^ fpre->Delta; + io.send_block(&tmp, 1); + } + //send output mask data + send_partial_block(io, mac+cf->num_wire - cf->n3, cf->n3); + } else { + for(int i = 0; i < cf->n1; ++i) { + mask_input[i] = logic_xor(input[i], getLSB(mac[i])); + mask_input[i] = logic_xor(mask_input[i], mask[i]); + } + io.send_data(mask_input, cf->n1); + io.recv_data(mask_input+cf->n1, cf->n2); + io.recv_block(labels, cf->n1 + cf->n2); + } + int ands = 0; + if(party == BOB) { + for(int i = 0; i < cf->num_gate; ++i) { + if (cf->gates[4*i+3] == XOR_GATE) { + labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]] ^ labels[cf->gates[4*i+1]]; + mask_input[cf->gates[4*i+2]] = logic_xor(mask_input[cf->gates[4*i]], mask_input[cf->gates[4*i+1]]); + } else if (cf->gates[4*i+3] == AND_GATE) { + int index = 2*mask_input[cf->gates[4*i]] + mask_input[cf->gates[4*i+1]]; + block H[2]; + Hash(H, labels[cf->gates[4*i]], labels[cf->gates[4*i+1]], i, index); + GT[ands][index][0] = GT[ands][index][0] ^ H[0]; + GT[ands][index][1] = GT[ands][index][1] ^ H[1]; - block ttt = GTK[ands][index] ^ fpre->Delta; - ttt = ttt & MASK; - GTK[ands][index] = GTK[ands][index] & MASK; - GT[ands][index][0] = GT[ands][index][0] & MASK; + block ttt = GTK[ands][index] ^ fpre->Delta; + ttt = ttt & MASK; + GTK[ands][index] = GTK[ands][index] & MASK; + GT[ands][index][0] = GT[ands][index][0] & MASK; - if(cmpBlock(>[ands][index][0], >K[ands][index], 1)) - mask_input[cf->gates[4*i+2]] = false; - else if(cmpBlock(>[ands][index][0], &ttt, 1)) - mask_input[cf->gates[4*i+2]] = true; - else cout <gates[4*i+2]] = logic_xor(mask_input[cf->gates[4*i+2]], getLSB(GTM[ands][index])); + if(cmpBlock(>[ands][index][0], >K[ands][index], 1)) + mask_input[cf->gates[4*i+2]] = false; + else if(cmpBlock(>[ands][index][0], &ttt, 1)) + mask_input[cf->gates[4*i+2]] = true; + else cout <gates[4*i+2]] = logic_xor(mask_input[cf->gates[4*i+2]], getLSB(GTM[ands][index])); - labels[cf->gates[4*i+2]] = GT[ands][index][1] ^ GTM[ands][index]; - ands++; - } else { - mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]]; - labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]]; - } - } - } - if (party == BOB) { - bool * o = new bool[cf->n3]; - for(int i = 0; i < cf->n3; ++i) { - block tmp; - recv_partial_block(io, &tmp, 1); - tmp = tmp & MASK; + labels[cf->gates[4*i+2]] = GT[ands][index][1] ^ GTM[ands][index]; + ands++; + } else { + mask_input[cf->gates[4*i+2]] = not mask_input[cf->gates[4*i]]; + labels[cf->gates[4*i+2]] = labels[cf->gates[4*i]]; + } + } + } + if (party == BOB) { + bool * o = new bool[cf->n3]; + for(int i = 0; i < cf->n3; ++i) { + block tmp; + recv_partial_block(io, &tmp, 1); + tmp = tmp & MASK; - block ttt = key[cf->num_wire - cf-> n3 + i] ^ fpre->Delta; - ttt = ttt & MASK; - key[cf->num_wire - cf-> n3 + i] = key[cf->num_wire - cf-> n3 + i] & MASK; + block ttt = key[cf->num_wire - cf-> n3 + i] ^ fpre->Delta; + ttt = ttt & MASK; + key[cf->num_wire - cf-> n3 + i] = key[cf->num_wire - cf-> n3 + i] & MASK; - if(cmpBlock(&tmp, &key[cf->num_wire - cf-> n3 + i], 1)) - o[i] = false; - else if(cmpBlock(&tmp, &ttt, 1)) - o[i] = true; - else cout <<"no match output label!"<n3; ++i) { - output[i] = logic_xor(o[i], mask_input[cf->num_wire - cf->n3 + i]); - output[i] = logic_xor(output[i], getLSB(mac[cf->num_wire - cf->n3 + i])); - } - delete[] o; - if(alice_output) { - send_partial_block(io, mac+cf->num_wire - cf->n3, cf->n3); - send_partial_block(io, labels+cf->num_wire - cf->n3, cf->n3); - io.send_data(mask_input + cf->num_wire - cf->n3, cf->n3); - io.flush(); - } - } else {//ALICE - if(alice_output) { - block * tmp_mac = new block[cf->n3]; - block * tmp_label = new block[cf->n3]; - bool * tmp_mask_input = new bool[cf->n3]; - recv_partial_block(io, tmp_mac, cf->n3); - recv_partial_block(io, tmp_label, cf->n3); - io.recv_data(tmp_mask_input, cf->n3); - io.flush(); - for(int i = 0; i < cf->n3; ++i) { - block tmp = tmp_mac[i]; - tmp = tmp & MASK; + if(cmpBlock(&tmp, &key[cf->num_wire - cf-> n3 + i], 1)) + o[i] = false; + else if(cmpBlock(&tmp, &ttt, 1)) + o[i] = true; + else cout <<"no match output label!"<n3; ++i) { + output[i] = logic_xor(o[i], mask_input[cf->num_wire - cf->n3 + i]); + output[i] = logic_xor(output[i], getLSB(mac[cf->num_wire - cf->n3 + i])); + } + delete[] o; + if(alice_output) { + send_partial_block(io, mac+cf->num_wire - cf->n3, cf->n3); + send_partial_block(io, labels+cf->num_wire - cf->n3, cf->n3); + io.send_data(mask_input + cf->num_wire - cf->n3, cf->n3); + io.flush(); + } + } else {//ALICE + if(alice_output) { + block * tmp_mac = new block[cf->n3]; + block * tmp_label = new block[cf->n3]; + bool * tmp_mask_input = new bool[cf->n3]; + recv_partial_block(io, tmp_mac, cf->n3); + recv_partial_block(io, tmp_label, cf->n3); + io.recv_data(tmp_mask_input, cf->n3); + io.flush(); + for(int i = 0; i < cf->n3; ++i) { + block tmp = tmp_mac[i]; + tmp = tmp & MASK; - block ttt = key[cf->num_wire - cf-> n3 + i] ^ fpre->Delta; - ttt = ttt & MASK; - key[cf->num_wire - cf-> n3 + i] = key[cf->num_wire - cf-> n3 + i] & MASK; + block ttt = key[cf->num_wire - cf-> n3 + i] ^ fpre->Delta; + ttt = ttt & MASK; + key[cf->num_wire - cf-> n3 + i] = key[cf->num_wire - cf-> n3 + i] & MASK; - if(cmpBlock(&tmp, &key[cf->num_wire - cf-> n3 + i], 1)) - output[i] = false; - else if(cmpBlock(&tmp, &ttt, 1)) - output[i] = true; - else cout <<"no match output label!"<Delta; - mask_label = mask_label & MASK; - block masked_labels = labels[cf->num_wire - cf-> n3 + i] & MASK; - if(!cmpBlock(&mask_label, &masked_labels, 1)) - cout <<"no match output label2!"<num_wire - cf-> n3 + i], 1)) + output[i] = false; + else if(cmpBlock(&tmp, &ttt, 1)) + output[i] = true; + else cout <<"no match output label!"<Delta; + mask_label = mask_label & MASK; + block masked_labels = labels[cf->num_wire - cf-> n3 + i] & MASK; + if(!cmpBlock(&mask_label, &masked_labels, 1)) + cout <<"no match output label2!"<num_wire - cf->n3 + i])); - } - delete[] tmp_mac; - delete[] tmp_label; - delete[] tmp_mask_input; - } + output[i] = logic_xor(output[i], tmp_mask_input[i]); + output[i] = logic_xor(output[i], getLSB(mac[cf->num_wire - cf->n3 + i])); + } + delete[] tmp_mac; + delete[] tmp_label; + delete[] tmp_mask_input; + } - } - delete[] mask_input; + } + delete[] mask_input; - return output; - } + return output; + } - void check(block * MAC, block * KEY, bool * r, int length = 1) { - if (party == ALICE) { - io.send_data(r, length*3); - io.send_block(&fpre->Delta, 1); - io.send_block(KEY, length*3); - block DD;io.recv_block(&DD, 1); + void check(block * MAC, block * KEY, bool * r, int length = 1) { + if (party == ALICE) { + io.send_data(r, length*3); + io.send_block(&fpre->Delta, 1); + io.send_block(KEY, length*3); + block DD;io.recv_block(&DD, 1); - for(int i = 0; i < length*3; ++i) { - block tmp;io.recv_block(&tmp, 1); - if(r[i]) tmp = tmp ^ DD; - if (!cmpBlock(&tmp, &MAC[i], 1)) - cout <Delta, 1); - io.send_block(KEY, length*3); - } - io.flush(); - } + for(int i = 0; i < length*3; ++i) { + block tmp;io.recv_block(&tmp, 1); + if(r[i]) tmp = tmp ^ DD; + if (!cmpBlock(&tmp, &MAC[i], 1)) + cout <Delta, 1); + io.send_block(KEY, length*3); + } + io.flush(); + } - void check2(block & MAC, block & KEY) { - if (party == ALICE) { - io.send_block(&fpre->Delta, 1); - io.send_block(&KEY, 1); - block DD;io.recv_block(&DD, 1); - for(int i = 0; i < 1; ++i) { - block tmp;io.recv_block(&tmp, 1); - if(getLSB(MAC)) tmp = tmp ^ DD; - if (!cmpBlock(&tmp, &MAC, 1)) - cout <Delta, 1); - io.send_block(&KEY, 1); - } - io.flush(); - } + void check2(block & MAC, block & KEY) { + if (party == ALICE) { + io.send_block(&fpre->Delta, 1); + io.send_block(&KEY, 1); + block DD;io.recv_block(&DD, 1); + for(int i = 0; i < 1; ++i) { + block tmp;io.recv_block(&tmp, 1); + if(getLSB(MAC)) tmp = tmp ^ DD; + if (!cmpBlock(&tmp, &MAC, 1)) + cout <Delta, 1); + io.send_block(&KEY, 1); + } + io.flush(); + } - void Hash(block H[4][2], const block & a, const block & b, uint64_t i) { - block A[2], B[2]; - A[0] = a; A[1] = a ^ fpre->Delta; - B[0] = b; B[1] = b ^ fpre->Delta; - A[0] = sigma(A[0]); - A[1] = sigma(A[1]); - B[0] = sigma(sigma(B[0])); - B[1] = sigma(sigma(B[1])); + void Hash(block H[4][2], const block & a, const block & b, uint64_t i) { + block A[2], B[2]; + A[0] = a; A[1] = a ^ fpre->Delta; + B[0] = b; B[1] = b ^ fpre->Delta; + A[0] = sigma(A[0]); + A[1] = sigma(A[1]); + B[0] = sigma(sigma(B[0])); + B[1] = sigma(sigma(B[1])); - H[0][1] = H[0][0] = A[0] ^ B[0]; - H[1][1] = H[1][0] = A[0] ^ B[1]; - H[2][1] = H[2][0] = A[1] ^ B[0]; - H[3][1] = H[3][0] = A[1] ^ B[1]; - for(uint64_t j = 0; j < 4; ++j) { - H[j][0] = H[j][0] ^ makeBlock(4*i+j, 0); - H[j][1] = H[j][1] ^ makeBlock(4*i+j, 1); - } - prp.permute_block((block *)H, 8); - } + H[0][1] = H[0][0] = A[0] ^ B[0]; + H[1][1] = H[1][0] = A[0] ^ B[1]; + H[2][1] = H[2][0] = A[1] ^ B[0]; + H[3][1] = H[3][0] = A[1] ^ B[1]; + for(uint64_t j = 0; j < 4; ++j) { + H[j][0] = H[j][0] ^ makeBlock(4*i+j, 0); + H[j][1] = H[j][1] ^ makeBlock(4*i+j, 1); + } + prp.permute_block((block *)H, 8); + } - void Hash(block H[2], block a, block b, uint64_t i, uint64_t row) { - a = sigma(a); - b = sigma(sigma(b)); - H[0] = H[1] = a ^ b; - H[0] = H[0] ^ makeBlock(4*i+row, 0); - H[1] = H[1] ^ makeBlock(4*i+row, 1); - prp.permute_block((block *)H, 2); - } + void Hash(block H[2], block a, block b, uint64_t i, uint64_t row) { + a = sigma(a); + b = sigma(sigma(b)); + H[0] = H[1] = a ^ b; + H[0] = H[0] ^ makeBlock(4*i+row, 0); + H[1] = H[1] ^ makeBlock(4*i+row, 1); + prp.permute_block((block *)H, 2); + } - bool logic_xor(bool a, bool b) { - return a!= b; - } + bool logic_xor(bool a, bool b) { + return a!= b; + } }; } diff --git a/src/emp-ag2pc/feq.h b/src/emp-ag2pc/feq.h index 1babadd..ccba58a 100644 --- a/src/emp-ag2pc/feq.h +++ b/src/emp-ag2pc/feq.h @@ -5,43 +5,43 @@ namespace emp { class Feq { public: - Hash h; - IOChannel io; - int party; - Feq(IOChannel io, int party): io(io) { - this->party = party; - } - void add_block(const block & in) { - h.put(&in, sizeof(block)); - } + Hash h; + IOChannel io; + int party; + Feq(IOChannel io, int party): io(io) { + this->party = party; + } + void add_block(const block & in) { + h.put(&in, sizeof(block)); + } - void add_data(const void * data, int len) { - h.put(data, len); - } + void add_data(const void * data, int len) { + h.put(data, len); + } - void dgst(char * dgst) { - h.digest(dgst); - } - bool compare() { - char AR[Hash::DIGEST_SIZE+16]; - char dgst[Hash::DIGEST_SIZE]; - h.digest(AR); - if(party == ALICE) { - PRG prg; - prg.random_data(AR+Hash::DIGEST_SIZE, 16); - Hash::hash_once(dgst, AR, Hash::DIGEST_SIZE+16); - io.send_data(dgst, Hash::DIGEST_SIZE); - io.recv_data(dgst, Hash::DIGEST_SIZE); - io.send_data(AR+Hash::DIGEST_SIZE, 16); - } else { - io.recv_data(dgst, Hash::DIGEST_SIZE); - io.send_data(AR, Hash::DIGEST_SIZE); - io.recv_data(AR+Hash::DIGEST_SIZE, 16); - Hash::hash_once(AR, AR, Hash::DIGEST_SIZE+16); - } - io.flush(); - return memcmp(dgst, AR, Hash::DIGEST_SIZE)==0; - } + void dgst(char * dgst) { + h.digest(dgst); + } + bool compare() { + char AR[Hash::DIGEST_SIZE+16]; + char dgst[Hash::DIGEST_SIZE]; + h.digest(AR); + if(party == ALICE) { + PRG prg; + prg.random_data(AR+Hash::DIGEST_SIZE, 16); + Hash::hash_once(dgst, AR, Hash::DIGEST_SIZE+16); + io.send_data(dgst, Hash::DIGEST_SIZE); + io.recv_data(dgst, Hash::DIGEST_SIZE); + io.send_data(AR+Hash::DIGEST_SIZE, 16); + } else { + io.recv_data(dgst, Hash::DIGEST_SIZE); + io.send_data(AR, Hash::DIGEST_SIZE); + io.recv_data(AR+Hash::DIGEST_SIZE, 16); + Hash::hash_once(AR, AR, Hash::DIGEST_SIZE+16); + } + io.flush(); + return memcmp(dgst, AR, Hash::DIGEST_SIZE)==0; + } }; } diff --git a/src/emp-ag2pc/fpre.h b/src/emp-ag2pc/fpre.h index e516e48..41b79a7 100644 --- a/src/emp-ag2pc/fpre.h +++ b/src/emp-ag2pc/fpre.h @@ -11,350 +11,350 @@ namespace emp { //#define __debug class Fpre { - public: - IOChannel io; - int batch_size = 0, bucket_size = 0, size = 0; - int party; - block * keys = nullptr; - bool * values = nullptr; - PRG prg; - PRP prp; - PRP *prps; - LeakyDeltaOT *abit1, *abit2; - block Delta; - block ZDelta; - block one; - Feq *eq[2]; - block * MAC = nullptr, *KEY = nullptr; - block * MAC_res = nullptr, *KEY_res = nullptr; - block * pretable = nullptr; - Fpre(IOChannel io, int in_party, int bsize = 1000): io(io) { - prps = new PRP[2]; - this->party = in_party; + public: + IOChannel io; + int batch_size = 0, bucket_size = 0, size = 0; + int party; + block * keys = nullptr; + bool * values = nullptr; + PRG prg; + PRP prp; + PRP *prps; + LeakyDeltaOT *abit1, *abit2; + block Delta; + block ZDelta; + block one; + Feq *eq[2]; + block * MAC = nullptr, *KEY = nullptr; + block * MAC_res = nullptr, *KEY_res = nullptr; + block * pretable = nullptr; + Fpre(IOChannel io, int in_party, int bsize = 1000): io(io) { + prps = new PRP[2]; + this->party = in_party; - eq[0] = new Feq(io, party); - eq[1] = new Feq(io, party); + eq[0] = new Feq(io, party); + eq[1] = new Feq(io, party); - abit1 = new LeakyDeltaOT(io); - abit2 = new LeakyDeltaOT(io); + abit1 = new LeakyDeltaOT(io); + abit2 = new LeakyDeltaOT(io); - bool tmp_s[128]; - prg.random_bool(tmp_s, 128); - tmp_s[0] = true; - if(party == ALICE) { - tmp_s[1] = true; - abit1->setup_send(tmp_s); - io.flush(); - abit2->setup_recv(); - } else { - tmp_s[1] = false; - abit1->setup_recv(); - io.flush(); - abit2->setup_send(tmp_s); - } - io.flush(); + bool tmp_s[128]; + prg.random_bool(tmp_s, 128); + tmp_s[0] = true; + if(party == ALICE) { + tmp_s[1] = true; + abit1->setup_send(tmp_s); + io.flush(); + abit2->setup_recv(); + } else { + tmp_s[1] = false; + abit1->setup_recv(); + io.flush(); + abit2->setup_send(tmp_s); + } + io.flush(); - abit1 = new LeakyDeltaOT(io); - abit2 = new LeakyDeltaOT(io); - if(party == ALICE) { - abit1->setup_send(tmp_s, abit1->k0); - abit2->setup_recv(abit2->k0, abit2->k1); - } else { - abit2->setup_send(tmp_s, abit2->k0); - abit1->setup_recv(abit1->k0, abit1->k1); - } + abit1 = new LeakyDeltaOT(io); + abit2 = new LeakyDeltaOT(io); + if(party == ALICE) { + abit1->setup_send(tmp_s, abit1->k0); + abit2->setup_recv(abit2->k0, abit2->k1); + } else { + abit2->setup_send(tmp_s, abit2->k0); + abit1->setup_recv(abit1->k0, abit1->k1); + } - if(party == ALICE) Delta = abit1->Delta; - else Delta = abit2->Delta; - one = makeBlock(0, 1); - ZDelta = Delta & makeBlock(0xFFFFFFFFFFFFFFFF,0xFFFFFFFFFFFFFFFE); - set_batch_size(bsize); - } - int permute_batch_size; - void set_batch_size(int size) { - size = std::max(size, 320); - batch_size = ((size+1)/2)*2; - if(batch_size >= 280*1000) { - bucket_size = 3; - permute_batch_size = 280000; - } - else if(batch_size >= 3100) { - bucket_size = 4; - permute_batch_size = 3100; - } - else bucket_size = 5; - - delete[] MAC; - delete[] KEY; - - MAC = new block[batch_size * bucket_size * 3]; - KEY = new block[batch_size * bucket_size * 3]; - MAC_res = new block[batch_size * 3]; - KEY_res = new block[batch_size * 3]; -// cout << size<<"\t"<Delta; + else Delta = abit2->Delta; + one = makeBlock(0, 1); + ZDelta = Delta & makeBlock(0xFFFFFFFFFFFFFFFF,0xFFFFFFFFFFFFFFFE); + set_batch_size(bsize); + } + int permute_batch_size; + void set_batch_size(int size) { + size = std::max(size, 320); + batch_size = ((size+1)/2)*2; + if(batch_size >= 280*1000) { + bucket_size = 3; + permute_batch_size = 280000; + } + else if(batch_size >= 3100) { + bucket_size = 4; + permute_batch_size = 3100; + } + else bucket_size = 5; - delete abit1; - delete abit2; - delete eq[0]; - delete eq[1]; - } - void refill() { - auto start_time = clock_start(); + delete[] MAC; + delete[] KEY; - int start = 0; - int length = batch_size; + MAC = new block[batch_size * bucket_size * 3]; + KEY = new block[batch_size * bucket_size * 3]; + MAC_res = new block[batch_size * 3]; + KEY_res = new block[batch_size * 3]; +// cout << size<<"\t"< 4) { - combine(S, 0, MAC, KEY, batch_size, bucket_size, MAC_res, KEY_res); - } else { - int width = min((batch_size), permute_batch_size); - - int start = 0; - int length = min(width, batch_size); - combine(S, 0, MAC+start*bucket_size*3, KEY+start*bucket_size*3, length, bucket_size, MAC_res+start*3, KEY_res+start*3); - } - if(party == ALICE) { - cout <<"permute\t"< 4) { + combine(S, 0, MAC, KEY, batch_size, bucket_size, MAC_res, KEY_res); + } else { + int width = min((batch_size), permute_batch_size); + + int start = 0; + int length = min(width, batch_size); + combine(S, 0, MAC+start*bucket_size*3, KEY+start*bucket_size*3, length, bucket_size, MAC_res+start*3, KEY_res+start*3); + } + if(party == ALICE) { + cout <<"permute\t"<dgst(dgst); - eq[0]->add_data(dgst, Hash::DIGEST_SIZE); - } - if(!eq[0]->compare()) { - error("FEQ error\n"); - } - } + char dgst[Hash::DIGEST_SIZE]; + for(int i = 1; i < 2; ++i) { + eq[i]->dgst(dgst); + eq[0]->add_data(dgst, Hash::DIGEST_SIZE); + } + if(!eq[0]->compare()) { + error("FEQ error\n"); + } + } - void generate(block * MAC, block * KEY, int length) { - if (party == ALICE) { - abit1->send_dot(KEY, length*3); - abit2->recv_dot(MAC, length*3); - } else { - // TODO: I (Andrew) swapped the two lines below to remove a - // deadlock after I removed threading. I suspect that's fine - // but I need to check. - abit1->recv_dot(MAC, length*3); - abit2->send_dot(KEY, length*3); - } - } - - void check(block * MAC, block * KEY, int length, int I) { - block * G = new block[length]; - block * C = new block[length]; - block * GR = new block[length]; - bool * d = new bool[length]; - bool * dR = new bool[length]; - - for (int i = 0; i < length; ++i) { - C[i] = KEY[3*i+1] ^ MAC[3*i+1]; - C[i] = C[i] ^ (select_mask[getLSB(MAC[3*i+1])] & Delta); - G[i] = H2D(KEY[3*i], Delta, I); - G[i] = G[i] ^ C[i]; - } - if(party == ALICE) { - io.send_data(G, sizeof(block)*length); - io.recv_data(GR, sizeof(block)*length); - } else { - io.recv_data(GR, sizeof(block)*length); - io.send_data(G, sizeof(block)*length); - } - io.flush(); - for(int i = 0; i < length; ++i) { - block S = H2(MAC[3*i], KEY[3*i], I); - S = S ^ MAC[3*i+2] ^ KEY[3*i+2]; - S = S ^ (select_mask[getLSB(MAC[3*i])] & (GR[i] ^ C[i])); - G[i] = S ^ (select_mask[getLSB(MAC[3*i+2])] & Delta); - d[i] = getL2SB(G[i]); - } + void generate(block * MAC, block * KEY, int length) { + if (party == ALICE) { + abit1->send_dot(KEY, length*3); + abit2->recv_dot(MAC, length*3); + } else { + // TODO: I (Andrew) swapped the two lines below to remove a + // deadlock after I removed threading. I suspect that's fine + // but I need to check. + abit1->recv_dot(MAC, length*3); + abit2->send_dot(KEY, length*3); + } + } - if(party == ALICE) { - io.send_bool(d, length); - io.recv_bool(dR,length); - } else { - io.recv_bool(dR, length); - io.send_bool(d, length); - } - io.flush(); - for(int i = 0; i < length; ++i) { - d[i] = d[i] != dR[i]; - if (d[i]) { - if(party == ALICE) - MAC[3*i+2] = MAC[3*i+2] ^ one; - else - KEY[3*i+2] = KEY[3*i+2] ^ ZDelta; - - G[i] = G[i] ^ Delta; - } - eq[I]->add_block(G[i]); - } - delete[] G; - delete[] GR; - delete[] C; - delete[] d; - delete[] dR; - } - block H2D(block a, block b, int I) { - block d[2]; - d[0] = a; - d[1] = a ^ b; - prps[I].permute_block(d, 2); - d[0] = d[0] ^ d[1]; - return d[0] ^ b; - } + void check(block * MAC, block * KEY, int length, int I) { + block * G = new block[length]; + block * C = new block[length]; + block * GR = new block[length]; + bool * d = new bool[length]; + bool * dR = new bool[length]; - block H2(block a, block b, int I) { - block d[2]; - d[0] = a; - d[1] = b; - prps[I].permute_block(d, 2); - d[0] = d[0] ^ d[1]; - d[0] = d[0] ^ a; - return d[0] ^ b; - } + for (int i = 0; i < length; ++i) { + C[i] = KEY[3*i+1] ^ MAC[3*i+1]; + C[i] = C[i] ^ (select_mask[getLSB(MAC[3*i+1])] & Delta); + G[i] = H2D(KEY[3*i], Delta, I); + G[i] = G[i] ^ C[i]; + } + if(party == ALICE) { + io.send_data(G, sizeof(block)*length); + io.recv_data(GR, sizeof(block)*length); + } else { + io.recv_data(GR, sizeof(block)*length); + io.send_data(G, sizeof(block)*length); + } + io.flush(); + for(int i = 0; i < length; ++i) { + block S = H2(MAC[3*i], KEY[3*i], I); + S = S ^ MAC[3*i+2] ^ KEY[3*i+2]; + S = S ^ (select_mask[getLSB(MAC[3*i])] & (GR[i] ^ C[i])); + G[i] = S ^ (select_mask[getLSB(MAC[3*i+2])] & Delta); + d[i] = getL2SB(G[i]); + } - bool getL2SB(block b) { - unsigned char x = *((unsigned char*)&b); - return ((x >> 1) & 0x1) == 1; - } + if(party == ALICE) { + io.send_bool(d, length); + io.recv_bool(dR,length); + } else { + io.recv_bool(dR, length); + io.send_bool(d, length); + } + io.flush(); + for(int i = 0; i < length; ++i) { + d[i] = d[i] != dR[i]; + if (d[i]) { + if(party == ALICE) + MAC[3*i+2] = MAC[3*i+2] ^ one; + else + KEY[3*i+2] = KEY[3*i+2] ^ ZDelta; - void combine(block S, int I, block * MAC, block * KEY, int length, int bucket_size, block * MAC_res, block * KEY_res) { - int *location = new int[length*bucket_size]; - for(int i = 0; i < length*bucket_size; ++i) location[i] = i; - PRG prg(&S, I); - int * ind = new int[length*bucket_size]; - prg.random_data(ind, length*bucket_size*4); - for(int i = length*bucket_size-1; i>=0; --i) { - int index = ind[i]%(i+1); - index = index>0? index:(-1*index); - int tmp = location[i]; - location[i] = location[index]; - location[index] = tmp; - } - delete[] ind; + G[i] = G[i] ^ Delta; + } + eq[I]->add_block(G[i]); + } + delete[] G; + delete[] GR; + delete[] C; + delete[] d; + delete[] dR; + } + block H2D(block a, block b, int I) { + block d[2]; + d[0] = a; + d[1] = a ^ b; + prps[I].permute_block(d, 2); + d[0] = d[0] ^ d[1]; + return d[0] ^ b; + } - bool *data = new bool[length*bucket_size]; - bool *data2 = new bool[length*bucket_size]; - for(int i = 0; i < length; ++i) { - for(int j = 1; j < bucket_size; ++j) { - data[i*bucket_size+j] = getLSB(MAC[location[i*bucket_size]*3+1] ^ MAC[location[i*bucket_size+j]*3+1]); - } - } - if(party == ALICE) { - io.send_bool(data, length*bucket_size); - io.recv_bool(data2, length*bucket_size); - } else { - io.recv_bool(data2, length*bucket_size); - io.send_bool(data, length*bucket_size); - } - io.flush(); - for(int i = 0; i < length; ++i) { - for(int j = 1; j < bucket_size; ++j) { - data[i*bucket_size+j] = (data[i*bucket_size+j] != data2[i*bucket_size+j]); - } - } - for(int i = 0; i < length; ++i) { - for(int j = 0; j < 3; ++j) { - MAC_res[i*3+j] = MAC[location[i*bucket_size]*3+j]; - KEY_res[i*3+j] = KEY[location[i*bucket_size]*3+j]; - } - for(int j = 1; j < bucket_size; ++j) { - MAC_res[3*i] = MAC_res[3*i] ^ MAC[location[i*bucket_size+j]*3]; - KEY_res[3*i] = KEY_res[3*i] ^ KEY[location[i*bucket_size+j]*3]; + block H2(block a, block b, int I) { + block d[2]; + d[0] = a; + d[1] = b; + prps[I].permute_block(d, 2); + d[0] = d[0] ^ d[1]; + d[0] = d[0] ^ a; + return d[0] ^ b; + } - MAC_res[i*3+2] = MAC_res[i*3+2] ^ MAC[location[i*bucket_size+j]*3+2]; - KEY_res[i*3+2] = KEY_res[i*3+2] ^ KEY[location[i*bucket_size+j]*3+2]; + bool getL2SB(block b) { + unsigned char x = *((unsigned char*)&b); + return ((x >> 1) & 0x1) == 1; + } - if(data[i*bucket_size+j]) { - KEY_res[i*3+2] = KEY_res[i*3+2] ^ KEY[location[i*bucket_size+j]*3]; - MAC_res[i*3+2] = MAC_res[i*3+2] ^ MAC[location[i*bucket_size+j]*3]; - } - } - } + void combine(block S, int I, block * MAC, block * KEY, int length, int bucket_size, block * MAC_res, block * KEY_res) { + int *location = new int[length*bucket_size]; + for(int i = 0; i < length*bucket_size; ++i) location[i] = i; + PRG prg(&S, I); + int * ind = new int[length*bucket_size]; + prg.random_data(ind, length*bucket_size*4); + for(int i = length*bucket_size-1; i>=0; --i) { + int index = ind[i]%(i+1); + index = index>0? index:(-1*index); + int tmp = location[i]; + location[i] = location[index]; + location[index] = tmp; + } + delete[] ind; - delete[] data; - delete[] location; - delete[] data2; - } + bool *data = new bool[length*bucket_size]; + bool *data2 = new bool[length*bucket_size]; + for(int i = 0; i < length; ++i) { + for(int j = 1; j < bucket_size; ++j) { + data[i*bucket_size+j] = getLSB(MAC[location[i*bucket_size]*3+1] ^ MAC[location[i*bucket_size+j]*3+1]); + } + } + if(party == ALICE) { + io.send_bool(data, length*bucket_size); + io.recv_bool(data2, length*bucket_size); + } else { + io.recv_bool(data2, length*bucket_size); + io.send_bool(data, length*bucket_size); + } + io.flush(); + for(int i = 0; i < length; ++i) { + for(int j = 1; j < bucket_size; ++j) { + data[i*bucket_size+j] = (data[i*bucket_size+j] != data2[i*bucket_size+j]); + } + } + for(int i = 0; i < length; ++i) { + for(int j = 0; j < 3; ++j) { + MAC_res[i*3+j] = MAC[location[i*bucket_size]*3+j]; + KEY_res[i*3+j] = KEY[location[i*bucket_size]*3+j]; + } + for(int j = 1; j < bucket_size; ++j) { + MAC_res[3*i] = MAC_res[3*i] ^ MAC[location[i*bucket_size+j]*3]; + KEY_res[3*i] = KEY_res[3*i] ^ KEY[location[i*bucket_size+j]*3]; + + MAC_res[i*3+2] = MAC_res[i*3+2] ^ MAC[location[i*bucket_size+j]*3+2]; + KEY_res[i*3+2] = KEY_res[i*3+2] ^ KEY[location[i*bucket_size+j]*3+2]; + + if(data[i*bucket_size+j]) { + KEY_res[i*3+2] = KEY_res[i*3+2] ^ KEY[location[i*bucket_size+j]*3]; + MAC_res[i*3+2] = MAC_res[i*3+2] ^ MAC[location[i*bucket_size+j]*3]; + } + } + } + + delete[] data; + delete[] location; + delete[] data2; + } //for debug - void check_correctness(block * MAC, block * KEY, int length) { - if (party == ALICE) { - for(int i = 0; i < length*3; ++i) { - bool tmp = getLSB(MAC[i]); - io.send_data(&tmp, 1); - } - io.send_block(&Delta, 1); - io.send_block(KEY, length*3); - block DD; - io.recv_block(&DD, 1); + void check_correctness(block * MAC, block * KEY, int length) { + if (party == ALICE) { + for(int i = 0; i < length*3; ++i) { + bool tmp = getLSB(MAC[i]); + io.send_data(&tmp, 1); + } + io.send_block(&Delta, 1); + io.send_block(KEY, length*3); + block DD; + io.recv_block(&DD, 1); - for(int i = 0; i < length*3; ++i) { - block tmp; - io.recv_block(&tmp, 1); - if(getLSB(MAC[i])) tmp = tmp ^ DD; - if (!cmpBlock(&tmp, &MAC[i], 1)) - cout < void send_partial_block(IOChannel& io, const block * data, int length) { - for(int i = 0; i < length; ++i) { - io.send_data(&(data[i]), B); - } + for(int i = 0; i < length; ++i) { + io.send_data(&(data[i]), B); + } } template void recv_partial_block(IOChannel& io, block * data, int length) { - for(int i = 0; i < length; ++i) { - io.recv_data(&(data[i]), B); - } + for(int i = 0; i < length; ++i) { + io.recv_data(&(data[i]), B); + } } block coin_tossing(PRG prg, IOChannel& io, int party) { - block S, S2; - char dgst[Hash::DIGEST_SIZE]; - prg.random_block(&S, 1); - if(party == ALICE) { - Hash::hash_once(dgst, &S, sizeof(block)); - io.send_data(dgst, Hash::DIGEST_SIZE); - io.recv_block(&S2, 1); - io.send_block(&S, 1); - } else { - char dgst2[Hash::DIGEST_SIZE]; - io.recv_data(dgst2, Hash::DIGEST_SIZE); - io.send_block(&S, 1); - io.recv_block(&S2, 1); - Hash::hash_once(dgst, &S2, sizeof(block)); - if (memcmp(dgst, dgst2, Hash::DIGEST_SIZE)!= 0) - error("cheat CT!"); - } - io.flush(); - return S ^ S2; + block S, S2; + char dgst[Hash::DIGEST_SIZE]; + prg.random_block(&S, 1); + if(party == ALICE) { + Hash::hash_once(dgst, &S, sizeof(block)); + io.send_data(dgst, Hash::DIGEST_SIZE); + io.recv_block(&S2, 1); + io.send_block(&S, 1); + } else { + char dgst2[Hash::DIGEST_SIZE]; + io.recv_data(dgst2, Hash::DIGEST_SIZE); + io.send_block(&S, 1); + io.recv_block(&S2, 1); + Hash::hash_once(dgst, &S2, sizeof(block)); + if (memcmp(dgst, dgst2, Hash::DIGEST_SIZE)!= 0) + error("cheat CT!"); + } + io.flush(); + return S ^ S2; } } diff --git a/src/emp-ag2pc/leaky_deltaot.h b/src/emp-ag2pc/leaky_deltaot.h index 22be3a6..0034a85 100644 --- a/src/emp-ag2pc/leaky_deltaot.h +++ b/src/emp-ag2pc/leaky_deltaot.h @@ -3,45 +3,45 @@ #include namespace emp { #ifdef __GNUC__ - #ifndef __clang__ - #pragma GCC push_options - #pragma GCC optimize ("unroll-loops") - #endif + #ifndef __clang__ + #pragma GCC push_options + #pragma GCC optimize ("unroll-loops") + #endif #endif class LeakyDeltaOT: public IKNP { public: - LeakyDeltaOT(IOChannel io): IKNP(io, false) {} - - void send_dot(block * data, int length) { - this->send_cot(data, length); - this->io.flush(); - block one = makeBlock(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFE); - for (int i = 0; i < length; ++i) { - data[i] = data[i] & one; - } - } - void recv_dot(block* data, int length) { - bool * b = new bool[length]; - this->prg.random_bool(b, length); - this->recv_cot(data, b, length); - this->io.flush(); + LeakyDeltaOT(IOChannel io): IKNP(io, false) {} - block ch[2]; - ch[0] = zero_block; - ch[1] = makeBlock(0, 1); - block one = makeBlock(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFE); - for (int i = 0; i < length; ++i) { - data[i] = (data[i] & one) ^ ch[b[i]]; - } - delete[] b; - } + void send_dot(block * data, int length) { + this->send_cot(data, length); + this->io.flush(); + block one = makeBlock(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFE); + for (int i = 0; i < length; ++i) { + data[i] = data[i] & one; + } + } + void recv_dot(block* data, int length) { + bool * b = new bool[length]; + this->prg.random_bool(b, length); + this->recv_cot(data, b, length); + this->io.flush(); + + block ch[2]; + ch[0] = zero_block; + ch[1] = makeBlock(0, 1); + block one = makeBlock(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFE); + for (int i = 0; i < length; ++i) { + data[i] = (data[i] & one) ^ ch[b[i]]; + } + delete[] b; + } }; #ifdef __GNUC_ - #ifndef __clang___ - #pragma GCC pop_options - #endif + #ifndef __clang___ + #pragma GCC pop_options + #endif #endif } #endif// LEAKY_DELTA_OT_H diff --git a/src/emp-ot/co.h b/src/emp-ot/co.h index 8d744fa..3ef399c 100644 --- a/src/emp-ot/co.h +++ b/src/emp-ot/co.h @@ -11,88 +11,88 @@ namespace emp { * https://eprint.iacr.org/2015/267.pdf */ class OTCO: public OT { public: - IOChannel io; - Group *G = nullptr; - bool delete_G = true; - OTCO(IOChannel io, Group * _G = nullptr): io(io) { - if (_G == nullptr) - G = new Group(); - else { - G = _G; - delete_G = false; - } - } - ~OTCO() { - if (delete_G) - delete G; - } + IOChannel io; + Group *G = nullptr; + bool delete_G = true; + OTCO(IOChannel io, Group * _G = nullptr): io(io) { + if (_G == nullptr) + G = new Group(); + else { + G = _G; + delete_G = false; + } + } + ~OTCO() { + if (delete_G) + delete G; + } - void send(const block* data0, const block* data1, int64_t length) override { - BigInt a; - Point A, AaInv; - block res[2]; - Point * B = new Point[length]; - Point * BA = new Point[length]; + void send(const block* data0, const block* data1, int64_t length) override { + BigInt a; + Point A, AaInv; + block res[2]; + Point * B = new Point[length]; + Point * BA = new Point[length]; - G->get_rand_bn(a); - A = G->mul_gen(a); - io.send_pt(&A); - AaInv = A.mul(a); - AaInv = AaInv.inv(); + G->get_rand_bn(a); + A = G->mul_gen(a); + io.send_pt(&A); + AaInv = A.mul(a); + AaInv = AaInv.inv(); - for(int64_t i = 0; i < length; ++i) { - io.recv_pt(G, &B[i]); - B[i] = B[i].mul(a); - BA[i] = B[i].add(AaInv); - } - io.flush(); + for(int64_t i = 0; i < length; ++i) { + io.recv_pt(G, &B[i]); + B[i] = B[i].mul(a); + BA[i] = B[i].add(AaInv); + } + io.flush(); - for(int64_t i = 0; i < length; ++i) { - res[0] = Hash::KDF(B[i], i) ^ data0[i]; - res[1] = Hash::KDF(BA[i], i) ^ data1[i]; - io.send_data(res, 2*sizeof(block)); - } + for(int64_t i = 0; i < length; ++i) { + res[0] = Hash::KDF(B[i], i) ^ data0[i]; + res[1] = Hash::KDF(BA[i], i) ^ data1[i]; + io.send_data(res, 2*sizeof(block)); + } - delete[] BA; - delete[] B; - } + delete[] BA; + delete[] B; + } - void recv(block* data, const bool* b, int64_t length) override { - BigInt * bb = new BigInt[length]; - Point * B = new Point[length], - * As = new Point[length], - A; + void recv(block* data, const bool* b, int64_t length) override { + BigInt * bb = new BigInt[length]; + Point * B = new Point[length], + * As = new Point[length], + A; - for(int64_t i = 0; i < length; ++i) - G->get_rand_bn(bb[i]); + for(int64_t i = 0; i < length; ++i) + G->get_rand_bn(bb[i]); - io.recv_pt(G, &A); + io.recv_pt(G, &A); - for(int64_t i = 0; i < length; ++i) { - B[i] = G->mul_gen(bb[i]); - if (b[i]) - B[i] = B[i].add(A); - io.send_pt(&B[i]); - } - io.flush(); + for(int64_t i = 0; i < length; ++i) { + B[i] = G->mul_gen(bb[i]); + if (b[i]) + B[i] = B[i].add(A); + io.send_pt(&B[i]); + } + io.flush(); - for(int64_t i = 0; i < length; ++i) - As[i] = A.mul(bb[i]); + for(int64_t i = 0; i < length; ++i) + As[i] = A.mul(bb[i]); - block res[2]; - for(int64_t i = 0; i < length; ++i) { - io.recv_data(res, 2*sizeof(block)); - data[i] = Hash::KDF(As[i], i); - if(b[i]) - data[i] = data[i] ^ res[1]; - else - data[i] = data[i] ^ res[0]; - } - - delete[] bb; - delete[] B; - delete[] As; - } + block res[2]; + for(int64_t i = 0; i < length; ++i) { + io.recv_data(res, 2*sizeof(block)); + data[i] = Hash::KDF(As[i], i); + if(b[i]) + data[i] = data[i] ^ res[1]; + else + data[i] = data[i] ^ res[0]; + } + + delete[] bb; + delete[] B; + delete[] As; + } }; }//namespace diff --git a/src/emp-ot/iknp.h b/src/emp-ot/iknp.h index e4646fd..0989d84 100644 --- a/src/emp-ot/iknp.h +++ b/src/emp-ot/iknp.h @@ -18,280 +18,280 @@ const static int64_t ot_bsize = 8; */ class IKNP : public OT { public: - IOChannel io; + IOChannel io; - MITCCRH mitccrh; - block Delta; - PRG cot_prg; + MITCCRH mitccrh; + block Delta; + PRG cot_prg; - OTCO * base_ot = nullptr; - bool setup = false, *extended_r = nullptr; + OTCO * base_ot = nullptr; + bool setup = false, *extended_r = nullptr; - const static int64_t block_size = 1024*2; - block local_out[block_size]; - bool s[128], local_r[256]; - PRG prg, G0[128], G1[128]; - bool malicious = false; - block k0[128], k1[128]; + const static int64_t block_size = 1024*2; + block local_out[block_size]; + bool s[128], local_r[256]; + PRG prg, G0[128], G1[128]; + bool malicious = false; + block k0[128], k1[128]; - IKNP(IOChannel io, bool malicious = false): io(io), malicious(malicious) {} + IKNP(IOChannel io, bool malicious = false): io(io), malicious(malicious) {} - ~IKNP() { - delete_array_null(extended_r); - } + ~IKNP() { + delete_array_null(extended_r); + } - void setup_send(const bool* in_s = nullptr, block * in_k0 = nullptr) { - setup = true; - if(in_s == nullptr) - prg.random_bool(s, 128); - else - memcpy(s, in_s, 128); - - if(in_k0 != nullptr) { - memcpy(k0, in_k0, 128*sizeof(block)); - } else { - this->base_ot = new OTCO(io); - base_ot->recv(k0, s, 128); - delete base_ot; - } - for(int64_t i = 0; i < 128; ++i) - G0[i].reseed(&k0[i]); + void setup_send(const bool* in_s = nullptr, block * in_k0 = nullptr) { + setup = true; + if(in_s == nullptr) + prg.random_bool(s, 128); + else + memcpy(s, in_s, 128); - Delta = bool_to_block(s); - } + if(in_k0 != nullptr) { + memcpy(k0, in_k0, 128*sizeof(block)); + } else { + this->base_ot = new OTCO(io); + base_ot->recv(k0, s, 128); + delete base_ot; + } + for(int64_t i = 0; i < 128; ++i) + G0[i].reseed(&k0[i]); - void setup_recv(block * in_k0 = nullptr, block * in_k1 =nullptr) { - setup = true; - if(in_k0 !=nullptr) { - memcpy(k0, in_k0, 128*sizeof(block)); - memcpy(k1, in_k1, 128*sizeof(block)); - } else { - this->base_ot = new OTCO(io); - prg.random_block(k0, 128); - prg.random_block(k1, 128); - base_ot->send(k0, k1, 128); - delete base_ot; - } - for(int64_t i = 0; i < 128; ++i) { - G0[i].reseed(&k0[i]); - G1[i].reseed(&k1[i]); - } - } + Delta = bool_to_block(s); + } - void send_pre(block * out, int64_t length) { - if(not setup) - setup_send(); - int64_t j = 0; - for (; j < length/block_size; ++j) - send_pre_block(out + j*block_size, block_size); - int64_t remain = length % block_size; - if (remain > 0) { - send_pre_block(local_out, remain); - memcpy(out+j*block_size, local_out, sizeof(block)*remain); - } - if(malicious) - send_pre_block(local_out, 256); - } + void setup_recv(block * in_k0 = nullptr, block * in_k1 =nullptr) { + setup = true; + if(in_k0 !=nullptr) { + memcpy(k0, in_k0, 128*sizeof(block)); + memcpy(k1, in_k1, 128*sizeof(block)); + } else { + this->base_ot = new OTCO(io); + prg.random_block(k0, 128); + prg.random_block(k1, 128); + base_ot->send(k0, k1, 128); + delete base_ot; + } + for(int64_t i = 0; i < 128; ++i) { + G0[i].reseed(&k0[i]); + G1[i].reseed(&k1[i]); + } + } - void send_pre_block(block * out, int64_t len) { - block t[block_size]; - block tmp[block_size]; - int64_t local_block_size = (len+127)/128*128; - io.recv_block(tmp, local_block_size); - for(int64_t i = 0; i < 128; ++i) { - G0[i].random_data(t+(i*block_size/128), local_block_size/8); - if (s[i]) - xorBlocks_arr(t+(i*block_size/128), t+(i*block_size/128), tmp+(i*local_block_size/128), local_block_size/128); - } - sse_trans((uint8_t *)(out), (uint8_t*)t, 128, block_size); - } + void send_pre(block * out, int64_t length) { + if(not setup) + setup_send(); + int64_t j = 0; + for (; j < length/block_size; ++j) + send_pre_block(out + j*block_size, block_size); + int64_t remain = length % block_size; + if (remain > 0) { + send_pre_block(local_out, remain); + memcpy(out+j*block_size, local_out, sizeof(block)*remain); + } + if(malicious) + send_pre_block(local_out, 256); + } - void recv_pre(block * out, const bool* r, int64_t length) { - if(not setup) - setup_recv(); + void send_pre_block(block * out, int64_t len) { + block t[block_size]; + block tmp[block_size]; + int64_t local_block_size = (len+127)/128*128; + io.recv_block(tmp, local_block_size); + for(int64_t i = 0; i < 128; ++i) { + G0[i].random_data(t+(i*block_size/128), local_block_size/8); + if (s[i]) + xorBlocks_arr(t+(i*block_size/128), t+(i*block_size/128), tmp+(i*local_block_size/128), local_block_size/128); + } + sse_trans((uint8_t *)(out), (uint8_t*)t, 128, block_size); + } - block *block_r = new block[(length+127)/128]; - for(int64_t i = 0; i < length/128; ++i) - block_r[i] = bool_to_block(r+i*128); - if (length%128 != 0) { - bool tmp_bool_array[128]; - memset(tmp_bool_array, 0, 128); - int64_t start_point = (length / 128)*128; - memcpy(tmp_bool_array, r+start_point, length % 128); - block_r[length/128] = bool_to_block(tmp_bool_array); - } - - int64_t j = 0; - for (; j < length/block_size; ++j) - recv_pre_block(out+j*block_size, block_r + (j*block_size/128), block_size); - int64_t remain = length % block_size; - if (remain > 0) { - recv_pre_block(local_out, block_r + (j*block_size/128), remain); - memcpy(out+j*block_size, local_out, sizeof(block)*remain); - } - if(malicious) { - block local_r_block[2]; - prg.random_bool(local_r, 256); - local_r_block[0] = bool_to_block(local_r); - local_r_block[1] = bool_to_block(local_r + 128); - recv_pre_block(local_out, local_r_block, 256); - } - delete[] block_r; - } + void recv_pre(block * out, const bool* r, int64_t length) { + if(not setup) + setup_recv(); - void recv_pre_block(block * out, block * r, int64_t len) { - block t[block_size]; - block tmp[block_size]; - int64_t local_block_size = (len+127)/128 * 128; - for(int64_t i = 0; i < 128; ++i) { - G0[i].random_data(t+(i*block_size/128), local_block_size/8); - G1[i].random_data(tmp, local_block_size/8); - xorBlocks_arr(tmp, t+(i*block_size/128), tmp, local_block_size/128); - xorBlocks_arr(tmp, r, tmp, local_block_size/128); - io.send_data(tmp, local_block_size/8); - } + block *block_r = new block[(length+127)/128]; + for(int64_t i = 0; i < length/128; ++i) + block_r[i] = bool_to_block(r+i*128); + if (length%128 != 0) { + bool tmp_bool_array[128]; + memset(tmp_bool_array, 0, 128); + int64_t start_point = (length / 128)*128; + memcpy(tmp_bool_array, r+start_point, length % 128); + block_r[length/128] = bool_to_block(tmp_bool_array); + } - sse_trans((uint8_t *)(out), (uint8_t*)t, 128, block_size); - } + int64_t j = 0; + for (; j < length/block_size; ++j) + recv_pre_block(out+j*block_size, block_r + (j*block_size/128), block_size); + int64_t remain = length % block_size; + if (remain > 0) { + recv_pre_block(local_out, block_r + (j*block_size/128), remain); + memcpy(out+j*block_size, local_out, sizeof(block)*remain); + } + if(malicious) { + block local_r_block[2]; + prg.random_bool(local_r, 256); + local_r_block[0] = bool_to_block(local_r); + local_r_block[1] = bool_to_block(local_r + 128); + recv_pre_block(local_out, local_r_block, 256); + } + delete[] block_r; + } - void send(const block* data0, const block* data1, int64_t length) override { - block * data = new block[length]; - send_cot(data, length); - block s; - cot_prg.random_block(&s, 1); - io.send_block(&s,1); - mitccrh.setS(s); - io.flush(); - block pad[2*ot_bsize]; - for(int64_t i = 0; i < length; i+=ot_bsize) { - for(int64_t j = i; j < min(i+ot_bsize, length); ++j) { - pad[2*(j-i)] = data[j]; - pad[2*(j-i)+1] = data[j] ^ Delta; - } - mitccrh.hash(pad); - for(int64_t j = i; j < min(i+ot_bsize, length); ++j) { - pad[2*(j-i)] = pad[2*(j-i)] ^ data0[j]; - pad[2*(j-i)+1] = pad[2*(j-i)+1] ^ data1[j]; - } - io.send_data(pad, 2*sizeof(block)*min(ot_bsize,length-i)); - } - delete[] data; - } + void recv_pre_block(block * out, block * r, int64_t len) { + block t[block_size]; + block tmp[block_size]; + int64_t local_block_size = (len+127)/128 * 128; + for(int64_t i = 0; i < 128; ++i) { + G0[i].random_data(t+(i*block_size/128), local_block_size/8); + G1[i].random_data(tmp, local_block_size/8); + xorBlocks_arr(tmp, t+(i*block_size/128), tmp, local_block_size/128); + xorBlocks_arr(tmp, r, tmp, local_block_size/128); + io.send_data(tmp, local_block_size/8); + } - void recv(block* data, const bool* r, int64_t length) override { - recv_cot(data, r, length); - block s; - io.recv_block(&s,1); - mitccrh.setS(s); - io.flush(); + sse_trans((uint8_t *)(out), (uint8_t*)t, 128, block_size); + } - block res[2*ot_bsize]; - block pad[ot_bsize]; - for(int64_t i = 0; i < length; i+=ot_bsize) { - memcpy(pad, data+i, min(ot_bsize,length-i)*sizeof(block)); - mitccrh.hash(pad); - io.recv_data(res, 2*sizeof(block)*min(ot_bsize,length-i)); - for(int64_t j = 0; j < ot_bsize and j < length-i; ++j) { - data[i+j] = res[2*j+r[i+j]] ^ pad[j]; - } - } - } + void send(const block* data0, const block* data1, int64_t length) override { + block * data = new block[length]; + send_cot(data, length); + block s; + cot_prg.random_block(&s, 1); + io.send_block(&s,1); + mitccrh.setS(s); + io.flush(); + block pad[2*ot_bsize]; + for(int64_t i = 0; i < length; i+=ot_bsize) { + for(int64_t j = i; j < min(i+ot_bsize, length); ++j) { + pad[2*(j-i)] = data[j]; + pad[2*(j-i)+1] = data[j] ^ Delta; + } + mitccrh.hash(pad); + for(int64_t j = i; j < min(i+ot_bsize, length); ++j) { + pad[2*(j-i)] = pad[2*(j-i)] ^ data0[j]; + pad[2*(j-i)+1] = pad[2*(j-i)+1] ^ data1[j]; + } + io.send_data(pad, 2*sizeof(block)*min(ot_bsize,length-i)); + } + delete[] data; + } - void send_cot(block * data, int64_t length) { - send_pre(data, length); + void recv(block* data, const bool* r, int64_t length) override { + recv_cot(data, r, length); + block s; + io.recv_block(&s,1); + mitccrh.setS(s); + io.flush(); - if(malicious) - if(!send_check(data, length)) - error("OT Extension check failed"); - } + block res[2*ot_bsize]; + block pad[ot_bsize]; + for(int64_t i = 0; i < length; i+=ot_bsize) { + memcpy(pad, data+i, min(ot_bsize,length-i)*sizeof(block)); + mitccrh.hash(pad); + io.recv_data(res, 2*sizeof(block)*min(ot_bsize,length-i)); + for(int64_t j = 0; j < ot_bsize and j < length-i; ++j) { + data[i+j] = res[2*j+r[i+j]] ^ pad[j]; + } + } + } - void recv_cot(block* data, const bool * b, int64_t length) { - recv_pre(data, b, length); - if(malicious) - recv_check(data, b, length); - } + void send_cot(block * data, int64_t length) { + send_pre(data, length); + + if(malicious) + if(!send_check(data, length)) + error("OT Extension check failed"); + } + + void recv_cot(block* data, const bool * b, int64_t length) { + recv_pre(data, b, length); + if(malicious) + recv_check(data, b, length); + } /* * * [REF] Implementation of "Actively Secure OT Extension with Optimal Overhead" * https://eprint.iacr.org/2015/546.pdf */ - bool send_check(block * out, int64_t length) { - block seed2, x, t[2], q[2], tmp[2]; - block chi[block_size]; - q[0] = q[1] = makeBlock(0, 0); - io.recv_block(&seed2, 1); - io.flush(); - PRG chiPRG(&seed2); + bool send_check(block * out, int64_t length) { + block seed2, x, t[2], q[2], tmp[2]; + block chi[block_size]; + q[0] = q[1] = makeBlock(0, 0); + io.recv_block(&seed2, 1); + io.flush(); + PRG chiPRG(&seed2); - for(int64_t i = 0; i < length/block_size; ++i) { - chiPRG.random_block(chi, block_size); - vector_inn_prdt_sum_no_red(tmp, chi, out+i*block_size); - q[0] = q[0] ^ tmp[0]; - q[1] = q[1] ^ tmp[1]; - } - int64_t remain = length % block_size; - if(remain != 0) { - chiPRG.random_block(chi, block_size); - vector_inn_prdt_sum_no_red(tmp, chi, out + length - remain, remain); - q[0] = q[0] ^ tmp[0]; - q[1] = q[1] ^ tmp[1]; - } - { - chiPRG.random_block(chi, 256); - vector_inn_prdt_sum_no_red<256>(tmp, chi, local_out); - q[0] = q[0] ^ tmp[0]; - q[1] = q[1] ^ tmp[1]; - } + for(int64_t i = 0; i < length/block_size; ++i) { + chiPRG.random_block(chi, block_size); + vector_inn_prdt_sum_no_red(tmp, chi, out+i*block_size); + q[0] = q[0] ^ tmp[0]; + q[1] = q[1] ^ tmp[1]; + } + int64_t remain = length % block_size; + if(remain != 0) { + chiPRG.random_block(chi, block_size); + vector_inn_prdt_sum_no_red(tmp, chi, out + length - remain, remain); + q[0] = q[0] ^ tmp[0]; + q[1] = q[1] ^ tmp[1]; + } + { + chiPRG.random_block(chi, 256); + vector_inn_prdt_sum_no_red<256>(tmp, chi, local_out); + q[0] = q[0] ^ tmp[0]; + q[1] = q[1] ^ tmp[1]; + } - io.recv_block(&x, 1); - io.recv_block(t, 2); - mul128(x, Delta, tmp, tmp+1); - q[0] = q[0] ^ tmp[0]; - q[1] = q[1] ^ tmp[1]; + io.recv_block(&x, 1); + io.recv_block(t, 2); + mul128(x, Delta, tmp, tmp+1); + q[0] = q[0] ^ tmp[0]; + q[1] = q[1] ^ tmp[1]; - return cmpBlock(q, t, 2); - } - void recv_check(block * out, const bool* r, int64_t length) { - block select[2] = {zero_block, all_one_block}; - block seed2, x = makeBlock(0,0), t[2], tmp[2]; - prg.random_block(&seed2,1); - io.send_block(&seed2, 1); - io.flush(); - block chi[block_size]; - t[0] = t[1] = makeBlock(0, 0); - PRG chiPRG(&seed2); + return cmpBlock(q, t, 2); + } + void recv_check(block * out, const bool* r, int64_t length) { + block select[2] = {zero_block, all_one_block}; + block seed2, x = makeBlock(0,0), t[2], tmp[2]; + prg.random_block(&seed2,1); + io.send_block(&seed2, 1); + io.flush(); + block chi[block_size]; + t[0] = t[1] = makeBlock(0, 0); + PRG chiPRG(&seed2); - for(int64_t i = 0; i < length/block_size; ++i) { - chiPRG.random_block(chi, block_size); - vector_inn_prdt_sum_no_red(tmp, chi, out+i*block_size); - t[0] = t[0] ^ tmp[0]; - t[1] = t[1] ^ tmp[1]; - for(int64_t j = 0; j < block_size; ++j) - x = x ^ (chi[j] & select[r[i*block_size+j]]); - } - int64_t remain = length % block_size; - if(remain != 0) { - chiPRG.random_block(chi, block_size); - vector_inn_prdt_sum_no_red(tmp, chi, out+length - remain, remain); - t[0] = t[0] ^ tmp[0]; - t[1] = t[1] ^ tmp[1]; - for(int64_t j = 0; j < remain; ++j) - x = x ^ (chi[j] & select[r[length - remain + j]]); - } - - { - chiPRG.random_block(chi, 256); - vector_inn_prdt_sum_no_red<256>(tmp, chi, local_out); - t[0] = t[0] ^ tmp[0]; - t[1] = t[1] ^ tmp[1]; - for(int64_t j = 0; j < 256; ++j) - x = x ^ (chi[j] & select[local_r[j]]); - } + for(int64_t i = 0; i < length/block_size; ++i) { + chiPRG.random_block(chi, block_size); + vector_inn_prdt_sum_no_red(tmp, chi, out+i*block_size); + t[0] = t[0] ^ tmp[0]; + t[1] = t[1] ^ tmp[1]; + for(int64_t j = 0; j < block_size; ++j) + x = x ^ (chi[j] & select[r[i*block_size+j]]); + } + int64_t remain = length % block_size; + if(remain != 0) { + chiPRG.random_block(chi, block_size); + vector_inn_prdt_sum_no_red(tmp, chi, out+length - remain, remain); + t[0] = t[0] ^ tmp[0]; + t[1] = t[1] ^ tmp[1]; + for(int64_t j = 0; j < remain; ++j) + x = x ^ (chi[j] & select[r[length - remain + j]]); + } - io.send_block(&x, 1); - io.send_block(t, 2); - } + { + chiPRG.random_block(chi, 256); + vector_inn_prdt_sum_no_red<256>(tmp, chi, local_out); + t[0] = t[0] ^ tmp[0]; + t[1] = t[1] ^ tmp[1]; + for(int64_t j = 0; j < 256; ++j) + x = x ^ (chi[j] & select[local_r[j]]); + } + + io.send_block(&x, 1); + io.send_block(t, 2); + } }; }//namespace diff --git a/src/emp-ot/ot.h b/src/emp-ot/ot.h index 5b1b315..3976c79 100644 --- a/src/emp-ot/ot.h +++ b/src/emp-ot/ot.h @@ -5,10 +5,10 @@ namespace emp { class OT { public: - virtual void send(const block* data0, const block* data1, int64_t length) = 0; - virtual void recv(block* data, const bool* b, int64_t length) = 0; + virtual void send(const block* data0, const block* data1, int64_t length) = 0; + virtual void recv(block* data, const bool* b, int64_t length) = 0; - virtual ~OT() {} + virtual ~OT() {} }; } diff --git a/src/emp-tool/circuits/bit.h b/src/emp-tool/circuits/bit.h index 1354040..48fd1a7 100644 --- a/src/emp-tool/circuits/bit.h +++ b/src/emp-tool/circuits/bit.h @@ -8,36 +8,36 @@ namespace emp { class Bit : public Swappable{ public: - block bit; + block bit; - Bit(bool _b = false, int party = PUBLIC); - Bit(const block& a) { - memcpy(&bit, &a, sizeof(block)); - } + Bit(bool _b = false, int party = PUBLIC); + Bit(const block& a) { + memcpy(&bit, &a, sizeof(block)); + } - template - O reveal(int party = PUBLIC) const; + template + O reveal(int party = PUBLIC) const; - Bit operator!=(const Bit& rhs) const; - Bit operator==(const Bit& rhs) const; - Bit operator &(const Bit& rhs) const; - Bit operator |(const Bit& rhs) const; - Bit operator !() const; + Bit operator!=(const Bit& rhs) const; + Bit operator==(const Bit& rhs) const; + Bit operator &(const Bit& rhs) const; + Bit operator |(const Bit& rhs) const; + Bit operator !() const; - //swappable - Bit select(const Bit & select, const Bit & new_v)const ; - Bit operator ^(const Bit& rhs) const; - Bit operator ^=(const Bit& rhs); + //swappable + Bit select(const Bit & select, const Bit & new_v)const ; + Bit operator ^(const Bit& rhs) const; + Bit operator ^=(const Bit& rhs); - //batcher - template - static size_t bool_size(Args&&... args) { - return 1; - } + //batcher + template + static size_t bool_size(Args&&... args) { + return 1; + } - static void bool_data(bool *b, bool data) { - b[0] = data; - } + static void bool_data(bool *b, bool data) { + b[0] = data; + } }; #include "emp-tool/circuits/bit.hpp" } diff --git a/src/emp-tool/circuits/bit.hpp b/src/emp-tool/circuits/bit.hpp index 003f25d..87c24b1 100644 --- a/src/emp-tool/circuits/bit.hpp +++ b/src/emp-tool/circuits/bit.hpp @@ -1,57 +1,57 @@ inline Bit::Bit(bool b, int party) { - if (party == PUBLIC) - bit = CircuitExecution::circ_exec->public_label(b); - else ProtocolExecution::prot_exec->feed(&bit, party, &b, 1); + if (party == PUBLIC) + bit = CircuitExecution::circ_exec->public_label(b); + else ProtocolExecution::prot_exec->feed(&bit, party, &b, 1); } inline Bit Bit::select(const Bit & select, const Bit & new_v) const{ - Bit tmp = *this; - tmp = tmp ^ new_v; - tmp = tmp & select; - return *this ^ tmp; + Bit tmp = *this; + tmp = tmp ^ new_v; + tmp = tmp & select; + return *this ^ tmp; } template inline O Bit::reveal(int party) const { - O res; - ProtocolExecution::prot_exec->reveal(&res, party, &bit, 1); - return res; + O res; + ProtocolExecution::prot_exec->reveal(&res, party, &bit, 1); + return res; } template<> inline string Bit::reveal(int party) const { - bool res; - ProtocolExecution::prot_exec->reveal(&res, party, &bit, 1); - return res ? "true" : "false"; + bool res; + ProtocolExecution::prot_exec->reveal(&res, party, &bit, 1); + return res ? "true" : "false"; } inline Bit Bit::operator==(const Bit& rhs) const { - return !(*this ^ rhs); + return !(*this ^ rhs); } inline Bit Bit::operator!=(const Bit& rhs) const { - return (*this) ^ rhs; + return (*this) ^ rhs; } inline Bit Bit::operator &(const Bit& rhs) const{ - Bit res; - res.bit = CircuitExecution::circ_exec->and_gate(bit, rhs.bit); - return res; + Bit res; + res.bit = CircuitExecution::circ_exec->and_gate(bit, rhs.bit); + return res; } inline Bit Bit::operator ^(const Bit& rhs) const{ - Bit res; - res.bit = CircuitExecution::circ_exec->xor_gate(bit, rhs.bit); - return res; + Bit res; + res.bit = CircuitExecution::circ_exec->xor_gate(bit, rhs.bit); + return res; } inline Bit Bit::operator ^=(const Bit& rhs) { - this->bit = CircuitExecution::circ_exec->xor_gate(bit, rhs.bit); - return (*this); + this->bit = CircuitExecution::circ_exec->xor_gate(bit, rhs.bit); + return (*this); } inline Bit Bit::operator |(const Bit& rhs) const{ - return (*this ^ rhs) ^ (*this & rhs); + return (*this ^ rhs) ^ (*this & rhs); } inline Bit Bit::operator!() const { - return CircuitExecution::circ_exec->not_gate(bit); + return CircuitExecution::circ_exec->not_gate(bit); } diff --git a/src/emp-tool/circuits/circuit_file.h b/src/emp-tool/circuits/circuit_file.h index 627a5ff..8852711 100644 --- a/src/emp-tool/circuits/circuit_file.h +++ b/src/emp-tool/circuits/circuit_file.h @@ -17,186 +17,186 @@ namespace emp { template void execute_circuit(block * wires, const T * gates, size_t num_gate) { - for(size_t i = 0; i < num_gate; ++i) { - if(gates[4*i+3] == AND_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else if (gates[4*i+3] == XOR_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else if (gates[4*i+3] == NOT_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); - } else { - block tmp = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - block tmp2 = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(tmp, tmp2); - } - } + for(size_t i = 0; i < num_gate; ++i) { + if(gates[4*i+3] == AND_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else if (gates[4*i+3] == XOR_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else if (gates[4*i+3] == NOT_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); + } else { + block tmp = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + block tmp2 = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(tmp, tmp2); + } + } } class BristolFormat { public: - int num_gate, num_wire, n1, n2, n3; - vector gates; - vector wires; - std::ofstream fout; + int num_gate, num_wire, n1, n2, n3; + vector gates; + vector wires; + std::ofstream fout; - BristolFormat(int num_gate, int num_wire, int n1, int n2, int n3, int * gate_arr) { - this->num_gate = num_gate; - this->num_wire = num_wire; - this->n1 = n1; - this->n2 = n2; - this->n3 = n3; - gates.resize(num_gate*4); - wires.resize(num_wire); - memcpy(gates.data(), gate_arr, num_gate*4*sizeof(int)); - } + BristolFormat(int num_gate, int num_wire, int n1, int n2, int n3, int * gate_arr) { + this->num_gate = num_gate; + this->num_wire = num_wire; + this->n1 = n1; + this->n2 = n2; + this->n3 = n3; + gates.resize(num_gate*4); + wires.resize(num_wire); + memcpy(gates.data(), gate_arr, num_gate*4*sizeof(int)); + } - BristolFormat(FILE * file) { - this->from_file(file); - } + BristolFormat(FILE * file) { + this->from_file(file); + } - BristolFormat(const char * file) { - this->from_file(file); - } + BristolFormat(const char * file) { + this->from_file(file); + } - void to_file(const char * filename, const char * prefix) { - fout.open(filename); - fout << "int "<from_file(f); - fclose(f); - } + void from_file(const char * file) { + FILE * f = fopen(file, "r"); + this->from_file(f); + fclose(f); + } - void compute(Bit * out, const Bit * in1, const Bit * in2) { - compute((block*)out, (block *)in1, (block*)in2); - } + void compute(Bit * out, const Bit * in1, const Bit * in2) { + compute((block*)out, (block *)in1, (block*)in2); + } - void compute(block * out, const block * in1, const block * in2) { - memcpy(wires.data(), in1, n1*sizeof(block)); - memcpy(wires.data()+n1, in2, n2*sizeof(block)); - for(int i = 0; i < num_gate; ++i) { - if(gates[4*i+3] == AND_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else if (gates[4*i+3] == XOR_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else - wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); - } - memcpy(out, wires.data()+(num_wire-n3), n3*sizeof(block)); - } + void compute(block * out, const block * in1, const block * in2) { + memcpy(wires.data(), in1, n1*sizeof(block)); + memcpy(wires.data()+n1, in2, n2*sizeof(block)); + for(int i = 0; i < num_gate; ++i) { + if(gates[4*i+3] == AND_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else if (gates[4*i+3] == XOR_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else + wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); + } + memcpy(out, wires.data()+(num_wire-n3), n3*sizeof(block)); + } }; class BristolFashion { public: - int num_gate = 0, num_wire = 0, - num_input = 0, num_output = 0; - vector gates; - vector wires; + int num_gate = 0, num_wire = 0, + num_input = 0, num_output = 0; + vector gates; + vector wires; - BristolFashion(FILE * file) { - this->from_file(file); - } + BristolFashion(FILE * file) { + this->from_file(file); + } - BristolFashion(const char * file) { - this->from_file(file); - } + BristolFashion(const char * file) { + this->from_file(file); + } - void from_file(FILE * f) { - int tmp; - (void)fscanf(f, "%d%d\n", &num_gate, &num_wire); - int niov = 0; - (void)fscanf(f, "%d", &niov); - for(int i = 0; i < niov; ++i) { - (void)fscanf(f, "%d", &tmp); - num_input += tmp; - } - (void)fscanf(f, "%d", &niov); - for(int i = 0; i < niov; ++i) { - (void)fscanf(f, "%d", &tmp); - num_output += tmp; - } + void from_file(FILE * f) { + int tmp; + (void)fscanf(f, "%d%d\n", &num_gate, &num_wire); + int niov = 0; + (void)fscanf(f, "%d", &niov); + for(int i = 0; i < niov; ++i) { + (void)fscanf(f, "%d", &tmp); + num_input += tmp; + } + (void)fscanf(f, "%d", &niov); + for(int i = 0; i < niov; ++i) { + (void)fscanf(f, "%d", &tmp); + num_output += tmp; + } - char str[10]; - gates.resize(num_gate*4); - wires.resize(num_wire); - for(int i = 0; i < num_gate; ++i) { - (void)fscanf(f, "%d", &tmp); - if (tmp == 2) { - (void)fscanf(f, "%d%d%d%d%s", &tmp, &gates[4*i], &gates[4*i+1], &gates[4*i+2], str); - if (str[0] == 'A') gates[4*i+3] = AND_GATE; - else if (str[0] == 'X') gates[4*i+3] = XOR_GATE; - } - else if (tmp == 1) { - (void)fscanf(f, "%d%d%d%s", &tmp, &gates[4*i], &gates[4*i+2], str); - gates[4*i+3] = NOT_GATE; - } - } - } + char str[10]; + gates.resize(num_gate*4); + wires.resize(num_wire); + for(int i = 0; i < num_gate; ++i) { + (void)fscanf(f, "%d", &tmp); + if (tmp == 2) { + (void)fscanf(f, "%d%d%d%d%s", &tmp, &gates[4*i], &gates[4*i+1], &gates[4*i+2], str); + if (str[0] == 'A') gates[4*i+3] = AND_GATE; + else if (str[0] == 'X') gates[4*i+3] = XOR_GATE; + } + else if (tmp == 1) { + (void)fscanf(f, "%d%d%d%s", &tmp, &gates[4*i], &gates[4*i+2], str); + gates[4*i+3] = NOT_GATE; + } + } + } - void from_file(const char * file) { - FILE * f = fopen(file, "r"); - this->from_file(f); - fclose(f); - } + void from_file(const char * file) { + FILE * f = fopen(file, "r"); + this->from_file(f); + fclose(f); + } - void compute(Bit * out, const Bit * in) { - compute((block*)out, (block *)in); - } - void compute(block * out, const block * in) { - memcpy(wires.data(), in, num_input*sizeof(block)); - for(int i = 0; i < num_gate; ++i) { - if(gates[4*i+3] == AND_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else if (gates[4*i+3] == XOR_GATE) { - wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); - } - else - wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); - } - memcpy(out, wires.data()+(num_wire-num_output), num_output*sizeof(block)); - } + void compute(Bit * out, const Bit * in) { + compute((block*)out, (block *)in); + } + void compute(block * out, const block * in) { + memcpy(wires.data(), in, num_input*sizeof(block)); + for(int i = 0; i < num_gate; ++i) { + if(gates[4*i+3] == AND_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->and_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else if (gates[4*i+3] == XOR_GATE) { + wires[gates[4*i+2]] = CircuitExecution::circ_exec->xor_gate(wires[gates[4*i]], wires[gates[4*i+1]]); + } + else + wires[gates[4*i+2]] = CircuitExecution::circ_exec->not_gate(wires[gates[4*i]]); + } + memcpy(out, wires.data()+(num_wire-num_output), num_output*sizeof(block)); + } }; } diff --git a/src/emp-tool/circuits/swappable.h b/src/emp-tool/circuits/swappable.h index d88e1a2..868f0e6 100644 --- a/src/emp-tool/circuits/swappable.h +++ b/src/emp-tool/circuits/swappable.h @@ -5,21 +5,21 @@ namespace emp { class Bit; template class Swappable { public: - T If(const Bit & sel, const T& rhs) const { - return static_cast(this)->select(sel, rhs); - } + T If(const Bit & sel, const T& rhs) const { + return static_cast(this)->select(sel, rhs); + } }; template inline T If(const Bit & select, const T & o1, const T & o2) { - T res = o2; - return res.If(select, o1); + T res = o2; + return res.If(select, o1); } template inline void swap(const Bit & swap, T & o1, T & o2) { - T o = If(swap, o1, o2); - o ^= o2; - o1 ^= o; - o2 ^= o; + T o = If(swap, o1, o2); + o ^= o2; + o1 ^= o; + o2 ^= o; } } #endif diff --git a/src/emp-tool/execution/circuit_execution.h b/src/emp-tool/execution/circuit_execution.h index 37f5596..4422c63 100644 --- a/src/emp-tool/execution/circuit_execution.h +++ b/src/emp-tool/execution/circuit_execution.h @@ -5,24 +5,24 @@ namespace emp { -/* Circuit Pipelining +/* Circuit Pipelining * [REF] Implementation of "Faster Secure Two-Party Computation Using Garbled Circuit" * https://www.usenix.org/legacy/event/sec11/tech/full_papers/Huang.pdf */ class CircuitExecution { public: #ifndef THREADING - static CircuitExecution * circ_exec; + static CircuitExecution * circ_exec; #else - static __thread CircuitExecution * circ_exec; + static __thread CircuitExecution * circ_exec; #endif - virtual block and_gate(const block& in1, const block& in2) = 0; - virtual block xor_gate(const block&in1, const block&in2) = 0; - virtual block not_gate(const block& in1) = 0; - virtual block public_label(bool b) = 0; - virtual uint64_t num_and() { - return -1; - } - virtual ~CircuitExecution (){ } + virtual block and_gate(const block& in1, const block& in2) = 0; + virtual block xor_gate(const block&in1, const block&in2) = 0; + virtual block not_gate(const block& in1) = 0; + virtual block public_label(bool b) = 0; + virtual uint64_t num_and() { + return -1; + } + virtual ~CircuitExecution (){ } }; enum RTCktOpt{on, off}; } diff --git a/src/emp-tool/execution/protocol_execution.h b/src/emp-tool/execution/protocol_execution.h index 8bfd439..ab78e5f 100644 --- a/src/emp-tool/execution/protocol_execution.h +++ b/src/emp-tool/execution/protocol_execution.h @@ -1,24 +1,24 @@ #ifndef EMP_PROTOCOL_EXECUTION_H #define EMP_PROTOCOL_EXECUTION_H -#include +#include #include "emp-tool/utils/block.h" #include "emp-tool/utils/constants.h" namespace emp { -class ProtocolExecution { +class ProtocolExecution { public: - int cur_party; + int cur_party; #ifndef THREADING - static ProtocolExecution * prot_exec; + static ProtocolExecution * prot_exec; #else - static __thread ProtocolExecution * prot_exec; + static __thread ProtocolExecution * prot_exec; #endif - ProtocolExecution(int party = PUBLIC): cur_party (party) {} - virtual ~ProtocolExecution() {} - virtual void feed(block * lbls, int party, const bool* b, int nel) = 0; - virtual void reveal(bool*out, int party, const block *lbls, int nel) = 0; - virtual void finalize() {} + ProtocolExecution(int party = PUBLIC): cur_party (party) {} + virtual ~ProtocolExecution() {} + virtual void feed(block * lbls, int party, const bool* b, int nel) = 0; + virtual void reveal(bool*out, int party, const block *lbls, int nel) = 0; + virtual void finalize() {} }; } #endif diff --git a/src/emp-tool/gc/halfgate_eva.h b/src/emp-tool/gc/halfgate_eva.h index 0f3a1c2..d116779 100644 --- a/src/emp-tool/gc/halfgate_eva.h +++ b/src/emp-tool/gc/halfgate_eva.h @@ -7,57 +7,57 @@ namespace emp { inline block halfgates_eval(block A, block B, const block *table, MITCCRH<8> *mitccrh) { - block HA, HB, W; - int sa, sb; + block HA, HB, W; + int sa, sb; - sa = getLSB(A); - sb = getLSB(B); + sa = getLSB(A); + sb = getLSB(B); - block H[2]; - H[0] = A; - H[1] = B; - mitccrh->hash_cir<2,1>(H); - HA = H[0]; - HB = H[1]; + block H[2]; + H[0] = A; + H[1] = B; + mitccrh->hash_cir<2,1>(H); + HA = H[0]; + HB = H[1]; - W = HA ^ HB; - W = W ^ (select_mask[sa] & table[0]); - W = W ^ (select_mask[sb] & table[1]); - W = W ^ (select_mask[sb] & A); - return W; + W = HA ^ HB; + W = W ^ (select_mask[sa] & table[0]); + W = W ^ (select_mask[sb] & table[1]); + W = W ^ (select_mask[sb] & A); + return W; } class HalfGateEva:public CircuitExecution { public: - IOChannel io; - block constant[2]; - MITCCRH<8> mitccrh; - HalfGateEva(IOChannel io): io(io) { - set_delta(); - block tmp; - io.recv_block(&tmp, 1); - mitccrh.setS(tmp); - } - void set_delta() { - io.recv_block(constant, 2); - } - block public_label(bool b) override { - return constant[b]; - } - block and_gate(const block& a, const block& b) override { - block table[2]; - io.recv_block(table, 2); - return halfgates_eval(a, b, table, &mitccrh); - } - block xor_gate(const block& a, const block& b) override { - return a ^ b; - } - block not_gate(const block&a) override { - return xor_gate(a, public_label(true)); - } - uint64_t num_and() override { - return mitccrh.gid/2; - } + IOChannel io; + block constant[2]; + MITCCRH<8> mitccrh; + HalfGateEva(IOChannel io): io(io) { + set_delta(); + block tmp; + io.recv_block(&tmp, 1); + mitccrh.setS(tmp); + } + void set_delta() { + io.recv_block(constant, 2); + } + block public_label(bool b) override { + return constant[b]; + } + block and_gate(const block& a, const block& b) override { + block table[2]; + io.recv_block(table, 2); + return halfgates_eval(a, b, table, &mitccrh); + } + block xor_gate(const block& a, const block& b) override { + return a ^ b; + } + block not_gate(const block&a) override { + return xor_gate(a, public_label(true)); + } + uint64_t num_and() override { + return mitccrh.gid/2; + } }; } #endif// HALFGATE_EVA_H diff --git a/src/emp-tool/gc/halfgate_gen.h b/src/emp-tool/gc/halfgate_gen.h index 5f63d35..2aff150 100644 --- a/src/emp-tool/gc/halfgate_gen.h +++ b/src/emp-tool/gc/halfgate_gen.h @@ -12,71 +12,71 @@ namespace emp { * https://eprint.iacr.org/2014/756.pdf */ inline block halfgates_garble(block LA0, block A1, block LB0, block B1, block delta, block *table, MITCCRH<8> *mitccrh) { - bool pa = getLSB(LA0); - bool pb = getLSB(LB0); - block HLA0, HA1, HLB0, HB1; - block tmp, W0; + bool pa = getLSB(LA0); + bool pb = getLSB(LB0); + block HLA0, HA1, HLB0, HB1; + block tmp, W0; - block H[4]; - H[0] = LA0; - H[1] = A1; - H[2] = LB0; - H[3] = B1; - mitccrh->hash_cir<2,2>(H); - HLA0 = H[0]; - HA1 = H[1]; - HLB0 = H[2]; - HB1 = H[3]; + block H[4]; + H[0] = LA0; + H[1] = A1; + H[2] = LB0; + H[3] = B1; + mitccrh->hash_cir<2,2>(H); + HLA0 = H[0]; + HA1 = H[1]; + HLB0 = H[2]; + HB1 = H[3]; - table[0] = HLA0 ^ HA1; - table[0] = table[0] ^ (select_mask[pb] & delta); - W0 = HLA0; - W0 = W0 ^ (select_mask[pa] & table[0]); - tmp = HLB0 ^ HB1; - table[1] = tmp ^ LA0; - W0 = W0 ^ HLB0; - W0 = W0 ^ (select_mask[pb] & tmp); + table[0] = HLA0 ^ HA1; + table[0] = table[0] ^ (select_mask[pb] & delta); + W0 = HLA0; + W0 = W0 ^ (select_mask[pa] & table[0]); + tmp = HLB0 ^ HB1; + table[1] = tmp ^ LA0; + W0 = W0 ^ HLB0; + W0 = W0 ^ (select_mask[pb] & tmp); - return W0; + return W0; } class HalfGateGen:public CircuitExecution { public: - block delta; - IOChannel io; - block constant[2]; - MITCCRH<8> mitccrh; - HalfGateGen(IOChannel io) :io(io) { - block tmp[2]; - PRG().random_block(tmp, 2); - set_delta(tmp[0]); - io.send_block(tmp+1, 1); - mitccrh.setS(tmp[1]); - } - void set_delta(const block & _delta) { - delta = set_bit(_delta, 0); - PRG().random_block(constant, 2); - io.send_block(constant, 2); - constant[1] = constant[1] ^ delta; - } - block public_label(bool b) override { - return constant[b]; - } - block and_gate(const block& a, const block& b) override { - block table[2]; - block res = halfgates_garble(a, a^delta, b, b^delta, delta, table, &mitccrh); - io.send_block(table, 2); - return res; - } - block xor_gate(const block&a, const block& b) override { - return a ^ b; - } - block not_gate(const block&a) override { - return xor_gate(a, public_label(true)); - } - uint64_t num_and() override { - return mitccrh.gid/2; - } + block delta; + IOChannel io; + block constant[2]; + MITCCRH<8> mitccrh; + HalfGateGen(IOChannel io) :io(io) { + block tmp[2]; + PRG().random_block(tmp, 2); + set_delta(tmp[0]); + io.send_block(tmp+1, 1); + mitccrh.setS(tmp[1]); + } + void set_delta(const block & _delta) { + delta = set_bit(_delta, 0); + PRG().random_block(constant, 2); + io.send_block(constant, 2); + constant[1] = constant[1] ^ delta; + } + block public_label(bool b) override { + return constant[b]; + } + block and_gate(const block& a, const block& b) override { + block table[2]; + block res = halfgates_garble(a, a^delta, b, b^delta, delta, table, &mitccrh); + io.send_block(table, 2); + return res; + } + block xor_gate(const block&a, const block& b) override { + return a ^ b; + } + block not_gate(const block&a) override { + return xor_gate(a, public_label(true)); + } + uint64_t num_and() override { + return mitccrh.gid/2; + } }; } #endif// HALFGATE_GEN_H diff --git a/src/emp-tool/io/io_channel.h b/src/emp-tool/io/io_channel.h index 86ae2c1..f819065 100644 --- a/src/emp-tool/io/io_channel.h +++ b/src/emp-tool/io/io_channel.h @@ -4,136 +4,136 @@ #include "emp-tool/utils/prg.h" #include "emp-tool/utils/group.h" #include -#include +#include namespace emp { class IOChannel { private: - std::shared_ptr raw_io; - std::shared_ptr counter = std::make_shared(0); + std::shared_ptr raw_io; + std::shared_ptr counter = std::make_shared(0); public: - IOChannel(std::shared_ptr raw_io): raw_io(raw_io) {} + IOChannel(std::shared_ptr raw_io): raw_io(raw_io) {} - void send_data(const void * data, size_t nbyte) { - *counter += nbyte; - raw_io->send(data, nbyte); - } + void send_data(const void * data, size_t nbyte) { + *counter += nbyte; + raw_io->send(data, nbyte); + } - void recv_data(void * data, size_t nbyte) { - raw_io->recv(data, nbyte); - } + void recv_data(void * data, size_t nbyte) { + raw_io->recv(data, nbyte); + } - void flush() { - raw_io->flush(); - } + void flush() { + raw_io->flush(); + } - void send_block(const block* data, size_t nblock) { - send_data(data, nblock*sizeof(block)); - } + void send_block(const block* data, size_t nblock) { + send_data(data, nblock*sizeof(block)); + } - void recv_block(block* data, size_t nblock) { - recv_data(data, nblock*sizeof(block)); - } + void recv_block(block* data, size_t nblock) { + recv_data(data, nblock*sizeof(block)); + } - void send_pt(Point *A, size_t num_pts = 1) { - for(size_t i = 0; i < num_pts; ++i) { - size_t len = A[i].size(); - A[i].group->resize_scratch(len); - unsigned char * tmp = A[i].group->scratch; - send_data(&len, 4); - A[i].to_bin(tmp, len); - send_data(tmp, len); - } - } + void send_pt(Point *A, size_t num_pts = 1) { + for(size_t i = 0; i < num_pts; ++i) { + size_t len = A[i].size(); + A[i].group->resize_scratch(len); + unsigned char * tmp = A[i].group->scratch; + send_data(&len, 4); + A[i].to_bin(tmp, len); + send_data(tmp, len); + } + } - void recv_pt(Group * g, Point *A, size_t num_pts = 1) { - size_t len = 0; - for(size_t i = 0; i < num_pts; ++i) { - recv_data(&len, 4); - assert(len <= 2048); - g->resize_scratch(len); - unsigned char * tmp = g->scratch; - recv_data(tmp, len); - A[i].from_bin(g, tmp, len); - } - } + void recv_pt(Group * g, Point *A, size_t num_pts = 1) { + size_t len = 0; + for(size_t i = 0; i < num_pts; ++i) { + recv_data(&len, 4); + assert(len <= 2048); + g->resize_scratch(len); + unsigned char * tmp = g->scratch; + recv_data(tmp, len); + A[i].from_bin(g, tmp, len); + } + } - void send_bool(bool * data, size_t length) { - void * ptr = (void *)data; - size_t space = length; - const void * aligned = std::align(alignof(uint64_t), sizeof(uint64_t), ptr, space); - if(aligned == nullptr) - send_data(data, length); - else{ - size_t diff = length - space; - send_data(data, diff); - send_bool_aligned((const bool*)aligned, length - diff); - } - } + void send_bool(bool * data, size_t length) { + void * ptr = (void *)data; + size_t space = length; + const void * aligned = std::align(alignof(uint64_t), sizeof(uint64_t), ptr, space); + if(aligned == nullptr) + send_data(data, length); + else{ + size_t diff = length - space; + send_data(data, diff); + send_bool_aligned((const bool*)aligned, length - diff); + } + } - void recv_bool(bool * data, size_t length) { - void * ptr = (void *)data; - size_t space = length; - void * aligned = std::align(alignof(uint64_t), sizeof(uint64_t), ptr, space); - if(aligned == nullptr) - recv_data(data, length); - else{ - size_t diff = length - space; - recv_data(data, diff); - recv_bool_aligned((bool*)aligned, length - diff); - } - } + void recv_bool(bool * data, size_t length) { + void * ptr = (void *)data; + size_t space = length; + void * aligned = std::align(alignof(uint64_t), sizeof(uint64_t), ptr, space); + if(aligned == nullptr) + recv_data(data, length); + else{ + size_t diff = length - space; + recv_data(data, diff); + recv_bool_aligned((bool*)aligned, length - diff); + } + } - void send_bool_aligned(const bool * data, size_t length) { - const bool * data64 = data; - size_t i = 0; + void send_bool_aligned(const bool * data, size_t length) { + const bool * data64 = data; + size_t i = 0; unsigned long long unpack; - for(; i < length/8; ++i) { - unsigned long long mask = 0x0101010101010101ULL; - unsigned long long tmp = 0; + for(; i < length/8; ++i) { + unsigned long long mask = 0x0101010101010101ULL; + unsigned long long tmp = 0; memcpy(&unpack, data64, sizeof(unpack)); data64 += sizeof(unpack); #if defined(__BMI2__) - tmp = _pext_u64(unpack, mask); + tmp = _pext_u64(unpack, mask); #else - // https://github.com/Forceflow/libmorton/issues/6 - for (unsigned long long bb = 1; mask != 0; bb += bb) { - if (unpack & mask & -mask) { tmp |= bb; } - mask &= (mask - 1); - } + // https://github.com/Forceflow/libmorton/issues/6 + for (unsigned long long bb = 1; mask != 0; bb += bb) { + if (unpack & mask & -mask) { tmp |= bb; } + mask &= (mask - 1); + } #endif - send_data(&tmp, 1); - } - if (8*i != length) - send_data(data + 8*i, length - 8*i); - } - void recv_bool_aligned(bool * data, size_t length) { - bool * data64 = data; - size_t i = 0; + send_data(&tmp, 1); + } + if (8*i != length) + send_data(data + 8*i, length - 8*i); + } + void recv_bool_aligned(bool * data, size_t length) { + bool * data64 = data; + size_t i = 0; unsigned long long unpack; - for(; i < length/8; ++i) { - unsigned long long mask = 0x0101010101010101ULL; - unsigned long long tmp = 0; - recv_data(&tmp, 1); + for(; i < length/8; ++i) { + unsigned long long mask = 0x0101010101010101ULL; + unsigned long long tmp = 0; + recv_data(&tmp, 1); #if defined(__BMI2__) - unpack = _pdep_u64(tmp, mask); + unpack = _pdep_u64(tmp, mask); #else - unpack = 0; - for (unsigned long long bb = 1; mask != 0; bb += bb) { - if (tmp & bb) {unpack |= mask & (-mask); } - mask &= (mask - 1); - } + unpack = 0; + for (unsigned long long bb = 1; mask != 0; bb += bb) { + if (tmp & bb) {unpack |= mask & (-mask); } + mask &= (mask - 1); + } #endif memcpy(data64, &unpack, sizeof(unpack)); data64 += sizeof(unpack); - - } - if (8*i != length) - recv_data(data + 8*i, length - 8*i); - } + + } + if (8*i != length) + recv_data(data + 8*i, length - 8*i); + } }; } #endif diff --git a/src/emp-tool/io/net_io.h b/src/emp-tool/io/net_io.h index 8ffa1e9..5d56bca 100644 --- a/src/emp-tool/io/net_io.h +++ b/src/emp-tool/io/net_io.h @@ -20,129 +20,129 @@ namespace emp { class NetIO: public IRawIO { public: - bool is_server; - int mysocket = -1; - int consocket = -1; - FILE * stream = nullptr; - char * buffer = nullptr; - bool has_sent = false; - string addr; - int port; - NetIO(const char * address, int port) { - if (port <0 || port > 65535) { - throw std::runtime_error("Invalid port number!"); - } + bool is_server; + int mysocket = -1; + int consocket = -1; + FILE * stream = nullptr; + char * buffer = nullptr; + bool has_sent = false; + string addr; + int port; + NetIO(const char * address, int port) { + if (port <0 || port > 65535) { + throw std::runtime_error("Invalid port number!"); + } - this->port = port; - is_server = (address == nullptr); - if (address == nullptr) { - struct sockaddr_in dest; - struct sockaddr_in serv; - socklen_t socksize = sizeof(struct sockaddr_in); - memset(&serv, 0, sizeof(serv)); - serv.sin_family = AF_INET; - serv.sin_addr.s_addr = htonl(INADDR_ANY); /* set our address to any interface */ - serv.sin_port = htons(port); /* set the server port number */ - mysocket = socket(AF_INET, SOCK_STREAM, 0); - int reuse = 1; - setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse)); - if(bind(mysocket, (struct sockaddr *)&serv, sizeof(struct sockaddr)) < 0) { - perror("error: bind"); - exit(1); - } - if(listen(mysocket, 1) < 0) { - perror("error: listen"); - exit(1); - } - consocket = accept(mysocket, (struct sockaddr *)&dest, &socksize); - close(mysocket); - } - else { - addr = string(address); + this->port = port; + is_server = (address == nullptr); + if (address == nullptr) { + struct sockaddr_in dest; + struct sockaddr_in serv; + socklen_t socksize = sizeof(struct sockaddr_in); + memset(&serv, 0, sizeof(serv)); + serv.sin_family = AF_INET; + serv.sin_addr.s_addr = htonl(INADDR_ANY); /* set our address to any interface */ + serv.sin_port = htons(port); /* set the server port number */ + mysocket = socket(AF_INET, SOCK_STREAM, 0); + int reuse = 1; + setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(reuse)); + if(bind(mysocket, (struct sockaddr *)&serv, sizeof(struct sockaddr)) < 0) { + perror("error: bind"); + exit(1); + } + if(listen(mysocket, 1) < 0) { + perror("error: listen"); + exit(1); + } + consocket = accept(mysocket, (struct sockaddr *)&dest, &socksize); + close(mysocket); + } + else { + addr = string(address); - struct sockaddr_in dest; - memset(&dest, 0, sizeof(dest)); - dest.sin_family = AF_INET; - dest.sin_addr.s_addr = inet_addr(address); - dest.sin_port = htons(port); + struct sockaddr_in dest; + memset(&dest, 0, sizeof(dest)); + dest.sin_family = AF_INET; + dest.sin_addr.s_addr = inet_addr(address); + dest.sin_port = htons(port); - while(1) { - consocket = socket(AF_INET, SOCK_STREAM, 0); + while(1) { + consocket = socket(AF_INET, SOCK_STREAM, 0); - if (connect(consocket, (struct sockaddr *)&dest, sizeof(struct sockaddr)) == 0) { - break; - } + if (connect(consocket, (struct sockaddr *)&dest, sizeof(struct sockaddr)) == 0) { + break; + } - close(consocket); - usleep(1000); - } - } - set_nodelay(); - stream = fdopen(consocket, "wb+"); - buffer = new char[NETWORK_BUFFER_SIZE]; - memset(buffer, 0, NETWORK_BUFFER_SIZE); - setvbuf(stream, buffer, _IOFBF, NETWORK_BUFFER_SIZE); + close(consocket); + usleep(1000); + } + } + set_nodelay(); + stream = fdopen(consocket, "wb+"); + buffer = new char[NETWORK_BUFFER_SIZE]; + memset(buffer, 0, NETWORK_BUFFER_SIZE); + setvbuf(stream, buffer, _IOFBF, NETWORK_BUFFER_SIZE); - std::cout << "connected\n"; - } + std::cout << "connected\n"; + } - void sync() { - int tmp = 0; - if(is_server) { - send(&tmp, 1); - recv(&tmp, 1); - } else { - recv(&tmp, 1); - send(&tmp, 1); - flush(); - } - } + void sync() { + int tmp = 0; + if(is_server) { + send(&tmp, 1); + recv(&tmp, 1); + } else { + recv(&tmp, 1); + send(&tmp, 1); + flush(); + } + } - ~NetIO(){ - flush(); - fclose(stream); - delete[] buffer; - } + ~NetIO(){ + flush(); + fclose(stream); + delete[] buffer; + } - void set_nodelay() { - const int one=1; - setsockopt(consocket,IPPROTO_TCP,TCP_NODELAY,&one,sizeof(one)); - } + void set_nodelay() { + const int one=1; + setsockopt(consocket,IPPROTO_TCP,TCP_NODELAY,&one,sizeof(one)); + } - void set_delay() { - const int zero = 0; - setsockopt(consocket,IPPROTO_TCP,TCP_NODELAY,&zero,sizeof(zero)); - } + void set_delay() { + const int zero = 0; + setsockopt(consocket,IPPROTO_TCP,TCP_NODELAY,&zero,sizeof(zero)); + } - void flush() { - fflush(stream); - } + void flush() { + fflush(stream); + } - void send(const void * data, size_t len) { - size_t sent = 0; - while(sent < len) { - size_t res = fwrite(sent + (char*)data, 1, len - sent, stream); - if (res > 0) - sent+=res; - else - error("net_send_data\n"); - } - has_sent = true; - } + void send(const void * data, size_t len) { + size_t sent = 0; + while(sent < len) { + size_t res = fwrite(sent + (char*)data, 1, len - sent, stream); + if (res > 0) + sent+=res; + else + error("net_send_data\n"); + } + has_sent = true; + } - void recv(void * data, size_t len) { - if(has_sent) - fflush(stream); - has_sent = false; - size_t sent = 0; - while(sent < len) { - size_t res = fread(sent + (char*)data, 1, len - sent, stream); - if (res > 0) - sent += res; - else - error("net_recv_data\n"); - } - } + void recv(void * data, size_t len) { + if(has_sent) + fflush(stream); + has_sent = false; + size_t sent = 0; + while(sent < len) { + size_t res = fread(sent + (char*)data, 1, len - sent, stream); + if (res > 0) + sent += res; + else + error("net_recv_data\n"); + } + } }; } diff --git a/src/emp-tool/utils/aes.h b/src/emp-tool/utils/aes.h index 0235544..53202b9 100644 --- a/src/emp-tool/utils/aes.h +++ b/src/emp-tool/utils/aes.h @@ -118,31 +118,31 @@ inline void AES_ecb_encrypt_blks(block *_blks, unsigned int nblks, const AES_KEY uint8x16_t * keys = (uint8x16_t*)(key->rd_key); auto * first = blks; for (unsigned int j = 0; j < key->rounds-1; ++j) { - uint8x16_t key_j = (uint8x16_t)keys[j]; + uint8x16_t key_j = (uint8x16_t)keys[j]; blks = first; for (unsigned int i = 0; i < nblks; ++i, ++blks) - *blks = vaesmcq_u8(vaeseq_u8(*blks, key_j)); + *blks = vaesmcq_u8(vaeseq_u8(*blks, key_j)); } - uint8x16_t last_key = (uint8x16_t)keys[key->rounds-1]; - for (unsigned int i = 0; i < nblks; ++i, ++first) - *first = vaeseq_u8(*first, last_key) ^ (uint8x16_t)keys[key->rounds]; + uint8x16_t last_key = (uint8x16_t)keys[key->rounds-1]; + for (unsigned int i = 0; i < nblks; ++i, ++first) + *first = vaeseq_u8(*first, last_key) ^ (uint8x16_t)keys[key->rounds]; } #endif #ifdef __GNUC__ - #ifndef __clang__ - #pragma GCC push_options - #pragma GCC optimize ("unroll-loops") - #endif + #ifndef __clang__ + #pragma GCC push_options + #pragma GCC optimize ("unroll-loops") + #endif #endif template inline void AES_ecb_encrypt_blks(block *blks, const AES_KEY *key) { - AES_ecb_encrypt_blks(blks, N, key); + AES_ecb_encrypt_blks(blks, N, key); } #ifdef __GNUC_ - #ifndef __clang___ - #pragma GCC pop_options - #endif + #ifndef __clang___ + #pragma GCC pop_options + #endif #endif inline void diff --git a/src/emp-tool/utils/aes_opt.h b/src/emp-tool/utils/aes_opt.h index 294a1b2..42c98f6 100644 --- a/src/emp-tool/utils/aes_opt.h +++ b/src/emp-tool/utils/aes_opt.h @@ -6,17 +6,17 @@ namespace emp { template static inline void ks_rounds(AES_KEY * keys, block con, block con3, block mask, int r) { - for (int i = 0; i < NumKeys; ++i) { - block key = keys[i].rd_key[r-1]; - block x2 =_mm_shuffle_epi8(key, mask); - block aux = _mm_aesenclast_si128 (x2, con); + for (int i = 0; i < NumKeys; ++i) { + block key = keys[i].rd_key[r-1]; + block x2 =_mm_shuffle_epi8(key, mask); + block aux = _mm_aesenclast_si128 (x2, con); - block globAux=_mm_slli_epi64(key, 32); - key=_mm_xor_si128(globAux, key); - globAux=_mm_shuffle_epi8(key, con3); - key=_mm_xor_si128(globAux, key); - keys[i].rd_key[r] = _mm_xor_si128(aux, key); - } + block globAux=_mm_slli_epi64(key, 32); + key=_mm_xor_si128(globAux, key); + globAux=_mm_shuffle_epi8(key, con3); + key=_mm_xor_si128(globAux, key); + keys[i].rd_key[r] = _mm_xor_si128(aux, key); + } } /* * AES key scheduling for 8 keys @@ -25,34 +25,34 @@ static inline void ks_rounds(AES_KEY * keys, block con, block con3, block mask, */ template static inline void AES_opt_key_schedule(block* user_key, AES_KEY *keys) { - block con = _mm_set_epi32(1,1,1,1); - block con2 = _mm_set_epi32(0x1b,0x1b,0x1b,0x1b); - block con3 = _mm_set_epi32(0x07060504,0x07060504,0x0ffffffff,0x0ffffffff); - block mask = _mm_set_epi32(0x0c0f0e0d,0x0c0f0e0d,0x0c0f0e0d,0x0c0f0e0d); + block con = _mm_set_epi32(1,1,1,1); + block con2 = _mm_set_epi32(0x1b,0x1b,0x1b,0x1b); + block con3 = _mm_set_epi32(0x07060504,0x07060504,0x0ffffffff,0x0ffffffff); + block mask = _mm_set_epi32(0x0c0f0e0d,0x0c0f0e0d,0x0c0f0e0d,0x0c0f0e0d); - for(int i = 0; i < NumKeys; ++i) { - keys[i].rounds=10; - keys[i].rd_key[0] = user_key[i]; - } + for(int i = 0; i < NumKeys; ++i) { + keys[i].rounds=10; + keys[i].rd_key[0] = user_key[i]; + } - ks_rounds(keys, con, con3, mask, 1); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 2); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 3); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 4); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 5); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 6); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 7); - con=_mm_slli_epi32(con, 1); - ks_rounds(keys, con, con3, mask, 8); - ks_rounds(keys, con2, con3, mask, 9); - con2=_mm_slli_epi32(con2, 1); - ks_rounds(keys, con2, con3, mask, 10); + ks_rounds(keys, con, con3, mask, 1); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 2); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 3); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 4); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 5); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 6); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 7); + con=_mm_slli_epi32(con, 1); + ks_rounds(keys, con, con3, mask, 8); + ks_rounds(keys, con2, con3, mask, 9); + con2=_mm_slli_epi32(con2, 1); + ks_rounds(keys, con2, con3, mask, 10); } /* @@ -61,56 +61,56 @@ static inline void AES_opt_key_schedule(block* user_key, AES_KEY *keys) { #ifdef __x86_64__ template static inline void ParaEnc(block *blks, AES_KEY *keys) { - block * first = blks; - for(size_t i = 0; i < numKeys; ++i) { - block K = keys[i].rd_key[0]; - for(size_t j = 0; j < numEncs; ++j) { - *blks = *blks ^ K; - ++blks; - } - } + block * first = blks; + for(size_t i = 0; i < numKeys; ++i) { + block K = keys[i].rd_key[0]; + for(size_t j = 0; j < numEncs; ++j) { + *blks = *blks ^ K; + ++blks; + } + } - for (unsigned int r = 1; r < 10; ++r) { - blks = first; - for(size_t i = 0; i < numKeys; ++i) { - block K = keys[i].rd_key[r]; - for(size_t j = 0; j < numEncs; ++j) { - *blks = _mm_aesenc_si128(*blks, K); - ++blks; - } - } - } + for (unsigned int r = 1; r < 10; ++r) { + blks = first; + for(size_t i = 0; i < numKeys; ++i) { + block K = keys[i].rd_key[r]; + for(size_t j = 0; j < numEncs; ++j) { + *blks = _mm_aesenc_si128(*blks, K); + ++blks; + } + } + } - blks = first; - for(size_t i = 0; i < numKeys; ++i) { - block K = keys[i].rd_key[10]; - for(size_t j = 0; j < numEncs; ++j) { - *blks = _mm_aesenclast_si128(*blks, K); - ++blks; - } - } + blks = first; + for(size_t i = 0; i < numKeys; ++i) { + block K = keys[i].rd_key[10]; + for(size_t j = 0; j < numEncs; ++j) { + *blks = _mm_aesenclast_si128(*blks, K); + ++blks; + } + } } #elif __aarch64__ template static inline void ParaEnc(block *_blks, AES_KEY *keys) { - uint8x16_t * first = (uint8x16_t*)(_blks); + uint8x16_t * first = (uint8x16_t*)(_blks); - for (unsigned int r = 0; r < 9; ++r) { - auto blks = first; - for(size_t i = 0; i < numKeys; ++i) { - uint8x16_t K = vreinterpretq_u8_m128i(keys[i].rd_key[r]); - for(size_t j = 0; j < numEncs; ++j, ++blks) - *blks = vaesmcq_u8(vaeseq_u8(*blks, K)); - } - } - - auto blks = first; - for(size_t i = 0; i < numKeys; ++i) { - uint8x16_t K = vreinterpretq_u8_m128i(keys[i].rd_key[9]); - uint8x16_t K2 = vreinterpretq_u8_m128i(keys[i].rd_key[10]); - for(size_t j = 0; j < numEncs; ++j, ++blks) - *blks = vaeseq_u8(*blks, K) ^ K2; - } + for (unsigned int r = 0; r < 9; ++r) { + auto blks = first; + for(size_t i = 0; i < numKeys; ++i) { + uint8x16_t K = vreinterpretq_u8_m128i(keys[i].rd_key[r]); + for(size_t j = 0; j < numEncs; ++j, ++blks) + *blks = vaesmcq_u8(vaeseq_u8(*blks, K)); + } + } + + auto blks = first; + for(size_t i = 0; i < numKeys; ++i) { + uint8x16_t K = vreinterpretq_u8_m128i(keys[i].rd_key[9]); + uint8x16_t K2 = vreinterpretq_u8_m128i(keys[i].rd_key[10]); + for(size_t j = 0; j < numEncs; ++j, ++blks) + *blks = vaeseq_u8(*blks, K) ^ K2; + } } #endif diff --git a/src/emp-tool/utils/block.h b/src/emp-tool/utils/block.h index ad7af1b..4781820 100644 --- a/src/emp-tool/utils/block.h +++ b/src/emp-tool/utils/block.h @@ -6,7 +6,7 @@ #elif __aarch64__ #include "sse2neon.h" inline __m128i _mm_aesimc_si128(__m128i a) { - return vreinterpretq_m128i_u8(vaesimcq_u8(vreinterpretq_u8_m128i(a))); + return vreinterpretq_m128i_u8(vaesimcq_u8(vreinterpretq_u8_m128i(a))); } inline __m128i _mm_aesdeclast_si128 (__m128i a, __m128i RoundKey) @@ -26,17 +26,17 @@ namespace emp { using block = __m128i; inline bool getLSB(const block & x) { - return (x[0] & 1) == 1; + return (x[0] & 1) == 1; } #ifdef __x86_64__ __attribute__((target("sse2"))) inline block makeBlock(uint64_t high, uint64_t low) { - return _mm_set_epi64x(high, low); + return _mm_set_epi64x(high, low); } #elif __aarch64__ inline block makeBlock(uint64_t high, uint64_t low) { - return (block)vcombine_u64((uint64x1_t)low, (uint64x1_t)high); + return (block)vcombine_u64((uint64x1_t)low, (uint64x1_t)high); } #endif @@ -49,7 +49,7 @@ inline block makeBlock(uint64_t high, uint64_t low) { __attribute__((target("sse2"))) #endif inline block sigma(block a) { - return _mm_shuffle_epi32(a, 78) ^ (a & makeBlock(0xFFFFFFFFFFFFFFFF, 0x00)); + return _mm_shuffle_epi32(a, 78) ^ (a & makeBlock(0xFFFFFFFFFFFFFFFF, 0x00)); } const block zero_block = makeBlock(0, 0); @@ -57,48 +57,48 @@ const block all_one_block = makeBlock(0xFFFFFFFFFFFFFFFF,0xFFFFFFFFFFFFFFFF); const block select_mask[2] = {zero_block, all_one_block}; inline block set_bit(const block & a, int i) { - if(i < 64) - return makeBlock(0L, 1ULL< inline void vector_inn_prdt_sum_red(block *res, block const *a, const block *b) { - vector_inn_prdt_sum_red(res, a, b, N); + vector_inn_prdt_sum_red(res, a, b, N); } /* inner product of two galois field vectors without reduction */ inline void vector_inn_prdt_sum_no_red(block *res, const block *a, const block *b, int sz) { - block r1 = zero_block, r2 = zero_block; - block r11, r12; - for(int i = 0; i < sz; i++) { - mul128(a[i], b[i], &r11, &r12); - r1 = r1 ^ r11; - r2 = r2 ^ r12; - } - res[0] = r1; - res[1] = r2; + block r1 = zero_block, r2 = zero_block; + block r11, r12; + for(int i = 0; i < sz; i++) { + mul128(a[i], b[i], &r11, &r12); + r1 = r1 ^ r11; + r2 = r2 ^ r12; + } + res[0] = r1; + res[1] = r2; } /* inner product of two galois field vectors without reduction */ template inline void vector_inn_prdt_sum_no_red(block *res, const block *a, const block *b) { - vector_inn_prdt_sum_no_red(res, a, b, N); + vector_inn_prdt_sum_no_red(res, a, b, N); } /* coefficients of almost universal hash function */ inline void uni_hash_coeff_gen(block* coeff, block seed, int sz) { - // Handle the case with small `sz` - coeff[0] = seed; - if(sz == 1) return; + // Handle the case with small `sz` + coeff[0] = seed; + if(sz == 1) return; - gfmul(seed, seed, &coeff[1]); - if(sz == 2) return; + gfmul(seed, seed, &coeff[1]); + if(sz == 2) return; - gfmul(coeff[1], seed, &coeff[2]); - if(sz == 3) return; + gfmul(coeff[1], seed, &coeff[2]); + if(sz == 3) return; - block multiplier; - gfmul(coeff[2], seed, &multiplier); - coeff[3] = multiplier; - if(sz == 4) return; + block multiplier; + gfmul(coeff[2], seed, &multiplier); + coeff[3] = multiplier; + if(sz == 4) return; - // Computing the rest with a batch of 4 - int i = 4; - for(; i < sz - 3; i += 4) { - gfmul(coeff[i - 4], multiplier, &coeff[i]); - gfmul(coeff[i - 3], multiplier, &coeff[i + 1]); - gfmul(coeff[i - 2], multiplier, &coeff[i + 2]); - gfmul(coeff[i - 1], multiplier, &coeff[i + 3]); - } + // Computing the rest with a batch of 4 + int i = 4; + for(; i < sz - 3; i += 4) { + gfmul(coeff[i - 4], multiplier, &coeff[i]); + gfmul(coeff[i - 3], multiplier, &coeff[i + 1]); + gfmul(coeff[i - 2], multiplier, &coeff[i + 2]); + gfmul(coeff[i - 1], multiplier, &coeff[i + 3]); + } - // Cleaning up with the rest - int remainder = sz % 4; - if(remainder != 0) { - i = sz - remainder; - for(; i < sz; ++i) - gfmul(coeff[i - 1], seed, &coeff[i]); - } + // Cleaning up with the rest + int remainder = sz % 4; + if(remainder != 0) { + i = sz - remainder; + for(; i < sz; ++i) + gfmul(coeff[i - 1], seed, &coeff[i]); + } } /* coefficients of almost universal hash function */ template inline void uni_hash_coeff_gen(block* coeff, block seed) { - uni_hash_coeff_gen(coeff, seed, N); + uni_hash_coeff_gen(coeff, seed, N); } /* packing in Galois field (v[i] * X^i for v of size 128) */ class GaloisFieldPacking { - public: - block base[128]; + public: + block base[128]; - GaloisFieldPacking() { - packing_base_gen(); - } + GaloisFieldPacking() { + packing_base_gen(); + } - ~GaloisFieldPacking() { + ~GaloisFieldPacking() { - } + } - void packing_base_gen() { - uint64_t a = 0, b = 1; - for(int i = 0; i < 64; i+=4) { - base[i] = _mm_set_epi64x(a, b); - base[i+1] = _mm_set_epi64x(a, b<<1); - base[i+2] = _mm_set_epi64x(a, b<<2); - base[i+3] = _mm_set_epi64x(a, b<<3); - b <<= 4; - } - a = 1, b = 0; - for(int i = 64; i < 128; i+=4) { - base[i] = _mm_set_epi64x(a, b); - base[i+1] = _mm_set_epi64x(a<<1, b); - base[i+2] = _mm_set_epi64x(a<<2, b); - base[i+3] = _mm_set_epi64x(a<<3, b); - a <<= 4; - } - } + void packing_base_gen() { + uint64_t a = 0, b = 1; + for(int i = 0; i < 64; i+=4) { + base[i] = _mm_set_epi64x(a, b); + base[i+1] = _mm_set_epi64x(a, b<<1); + base[i+2] = _mm_set_epi64x(a, b<<2); + base[i+3] = _mm_set_epi64x(a, b<<3); + b <<= 4; + } + a = 1, b = 0; + for(int i = 64; i < 128; i+=4) { + base[i] = _mm_set_epi64x(a, b); + base[i+1] = _mm_set_epi64x(a<<1, b); + base[i+2] = _mm_set_epi64x(a<<2, b); + base[i+3] = _mm_set_epi64x(a<<3, b); + a <<= 4; + } + } - void packing(block *res, block *data) { - vector_inn_prdt_sum_red(res, data, base, 128); - } + void packing(block *res, block *data) { + vector_inn_prdt_sum_red(res, data, base, 128); + } }; /* XOR of all elements in a vector */ inline void vector_self_xor(block *sum, block *data, int sz) { - block res[4]; - res[0] = zero_block; - res[1] = zero_block; - res[2] = zero_block; - res[3] = zero_block; - for(int i = 0; i < (sz/4)*4; i+=4) { - res[0] = data[i] ^ res[0]; - res[1] = data[i+1] ^ res[1]; - res[2] = data[i+2] ^ res[2]; - res[3] = data[i+3] ^ res[3]; - } - for(int i = (sz/4)*4, j = 0; i < sz; ++i, ++j) - res[j] = data[i] ^ res[j]; - res[0] = res[0] ^ res[1]; - res[2] = res[2] ^ res[3]; - *sum = res[0] ^ res[2]; + block res[4]; + res[0] = zero_block; + res[1] = zero_block; + res[2] = zero_block; + res[3] = zero_block; + for(int i = 0; i < (sz/4)*4; i+=4) { + res[0] = data[i] ^ res[0]; + res[1] = data[i+1] ^ res[1]; + res[2] = data[i+2] ^ res[2]; + res[3] = data[i+3] ^ res[3]; + } + for(int i = (sz/4)*4, j = 0; i < sz; ++i, ++j) + res[j] = data[i] ^ res[j]; + res[0] = res[0] ^ res[1]; + res[2] = res[2] ^ res[3]; + *sum = res[0] ^ res[2]; } /* XOR of all elements in a vector */ template inline void vector_self_xor(block *sum, block *data) { - vector_self_xor(sum, data, N); + vector_self_xor(sum, data, N); } } #endif diff --git a/src/emp-tool/utils/group.h b/src/emp-tool/utils/group.h index adbd67f..ecebb71 100644 --- a/src/emp-tool/utils/group.h +++ b/src/emp-tool/utils/group.h @@ -14,57 +14,57 @@ //#endif namespace emp { class BigInt { public: - BIGNUM *n = nullptr; - BigInt(); - BigInt(const BigInt &oth); - BigInt &operator=(BigInt oth); - ~BigInt(); + BIGNUM *n = nullptr; + BigInt(); + BigInt(const BigInt &oth); + BigInt &operator=(BigInt oth); + ~BigInt(); - int size(); - void to_bin(unsigned char * in); - void from_bin(const unsigned char * in, int length); + int size(); + void to_bin(unsigned char * in); + void from_bin(const unsigned char * in, int length); - BigInt add(const BigInt &oth); - BigInt mul(const BigInt &oth, BN_CTX *ctx); - BigInt mod(const BigInt &oth, BN_CTX *ctx); - BigInt add_mod(const BigInt & b, const BigInt& m, BN_CTX *ctx); - BigInt mul_mod(const BigInt & b, const BigInt& m, BN_CTX *ctx); + BigInt add(const BigInt &oth); + BigInt mul(const BigInt &oth, BN_CTX *ctx); + BigInt mod(const BigInt &oth, BN_CTX *ctx); + BigInt add_mod(const BigInt & b, const BigInt& m, BN_CTX *ctx); + BigInt mul_mod(const BigInt & b, const BigInt& m, BN_CTX *ctx); }; class Group; class Point { - public: - EC_POINT *point = nullptr; - Group * group = nullptr; - Point (Group * g = nullptr); - ~Point(); - Point(const Point & p); - Point& operator=(Point p); + public: + EC_POINT *point = nullptr; + Group * group = nullptr; + Point (Group * g = nullptr); + ~Point(); + Point(const Point & p); + Point& operator=(Point p); - void to_bin(unsigned char * buf, size_t buf_len); - size_t size(); - void from_bin(Group * g, const unsigned char * buf, size_t buf_len); + void to_bin(unsigned char * buf, size_t buf_len); + size_t size(); + void from_bin(Group * g, const unsigned char * buf, size_t buf_len); - Point add(Point & rhs); -// Point sub(Point & rhs); -// bool is_at_infinity(); -// bool is_on_curve(); - Point mul(const BigInt &m); - Point inv(); - bool operator==(Point & rhs); + Point add(Point & rhs); +// Point sub(Point & rhs); +// bool is_at_infinity(); +// bool is_on_curve(); + Point mul(const BigInt &m); + Point inv(); + bool operator==(Point & rhs); }; class Group { public: - EC_GROUP *ec_group = nullptr; - BN_CTX * bn_ctx = nullptr; - BigInt order; - unsigned char * scratch; - size_t scratch_size = 256; - Group(); - ~Group(); - void resize_scratch(size_t size); - void get_rand_bn(BigInt & n); - Point get_generator(); - Point mul_gen(const BigInt &m); + EC_GROUP *ec_group = nullptr; + BN_CTX * bn_ctx = nullptr; + BigInt order; + unsigned char * scratch; + size_t scratch_size = 256; + Group(); + ~Group(); + void resize_scratch(size_t size); + void get_rand_bn(BigInt & n); + Point get_generator(); + Point mul_gen(const BigInt &m); }; } diff --git a/src/emp-tool/utils/group_openssl.h b/src/emp-tool/utils/group_openssl.h index eaf389f..d3d30fc 100644 --- a/src/emp-tool/utils/group_openssl.h +++ b/src/emp-tool/utils/group_openssl.h @@ -3,178 +3,178 @@ namespace emp { inline BigInt::BigInt() { - n = BN_new(); + n = BN_new(); } inline BigInt::BigInt(const BigInt &oth) { - n = BN_new(); - BN_copy(n, oth.n); + n = BN_new(); + BN_copy(n, oth.n); } inline BigInt& BigInt::operator=(BigInt oth) { - std::swap(n, oth.n); - return *this; + std::swap(n, oth.n); + return *this; } inline BigInt::~BigInt() { - if (n != nullptr) - BN_free(n); + if (n != nullptr) + BN_free(n); } inline int BigInt::size() { - return BN_num_bytes(n); + return BN_num_bytes(n); } inline void BigInt::to_bin(unsigned char * in) { - BN_bn2bin(n, in); + BN_bn2bin(n, in); } inline void BigInt::from_bin(const unsigned char * in, int length) { - BN_free(n); - n = BN_bin2bn(in, length, nullptr); + BN_free(n); + n = BN_bin2bn(in, length, nullptr); } inline BigInt BigInt::add(const BigInt &oth) { - BigInt ret; - BN_add(ret.n, n, oth.n); - return ret; + BigInt ret; + BN_add(ret.n, n, oth.n); + return ret; } inline BigInt BigInt::mul_mod(const BigInt & b, const BigInt &m, BN_CTX *ctx) { - BigInt ret; - BN_mod_mul(ret.n, n, b.n, m.n, ctx); - return ret; + BigInt ret; + BN_mod_mul(ret.n, n, b.n, m.n, ctx); + return ret; } inline BigInt BigInt::add_mod(const BigInt & b, const BigInt &m, BN_CTX *ctx) { - BigInt ret; - BN_mod_add(ret.n, n, b.n, m.n, ctx); - return ret; + BigInt ret; + BN_mod_add(ret.n, n, b.n, m.n, ctx); + return ret; } inline BigInt BigInt::mul(const BigInt &oth, BN_CTX *ctx) { - BigInt ret; - BN_mul(ret.n, n, oth.n, ctx); - return ret; + BigInt ret; + BN_mul(ret.n, n, oth.n, ctx); + return ret; } inline BigInt BigInt::mod(const BigInt &oth, BN_CTX *ctx) { - BigInt ret; - BN_mod(ret.n, n, oth.n, ctx); - return ret; + BigInt ret; + BN_mod(ret.n, n, oth.n, ctx); + return ret; } inline Point::Point (Group * g) { - if (g == nullptr) return; - this->group = g; - point = EC_POINT_new(group->ec_group); + if (g == nullptr) return; + this->group = g; + point = EC_POINT_new(group->ec_group); } inline Point::~Point() { - if(point != nullptr) - EC_POINT_free(point); + if(point != nullptr) + EC_POINT_free(point); } inline Point::Point(const Point & p) { - if (p.group == nullptr) return; - this->group = p.group; - point = EC_POINT_new(group->ec_group); - int ret = EC_POINT_copy(point, p.point); - if(ret == 0) error("ECC COPY"); + if (p.group == nullptr) return; + this->group = p.group; + point = EC_POINT_new(group->ec_group); + int ret = EC_POINT_copy(point, p.point); + if(ret == 0) error("ECC COPY"); } inline Point& Point::operator=(Point p) { - std::swap(p.point, point); - std::swap(p.group, group); - return *this; + std::swap(p.point, point); + std::swap(p.group, group); + return *this; } inline void Point::to_bin(unsigned char * buf, size_t buf_len) { - int ret = EC_POINT_point2oct(group->ec_group, point, POINT_CONVERSION_UNCOMPRESSED, buf, buf_len, group->bn_ctx); - if(ret == 0) error("ECC TO_BIN"); + int ret = EC_POINT_point2oct(group->ec_group, point, POINT_CONVERSION_UNCOMPRESSED, buf, buf_len, group->bn_ctx); + if(ret == 0) error("ECC TO_BIN"); } inline size_t Point::size() { - size_t ret = EC_POINT_point2oct(group->ec_group, point, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, group->bn_ctx); - if(ret == 0) error("ECC SIZE_BIN"); - return ret; + size_t ret = EC_POINT_point2oct(group->ec_group, point, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, group->bn_ctx); + if(ret == 0) error("ECC SIZE_BIN"); + return ret; } inline void Point::from_bin(Group * g, const unsigned char * buf, size_t buf_len) { - if (point == nullptr) { - group = g; - point = EC_POINT_new(group->ec_group); - } - int ret = EC_POINT_oct2point(group->ec_group, point, buf, buf_len, group->bn_ctx); - if(ret == 0) error("ECC FROM_BIN"); + if (point == nullptr) { + group = g; + point = EC_POINT_new(group->ec_group); + } + int ret = EC_POINT_oct2point(group->ec_group, point, buf, buf_len, group->bn_ctx); + if(ret == 0) error("ECC FROM_BIN"); } inline Point Point::add(Point & rhs) { - Point ret(group); - int res = EC_POINT_add(group->ec_group, ret.point, point, rhs.point, group->bn_ctx); - if(res == 0) error("ECC ADD"); - return ret; + Point ret(group); + int res = EC_POINT_add(group->ec_group, ret.point, point, rhs.point, group->bn_ctx); + if(res == 0) error("ECC ADD"); + return ret; } inline Point Point::mul(const BigInt &m) { - Point ret (group); - int res = EC_POINT_mul(group->ec_group, ret.point, NULL, point, m.n, group->bn_ctx); - if(res == 0) error("ECC MUL"); - return ret; + Point ret (group); + int res = EC_POINT_mul(group->ec_group, ret.point, NULL, point, m.n, group->bn_ctx); + if(res == 0) error("ECC MUL"); + return ret; } inline Point Point::inv() { - Point ret (*this); - int res = EC_POINT_invert(group->ec_group, ret.point, group->bn_ctx); - if(res == 0) error("ECC INV"); - return ret; + Point ret (*this); + int res = EC_POINT_invert(group->ec_group, ret.point, group->bn_ctx); + if(res == 0) error("ECC INV"); + return ret; } inline bool Point::operator==(Point & rhs) { - int ret = EC_POINT_cmp(group->ec_group, point, rhs.point, group->bn_ctx); - if(ret == -1) error("ECC CMP"); - return (ret == 0); + int ret = EC_POINT_cmp(group->ec_group, point, rhs.point, group->bn_ctx); + if(ret == -1) error("ECC CMP"); + return (ret == 0); } inline Group::Group() { - ec_group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);//NIST P-256 - bn_ctx = BN_CTX_new(); - EC_GROUP_get_order(ec_group, order.n, bn_ctx); - scratch = new unsigned char[scratch_size]; + ec_group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);//NIST P-256 + bn_ctx = BN_CTX_new(); + EC_GROUP_get_order(ec_group, order.n, bn_ctx); + scratch = new unsigned char[scratch_size]; } inline Group::~Group(){ - if(ec_group != nullptr) - EC_GROUP_free(ec_group); + if(ec_group != nullptr) + EC_GROUP_free(ec_group); - if(bn_ctx != nullptr) - BN_CTX_free(bn_ctx); + if(bn_ctx != nullptr) + BN_CTX_free(bn_ctx); - if(scratch != nullptr) - delete[] scratch; + if(scratch != nullptr) + delete[] scratch; } inline void Group::resize_scratch(size_t size) { - if (size > scratch_size) { - delete[] scratch; - scratch_size = size; - scratch = new unsigned char[scratch_size]; - } + if (size > scratch_size) { + delete[] scratch; + scratch_size = size; + scratch = new unsigned char[scratch_size]; + } } inline void Group::get_rand_bn(BigInt & n) { - BN_rand_range(n.n, order.n); + BN_rand_range(n.n, order.n); } inline Point Group::get_generator() { - Point res(this); - int ret = EC_POINT_copy(res.point, EC_GROUP_get0_generator(ec_group)); - if(ret == 0) error("ECC GEN"); - return res; + Point res(this); + int ret = EC_POINT_copy(res.point, EC_GROUP_get0_generator(ec_group)); + if(ret == 0) error("ECC GEN"); + return res; } inline Point Group::mul_gen(const BigInt &m) { - Point res(this); - int ret = EC_POINT_mul(ec_group, res.point, m.n ,NULL, NULL, bn_ctx); - if(ret == 0) error("ECC GEN MUL"); - return res; + Point res(this); + int ret = EC_POINT_mul(ec_group, res.point, m.n ,NULL, NULL, bn_ctx); + if(ret == 0) error("ECC GEN MUL"); + return res; } } #endif \ No newline at end of file diff --git a/src/emp-tool/utils/hash.h b/src/emp-tool/utils/hash.h index ec03f36..9220379 100644 --- a/src/emp-tool/utils/hash.h +++ b/src/emp-tool/utils/hash.h @@ -9,70 +9,70 @@ namespace emp { class Hash { public: - EVP_MD_CTX *mdctx; - char buffer[HASH_BUFFER_SIZE]; - int size = 0; - static const int DIGEST_SIZE = 32; - Hash() { - if((mdctx = EVP_MD_CTX_create()) == NULL) - error("Hash function setup error!"); - if(1 != EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL)) - error("Hash function setup error!"); - } - ~Hash() { - EVP_MD_CTX_destroy(mdctx); - } - void put(const void * data, int nbyte) { - if (nbyte >= HASH_BUFFER_SIZE) - EVP_DigestUpdate(mdctx, data, nbyte); - else if(size + nbyte < HASH_BUFFER_SIZE) { - memcpy(buffer+size, data, nbyte); - size+=nbyte; - } else { - EVP_DigestUpdate(mdctx, buffer, size); - memcpy(buffer, data, nbyte); - size = nbyte; - } - } - void put_block(const block* blk, int nblock=1){ - put(blk, sizeof(block)*nblock); - } - void digest(void * a) { - if(size > 0) { - EVP_DigestUpdate(mdctx, buffer, size); - size=0; - } - uint32_t len = 0; - EVP_DigestFinal_ex(mdctx, (unsigned char *)a, &len); - reset(); - } - void reset() { - EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL); - size=0; - } - static void hash_once(void * dgst, const void * data, int nbyte) { - Hash hash; - hash.put(data, nbyte); - hash.digest(dgst); - } - #ifdef __x86_64__ - __attribute__((target("sse2"))) - #endif - static block hash_for_block(const void * data, int nbyte) { - char digest[DIGEST_SIZE]; - hash_once(digest, data, nbyte); - return _mm_load_si128((__m128i*)&digest[0]); - } + EVP_MD_CTX *mdctx; + char buffer[HASH_BUFFER_SIZE]; + int size = 0; + static const int DIGEST_SIZE = 32; + Hash() { + if((mdctx = EVP_MD_CTX_create()) == NULL) + error("Hash function setup error!"); + if(1 != EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL)) + error("Hash function setup error!"); + } + ~Hash() { + EVP_MD_CTX_destroy(mdctx); + } + void put(const void * data, int nbyte) { + if (nbyte >= HASH_BUFFER_SIZE) + EVP_DigestUpdate(mdctx, data, nbyte); + else if(size + nbyte < HASH_BUFFER_SIZE) { + memcpy(buffer+size, data, nbyte); + size+=nbyte; + } else { + EVP_DigestUpdate(mdctx, buffer, size); + memcpy(buffer, data, nbyte); + size = nbyte; + } + } + void put_block(const block* blk, int nblock=1){ + put(blk, sizeof(block)*nblock); + } + void digest(void * a) { + if(size > 0) { + EVP_DigestUpdate(mdctx, buffer, size); + size=0; + } + uint32_t len = 0; + EVP_DigestFinal_ex(mdctx, (unsigned char *)a, &len); + reset(); + } + void reset() { + EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL); + size=0; + } + static void hash_once(void * dgst, const void * data, int nbyte) { + Hash hash; + hash.put(data, nbyte); + hash.digest(dgst); + } + #ifdef __x86_64__ + __attribute__((target("sse2"))) + #endif + static block hash_for_block(const void * data, int nbyte) { + char digest[DIGEST_SIZE]; + hash_once(digest, data, nbyte); + return _mm_load_si128((__m128i*)&digest[0]); + } - static block KDF(Point &in, uint64_t id = 1) { - size_t len = in.size(); - in.group->resize_scratch(len+8); - unsigned char * tmp = in.group->scratch; - in.to_bin(tmp, len); - memcpy(tmp+len, &id, 8); - block ret = hash_for_block(tmp, len+8); - return ret; - } + static block KDF(Point &in, uint64_t id = 1) { + size_t len = in.size(); + in.group->resize_scratch(len+8); + unsigned char * tmp = in.group->scratch; + in.to_bin(tmp, len); + memcpy(tmp+len, &id, 8); + block ret = hash_for_block(tmp, len+8); + return ret; + } }; } #endif// HASH_H diff --git a/src/emp-tool/utils/mitccrh.h b/src/emp-tool/utils/mitccrh.h index 2507dee..782d18e 100644 --- a/src/emp-tool/utils/mitccrh.h +++ b/src/emp-tool/utils/mitccrh.h @@ -12,51 +12,51 @@ namespace emp { template class MITCCRH { public: - AES_KEY scheduled_key[BatchSize]; - block keys[BatchSize]; - int key_used = BatchSize; - block start_point; - uint64_t gid = 0; + AES_KEY scheduled_key[BatchSize]; + block keys[BatchSize]; + int key_used = BatchSize; + block start_point; + uint64_t gid = 0; - void setS(block sin) { - this->start_point = sin; - } + void setS(block sin) { + this->start_point = sin; + } - void renew_ks(uint64_t gid) { - this->gid = gid; - renew_ks(); - } + void renew_ks(uint64_t gid) { + this->gid = gid; + renew_ks(); + } - void renew_ks() { - for(int i = 0; i < BatchSize; ++i) - keys[i] = start_point ^ makeBlock(gid++, 0); - AES_opt_key_schedule(keys, scheduled_key); - key_used = 0; - } + void renew_ks() { + for(int i = 0; i < BatchSize; ++i) + keys[i] = start_point ^ makeBlock(gid++, 0); + AES_opt_key_schedule(keys, scheduled_key); + key_used = 0; + } - template - void hash_cir(block * blks) { - for(int i = 0; i < K*H; ++i) - blks[i] = sigma(blks[i]); - hash(blks); - } + template + void hash_cir(block * blks) { + for(int i = 0; i < K*H; ++i) + blks[i] = sigma(blks[i]); + hash(blks); + } - template - void hash(block * blks) { - assert(K <= BatchSize); - assert(BatchSize % K == 0); - if(key_used == BatchSize) renew_ks(); + template + void hash(block * blks) { + assert(K <= BatchSize); + assert(BatchSize % K == 0); + if(key_used == BatchSize) renew_ks(); - block tmp[K*H]; - for(int i = 0; i < K*H; ++i) - tmp[i] = blks[i]; - - ParaEnc(tmp, scheduled_key+key_used); - key_used += K; - - for(int i = 0; i < K*H; ++i) - blks[i] = blks[i] ^ tmp[i]; - } + block tmp[K*H]; + for(int i = 0; i < K*H; ++i) + tmp[i] = blks[i]; + + ParaEnc(tmp, scheduled_key+key_used); + key_used += K; + + for(int i = 0; i < K*H; ++i) + blks[i] = blks[i] ^ tmp[i]; + } }; } diff --git a/src/emp-tool/utils/prg.h b/src/emp-tool/utils/prg.h index bb2c2cd..5a2bf19 100644 --- a/src/emp-tool/utils/prg.h +++ b/src/emp-tool/utils/prg.h @@ -16,107 +16,107 @@ namespace emp { class PRG { public: - uint64_t counter = 0; - AES_KEY aes; - block key; - PRG(const void * seed = nullptr, int id = 0) { - if (seed != nullptr) { - reseed((const block *)seed, id); - } else { - block v; + uint64_t counter = 0; + AES_KEY aes; + block key; + PRG(const void * seed = nullptr, int id = 0) { + if (seed != nullptr) { + reseed((const block *)seed, id); + } else { + block v; #ifndef ENABLE_RDSEED - uint32_t * data = (uint32_t *)(&v); - std::random_device rand_div("/dev/urandom"); - for (size_t i = 0; i < sizeof(block) / sizeof(uint32_t); ++i) - data[i] = rand_div(); + uint32_t * data = (uint32_t *)(&v); + std::random_device rand_div("/dev/urandom"); + for (size_t i = 0; i < sizeof(block) / sizeof(uint32_t); ++i) + data[i] = rand_div(); #else - unsigned long long r0, r1; - int i = 0; - // To prevent an AMD CPU bug. (PR #156) - for(; i < 10; ++i) - if((_rdseed64_step(&r0) == 1) && (r0 != ULLONG_MAX) && (r0 != 0)) break; - if(i == 10)error("RDSEED FAILURE"); + unsigned long long r0, r1; + int i = 0; + // To prevent an AMD CPU bug. (PR #156) + for(; i < 10; ++i) + if((_rdseed64_step(&r0) == 1) && (r0 != ULLONG_MAX) && (r0 != 0)) break; + if(i == 10)error("RDSEED FAILURE"); - for(i = 0; i < 10; ++i) - if((_rdseed64_step(&r1) == 1) && (r1 != ULLONG_MAX) && (r1 != 0)) break; - if(i == 10)error("RDSEED FAILURE"); + for(i = 0; i < 10; ++i) + if((_rdseed64_step(&r1) == 1) && (r1 != ULLONG_MAX) && (r1 != 0)) break; + if(i == 10)error("RDSEED FAILURE"); - v = makeBlock(r0, r1); + v = makeBlock(r0, r1); #endif - reseed(&v, id); - } - } - void reseed(const block* seed, uint64_t id = 0) { - block v = _mm_loadu_si128(seed); - v ^= makeBlock(0LL, id); - key = v; - AES_set_encrypt_key(v, &aes); - counter = 0; - } + reseed(&v, id); + } + } + void reseed(const block* seed, uint64_t id = 0) { + block v = _mm_loadu_si128(seed); + v ^= makeBlock(0LL, id); + key = v; + AES_set_encrypt_key(v, &aes); + counter = 0; + } - void random_data(void *data, int nbytes) { - random_block((block *)data, nbytes/16); - if (nbytes % 16 != 0) { - block extra; - random_block(&extra, 1); - memcpy((nbytes/16*16)+(char *) data, &extra, nbytes%16); - } - } + void random_data(void *data, int nbytes) { + random_block((block *)data, nbytes/16); + if (nbytes % 16 != 0) { + block extra; + random_block(&extra, 1); + memcpy((nbytes/16*16)+(char *) data, &extra, nbytes%16); + } + } - void random_bool(bool * data, int length) { - uint8_t * uint_data = (uint8_t*)data; - random_data_unaligned(uint_data, length); - for(int i = 0; i < length; ++i) - data[i] = uint_data[i] & 1; - } + void random_bool(bool * data, int length) { + uint8_t * uint_data = (uint8_t*)data; + random_data_unaligned(uint_data, length); + for(int i = 0; i < length; ++i) + data[i] = uint_data[i] & 1; + } - void random_data_unaligned(void *data, int nbytes) { - size_t size = nbytes; - void *aligned_data = data; - if(std::align(sizeof(block), sizeof(block), aligned_data, size)) { - int chopped = nbytes - size; - random_data(aligned_data, nbytes - chopped); - block tmp[1]; - random_block(tmp, 1); - memcpy(data, tmp, chopped); - } else { - block tmp[2]; - random_block(tmp, 2); - memcpy(data, tmp, nbytes); - } - } + void random_data_unaligned(void *data, int nbytes) { + size_t size = nbytes; + void *aligned_data = data; + if(std::align(sizeof(block), sizeof(block), aligned_data, size)) { + int chopped = nbytes - size; + random_data(aligned_data, nbytes - chopped); + block tmp[1]; + random_block(tmp, 1); + memcpy(data, tmp, chopped); + } else { + block tmp[2]; + random_block(tmp, 2); + memcpy(data, tmp, nbytes); + } + } - void random_block(block * data, int nblocks=1) { - block tmp[AES_BATCH_SIZE]; - for(int i = 0; i < nblocks/AES_BATCH_SIZE; ++i) { - for (int j = 0; j < AES_BATCH_SIZE; ++j) - tmp[j] = makeBlock(0LL, counter++); - AES_ecb_encrypt_blks(tmp, &aes); - memcpy(data + i*AES_BATCH_SIZE, tmp, AES_BATCH_SIZE*sizeof(block)); - } - int remain = nblocks % AES_BATCH_SIZE; - for (int j = 0; j < remain; ++j) - tmp[j] = makeBlock(0LL, counter++); - AES_ecb_encrypt_blks(tmp, remain, &aes); - memcpy(data + (nblocks/AES_BATCH_SIZE)*AES_BATCH_SIZE, tmp, remain*sizeof(block)); - } + void random_block(block * data, int nblocks=1) { + block tmp[AES_BATCH_SIZE]; + for(int i = 0; i < nblocks/AES_BATCH_SIZE; ++i) { + for (int j = 0; j < AES_BATCH_SIZE; ++j) + tmp[j] = makeBlock(0LL, counter++); + AES_ecb_encrypt_blks(tmp, &aes); + memcpy(data + i*AES_BATCH_SIZE, tmp, AES_BATCH_SIZE*sizeof(block)); + } + int remain = nblocks % AES_BATCH_SIZE; + for (int j = 0; j < remain; ++j) + tmp[j] = makeBlock(0LL, counter++); + AES_ecb_encrypt_blks(tmp, remain, &aes); + memcpy(data + (nblocks/AES_BATCH_SIZE)*AES_BATCH_SIZE, tmp, remain*sizeof(block)); + } - typedef uint64_t result_type; - result_type buffer[32]; - size_t ptr = 32; - static constexpr result_type min() { - return 0; - } - static constexpr result_type max() { - return 0xFFFFFFFFFFFFFFFFULL; - } - result_type operator()() { - if(ptr == 32) { - random_block((block*)buffer, 16); - ptr = 0; - } - return buffer[ptr++]; - } + typedef uint64_t result_type; + result_type buffer[32]; + size_t ptr = 32; + static constexpr result_type min() { + return 0; + } + static constexpr result_type max() { + return 0xFFFFFFFFFFFFFFFFULL; + } + result_type operator()() { + if(ptr == 32) { + random_block((block*)buffer, 16); + ptr = 0; + } + return buffer[ptr++]; + } }; } diff --git a/src/emp-tool/utils/prp.h b/src/emp-tool/utils/prp.h index adf0cc9..8c6eee2 100644 --- a/src/emp-tool/utils/prp.h +++ b/src/emp-tool/utils/prp.h @@ -13,28 +13,28 @@ namespace emp { * https://eprint.iacr.org/2013/426.pdf */ class PRP { public: - AES_KEY aes; + AES_KEY aes; - PRP(const char * key = nullptr) { - if(key == nullptr) - aes_set_key(zero_block); - } + PRP(const char * key = nullptr) { + if(key == nullptr) + aes_set_key(zero_block); + } - PRP(const block& key) { - aes_set_key(key); - } + PRP(const block& key) { + aes_set_key(key); + } - void aes_set_key(const block& v) { - AES_set_encrypt_key(v, &aes); - } + void aes_set_key(const block& v) { + AES_set_encrypt_key(v, &aes); + } - void permute_block(block *data, int nblocks) { - for(int i = 0; i < nblocks/AES_BATCH_SIZE; ++i) { - AES_ecb_encrypt_blks(data + i*AES_BATCH_SIZE, &aes); - } - int remain = nblocks % AES_BATCH_SIZE; - AES_ecb_encrypt_blks(data + nblocks - remain, remain, &aes); - } + void permute_block(block *data, int nblocks) { + for(int i = 0; i < nblocks/AES_BATCH_SIZE; ++i) { + AES_ecb_encrypt_blks(data + i*AES_BATCH_SIZE, &aes); + } + int remain = nblocks % AES_BATCH_SIZE; + AES_ecb_encrypt_blks(data + nblocks - remain, remain, &aes); + } }; } #endif// PRP_H diff --git a/src/emp-tool/utils/utils.hpp b/src/emp-tool/utils/utils.hpp index 13e2cef..78fefce 100644 --- a/src/emp-tool/utils/utils.hpp +++ b/src/emp-tool/utils/utils.hpp @@ -1,35 +1,35 @@ template -void run_function(void *function, const Ts&... args) { - reinterpret_cast(function)(args...); +void run_function(void *function, const Ts&... args) { + reinterpret_cast(function)(args...); } template void inline delete_array_null(T * ptr){ - if(ptr != nullptr) { - delete[] ptr; - ptr = nullptr; - } + if(ptr != nullptr) { + delete[] ptr; + ptr = nullptr; + } } -inline time_point clock_start() { - return high_resolution_clock::now(); +inline time_point clock_start() { + return high_resolution_clock::now(); } inline double time_from(const time_point& s) { - return std::chrono::duration_cast(high_resolution_clock::now() - s).count(); + return std::chrono::duration_cast(high_resolution_clock::now() - s).count(); } inline void error(const char * s, int line, const char * file) { - fprintf(stderr, s, "\n"); - if(file != nullptr) { - fprintf(stderr, "at %d, %s\n", line, file); - } - exit(1); + fprintf(stderr, s, "\n"); + if(file != nullptr) { + fprintf(stderr, "at %d, %s\n", line, file); + } + exit(1); } inline void parse_party_and_port(const char *const * arg, int * party, int * port) { - *party = atoi (arg[1]); - *port = atoi (arg[2]); + *party = atoi (arg[1]); + *port = atoi (arg[2]); } template @@ -45,10 +45,10 @@ inline T bool_to_int(const bool *data) { template inline void int_to_bool(bool * data, T input, int len) { - for (int i = 0; i < len; ++i) { - data[i] = (input & 1)==1; - input >>= 1; - } + for (int i = 0; i < len; ++i) { + data[i] = (input & 1)==1; + input >>= 1; + } } @@ -57,9 +57,9 @@ inline void int_to_bool(bool * data, T input, int len) { // (does not mutate the memory to which input points) template inline void to_bool(bool * data, const T * input, const int len, const bool reverse) { - for (int i = 0; i < len; ++i) { - data[reverse ? len - i : i] = (bool) ((((uint8_t *) input)[i / 8] & (((uint8_t) 1) << (i % 8))) != 0); - } + for (int i = 0; i < len; ++i) { + data[reverse ? len - i : i] = (bool) ((((uint8_t *) input)[i / 8] & (((uint8_t) 1) << (i % 8))) != 0); + } } // Set the first len bits wherever output points to to be the first len bools from array data. @@ -68,29 +68,29 @@ inline void to_bool(bool * data, const T * input, const int len, const bool reve // assumes that if x is a bool, then ((uint8_t) x) is either 1 or 0. template inline void from_bool(const bool * data, T * output, const int len, const bool reverse) { - for (int i = 0; i < len; ++i) { + for (int i = 0; i < len; ++i) { ((uint8_t *) output)[i / 8] &= (~(((uint8_t) 1) << (i % 8))); // sets bit to 0 ((uint8_t *) output)[i / 8] |= (((uint8_t) data[reverse ? len - i : i]) << (i % 8)); // sets bit to bool[i] - } + } } inline block bool_to_block(const bool * data) { - return makeBlock(bool_to_int(data+64), bool_to_int(data)); + return makeBlock(bool_to_int(data+64), bool_to_int(data)); } inline void block_to_bool(bool * data, block b) { - uint64_t* ptr = (uint64_t*)(&b); - int_to_bool(data, ptr[0], 64); - int_to_bool(data+64, ptr[1], 64); + uint64_t* ptr = (uint64_t*)(&b); + int_to_bool(data, ptr[0], 64); + int_to_bool(data+64, ptr[1], 64); } inline bool file_exists(const std::string &name) { - if (FILE *file = fopen(name.c_str(), "r")) { - fclose(file); - return true; - }else return false; + if (FILE *file = fopen(name.c_str(), "r")) { + fclose(file); + return true; + }else return false; }