mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
962 lines
24 KiB
C++
962 lines
24 KiB
C++
/*
|
|
* Party.cpp
|
|
*
|
|
*/
|
|
|
|
#include "Party.h"
|
|
|
|
|
|
#include <sys/types.h>
|
|
#include <sys/stat.h>
|
|
#include <fcntl.h>
|
|
#include <stdlib.h>
|
|
|
|
#include "Tools/callgrind.h"
|
|
|
|
#include "proto_utils.h"
|
|
#include "msg_types.h"
|
|
#include "prf.h"
|
|
#include "BooleanCircuit.h"
|
|
#include "Math/Setup.h"
|
|
|
|
#ifdef __PURE_SHE__
|
|
#include "mpirxx.h"
|
|
#endif
|
|
|
|
ProgramParty* ProgramParty::singleton = 0;
|
|
|
|
|
|
BaseParty::BaseParty(party_id_t id) : _id(id)
|
|
{
|
|
#ifdef DEBUG_PRNG_PARTY
|
|
octet seed[SEED_SIZE];
|
|
memset(seed, 0, sizeof(seed));
|
|
seed[0] = id;
|
|
prng.SetSeed(seed);
|
|
#endif
|
|
}
|
|
|
|
BaseParty::~BaseParty()
|
|
{
|
|
}
|
|
|
|
Party::Party(const char* netmap_file, // required to init Node
|
|
const char* circuit_file, // required to init BooleanCircuit
|
|
party_id_t id,
|
|
const std::string input,
|
|
int numthreads,
|
|
int numtries
|
|
) :BaseParty(id),
|
|
_all_input(input),
|
|
_NUMTHREADS(numthreads),
|
|
_NUMTRIES(numtries)
|
|
{
|
|
_circuit = new BooleanCircuit( circuit_file );
|
|
_circuit->party = this;
|
|
_G = _circuit->NumGates();
|
|
#ifndef N_PARTIES
|
|
_N = _circuit->NumParties();
|
|
#endif
|
|
_W = _circuit->NumWires();
|
|
_OW = _circuit->NumOutWires();
|
|
_IO = _circuit->_num_input_wires;
|
|
resize_registers();
|
|
#ifdef __PURE_SHE__
|
|
init_modulos();
|
|
#endif
|
|
reset();
|
|
_generate_prf_inputs();
|
|
init(netmap_file, id, _N);
|
|
_num_externals_msg_received = {0};//ATOMIC_VAR_INIT(0);
|
|
_num_inputkeys_msg_received = {0};//ATOMIC_VAR_INIT(0);
|
|
}
|
|
|
|
void Party::_initialize_input()
|
|
{
|
|
party_t me = _circuit->_parties[_id];
|
|
std::string my_input = _all_input.substr(me.wires, me.n_wires);
|
|
_input.resize(me.n_wires);
|
|
const char* input = my_input.c_str();
|
|
memcpy(_input.data(), input, me.n_wires);
|
|
|
|
for(size_t i=0; i<me.n_wires; i++) {
|
|
_input[i]-=0x30;
|
|
}
|
|
#ifdef DEBUG
|
|
printf("inputs:\n");
|
|
phex(_input.data(), me.n_wires);
|
|
#endif
|
|
}
|
|
|
|
Party::~Party() {
|
|
}
|
|
|
|
void BaseParty::NodeReady()
|
|
{
|
|
printf("Node is ready\n");
|
|
}
|
|
|
|
void BaseParty::NewMessage(int from, ReceivedMsg& msg)
|
|
{
|
|
char* message = msg.data();
|
|
size_t len = msg.size();
|
|
// printf("got message of len %u from %d\n", len, from);
|
|
MSG_TYPE message_type;
|
|
msg.unserialize(message_type);
|
|
#ifdef DEBUG_STEPS
|
|
cout << "processing " << message_type_names[message_type] << " from " << dec
|
|
<< from << " of length " << msg.size() << endl;
|
|
#endif
|
|
unique_lock<mutex> locker(global_lock);
|
|
switch(message_type) {
|
|
case TYPE_KEYS:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_KEYS\n");
|
|
#endif
|
|
#ifdef DEBUG2
|
|
cout << "received keys" << endl;
|
|
phex(message, len);
|
|
#endif
|
|
get_buffer(TYPE_PRF_OUTPUTS);
|
|
_compute_prfs_outputs((Key*)(message + MSG_KEYS_HEADER_SZ));
|
|
_send_prfs();
|
|
// printf("sent prfs\n");
|
|
break;
|
|
}
|
|
case TYPE_MASK_INPUTS:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_MASK_INPUTS\n");
|
|
#endif
|
|
#ifdef DEBUG_INPUT
|
|
cout << "received " << msg.left() << " input masks" << endl;
|
|
#endif
|
|
mask_input(msg);
|
|
break;
|
|
}
|
|
case TYPE_MASK_OUTPUT:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_MASK_OUTPUT\n");
|
|
#endif
|
|
#ifdef DEBUG_OUTPUT_MASKS
|
|
cout << "receiving " << msg.left() << " output masks" << endl;
|
|
#endif
|
|
mask_output(msg);
|
|
break;
|
|
}
|
|
case TYPE_GARBLED_CIRCUIT:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_GARBLED_CIRCUIT\n");
|
|
#endif
|
|
#ifdef DEBUG2
|
|
phex(message, len);
|
|
#endif
|
|
#ifdef DEBUG_COMM
|
|
cout << "got " << len << " bytes for " << get_garbled_tbl_size() << " gates" << endl;
|
|
#endif
|
|
if ((len - 4) != 4 * _N * sizeof(Key) * get_garbled_tbl_size())
|
|
throw runtime_error("wrong size of garbled table");
|
|
store_garbled_circuit(msg);
|
|
|
|
// printf("\nGarbled Table\n\n");
|
|
// _printf_garbled_table();
|
|
// char garbled_circuit_cs = cs((char*)_garbled_tbl , garbled_tbl_sz);
|
|
// printf ("\ngarbled_circuit_cs = %d\n", garbled_circuit_cs);
|
|
|
|
_node->Send(SERVER_ID, get_buffer(TYPE_RECEIVED_GC));
|
|
|
|
break;
|
|
}
|
|
case TYPE_EXTERNAL_VALUES:
|
|
{
|
|
//this is done by party 1 only
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_EXTERNAL_VALUES from %d\n",from);
|
|
#endif
|
|
_process_external_received(message+sizeof(MSG_TYPE), from);
|
|
break;
|
|
}
|
|
case TYPE_ALL_EXTERNAL_VALUES:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_ALL_EXTERNAL_VALUES\n");
|
|
#endif
|
|
_process_all_external_received(message+sizeof(MSG_TYPE));
|
|
break;
|
|
}
|
|
case TYPE_KEY_PER_IN_WIRE:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_KEY_PER_IN_WIRE from %d\n", from);
|
|
#endif
|
|
_process_input_keys((Key*)(message+INPUT_KEYS_MSG_TYPE_SIZE), from);
|
|
break;
|
|
}
|
|
case TYPE_ALL_KEYS_PER_IN_WIRE:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_ALL_KEYS_PER_IN_WIRE\n");
|
|
#endif
|
|
_process_all_input_keys(message+INPUT_KEYS_MSG_TYPE_SIZE);
|
|
break;
|
|
}
|
|
case TYPE_LAUNCH_ONLINE:
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
printf("TYPE_LAUNCH_ONLINE\n");
|
|
#endif
|
|
cout << "Launching online phase at " << timer.elapsed() << endl;
|
|
_node->print_waiting();
|
|
// store a token item in case it's needed just before ending the program
|
|
ReceivedMsg msg;
|
|
// to avoid copying from address 0
|
|
msg.resize(1);
|
|
// for valgrind
|
|
msg.data()[0] = 0;
|
|
// token input round
|
|
store_garbled_circuit(msg);
|
|
online_timer.start();
|
|
_start_online_net = GET_TIME();
|
|
start_online_round();
|
|
break;
|
|
}
|
|
case TYPE_CHECKSUM:
|
|
{
|
|
printf("TYPE_CHECKSUM\n");
|
|
printf("got checksum = %d\n", message[sizeof(MSG_TYPE)]);
|
|
break;
|
|
}
|
|
case TYPE_SPDZ_WIRES:
|
|
receive_spdz_wires(msg);
|
|
break;
|
|
case TYPE_DELTA:
|
|
msg.unserialize(delta);
|
|
break;
|
|
default:
|
|
{
|
|
printf("UNDEFINED\n");
|
|
printf("got undefined message\n");
|
|
phex(message, len);
|
|
break;
|
|
}
|
|
}
|
|
|
|
#ifdef DEBUG_STEPS
|
|
cout << "done with " << message_type_names[message_type] << " from " << from << endl;
|
|
#endif
|
|
}
|
|
|
|
void Party::start_online_round()
|
|
{
|
|
_generate_external_values_msg();
|
|
// _node->Broadcast2(_external_values_msg, _external_values_msg_sz);
|
|
if(_id!=1) {
|
|
// printf("sending my externals:\n");
|
|
// phex(_external_values_msg+sizeof(MSG_TYPE), _circuit->Party(_id).n_wires);
|
|
_node->Send(1,_external_values_msg);
|
|
}
|
|
}
|
|
|
|
void Party::store_garbled_circuit(ReceivedMsg& msg)
|
|
{
|
|
auto end = get_garbled_tbl_end();
|
|
for (auto it = _garbled_tbl.begin(); it != end; it++)
|
|
{
|
|
it->unserialize(msg, _N);
|
|
#ifdef DEBUG
|
|
it->print();
|
|
#endif
|
|
}
|
|
}
|
|
|
|
void Party::_check_evaluate()
|
|
{
|
|
load_garbled_circuit();
|
|
_end_online_net = GET_TIME();
|
|
printf("Network time: ");
|
|
PRINT_DIFF(&_start_online_net, &_end_online_net);
|
|
#ifdef __PRIME_FIELD__
|
|
printf("this implementation uses PRIME FIELD\n");
|
|
#endif
|
|
|
|
printf("\n\npress for EVALUATING\n\n");
|
|
getchar();
|
|
|
|
vector<exec_props_t> execs(_NUMTHREADS+1);
|
|
unsigned long diff;
|
|
for (int ntry=1; ntry<=_NUMTRIES; ntry++) {
|
|
for(int nthreads=1; nthreads<=_NUMTHREADS; nthreads+=1) {
|
|
// _garbled_tbl = _garbled_tbl_copy;
|
|
// printf("num threads = %d\n", nthreads);
|
|
|
|
// CALLGRIND_START_INSTRUMENTATION;
|
|
struct timeval b;
|
|
if(nthreads == 1) {
|
|
b = GET_TIME();
|
|
// printf("Linear evaluation\n");
|
|
_circuit->EvaluateByLayerLinearly(_id);
|
|
} else {
|
|
b = GET_TIME();
|
|
// printf("Non linear evaluation\n");
|
|
_circuit->EvaluateByLayer(nthreads, _id);
|
|
}
|
|
// CALLGRIND_STOP_INSTRUMENTATION;
|
|
CALLGRIND_DUMP_STATS;
|
|
struct timeval a = GET_TIME();
|
|
// _circuit->Output();
|
|
|
|
// printf("Eval time: ");
|
|
diff = GET_DIFF(&b, &a);
|
|
// printf("#threads = %d, time = %lu\n", nthreads, diff);
|
|
execs[nthreads].acc += diff;
|
|
if (diff < execs[nthreads].min || execs[nthreads].min==0) {
|
|
execs[nthreads].min = diff;
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef DEBUG
|
|
_circuit->Output();
|
|
#endif
|
|
|
|
//getting the best results
|
|
unsigned long minmin = execs[1].min;
|
|
unsigned long long avgmin = execs[1].acc;
|
|
int best_minmin = 1;
|
|
int best_avgmin = 1;
|
|
for(int nthreads=1; nthreads<=_NUMTHREADS; nthreads++) {
|
|
printf("#nthreads = %d, min = %lu, acc = %llu\n",nthreads, execs[nthreads].min, execs[nthreads].acc);
|
|
if(execs[nthreads].min < minmin) {
|
|
minmin = execs[nthreads].min;
|
|
best_minmin = nthreads;
|
|
}
|
|
if(execs[nthreads].acc < avgmin) {
|
|
avgmin = execs[nthreads].acc;
|
|
best_avgmin = nthreads;
|
|
}
|
|
}
|
|
printf("\nRESULTS SUM:\n");
|
|
printf("Got minimal time with %d threads: %ld\n", best_minmin, minmin);
|
|
printf("Got minimal AVERAGE time with %d threads: %llu\n", best_avgmin, avgmin/_NUMTRIES);
|
|
fflush(0);
|
|
done();
|
|
}
|
|
|
|
void BaseParty::Start()
|
|
{
|
|
_node->Start();
|
|
}
|
|
|
|
|
|
void Party::_generate_prf_inputs()
|
|
{
|
|
resize_garbled_tbl();
|
|
for (gate_id_t g=1; g<=_G; g++)
|
|
{
|
|
_garbled_tbl[g-1].init_inputs(g, _N);
|
|
}
|
|
}
|
|
|
|
void Party::_compute_prfs_outputs(Key* keys)
|
|
{
|
|
receive_keys(keys);
|
|
for(gate_id_t g=1; g<=_G; g++) {
|
|
const Register* in_wires[2] = {®isters[_circuit->_gates[g]._left], ®isters[_circuit->_gates[g]._right]};
|
|
_garbled_tbl[g-1].compute_prfs_outputs(in_wires, _id, buffers[TYPE_PRF_OUTPUTS], g);
|
|
}
|
|
printf("\n\n");
|
|
}
|
|
|
|
void BaseParty::_send_prfs() {
|
|
_node->Send(SERVER_ID, buffers[TYPE_PRF_OUTPUTS]);
|
|
#ifdef DEBUG2
|
|
printf("Send PRFs:\n");
|
|
phex(buffers[TYPE_PRF_OUTPUTS]);
|
|
#endif
|
|
}
|
|
|
|
void Party::_print_keys()
|
|
{
|
|
for (wire_id_t w=0; w<_IO; w++) {
|
|
registers[w].keys.print(w, _id);
|
|
}
|
|
}
|
|
|
|
|
|
void Party::_printf_garbled_table()
|
|
{
|
|
for (gate_id_t g=1; g<=get_garbled_tbl_size(); g++) {
|
|
std::cout << "gate " << g << std::endl;
|
|
for(int entry=0; entry<4; entry++) {
|
|
for(party_id_t i=1; i<=_N; i++) {
|
|
KeyVector& k = _garbled_entry(g, entry);
|
|
std::cout << k[i-1] << " ";
|
|
}
|
|
std::cout << std::endl;
|
|
}
|
|
std::cout << std::endl;
|
|
}
|
|
}
|
|
|
|
void Party::_print_keys_of_party(Key *keys, int id)
|
|
{
|
|
printf("\nkeys for party %d", id);
|
|
for(wire_id_t w=0; w<_IO; w++) {
|
|
printf("k^%d_%lu: ", id, w);
|
|
std::cout << keys[w] << std::endl;
|
|
}
|
|
}
|
|
|
|
void Party::_print_input_keys_msg()
|
|
{
|
|
for(wire_id_t w=0; w<_IO; w++) {
|
|
std::cout << *(Key*)(_input_wire_keys_msg.data()+INPUT_KEYS_MSG_TYPE_SIZE+w*sizeof(Key)) << std::endl;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void Party::_generate_external_values_msg()
|
|
{
|
|
_initialize_input();
|
|
prepare_input_regs(_id);
|
|
_external_values_msg.clear();
|
|
fill_message_type(_external_values_msg, TYPE_EXTERNAL_VALUES);
|
|
unsigned int n = (*input_regs)[_id].size();
|
|
#ifdef DEBUG_ROUNDS
|
|
cout << dec << _input.size() << "/" << n << " inputs" << " in round "
|
|
<< input_regs.get_i() << endl;
|
|
#endif
|
|
if (n != _input.size())
|
|
throw runtime_error("number of inputs doesn't match");
|
|
for(unsigned int i=0; i<n; i++) {
|
|
if (input_mask == input_masks.end())
|
|
throw runtime_error("not enough masks");
|
|
char val = *input_mask^_input[i];
|
|
input_mask++;
|
|
_external_values_msg.push_back(val);
|
|
get_reg((*input_regs)[_id][i]).set_external(val);
|
|
}
|
|
#ifdef DEBUG_VALUES
|
|
cout << "on registers:" << endl;
|
|
for(unsigned int i=0; i<n; i++)
|
|
cout << (*input_regs)[_id][i] << " ";
|
|
cout << endl;
|
|
cout << "inputs:\t\t\t";
|
|
print_bit_array(_input.data(), n);
|
|
cout << "masks:\t\t\t";
|
|
print_bit_array(&*(input_mask - n), n);
|
|
cout << "externals:\t\t";
|
|
print_bit_array(&_external_values_msg[4], n);
|
|
#endif
|
|
|
|
#ifdef DEBUG2
|
|
phex(&*(input_mask - n), n);
|
|
phex(_input.data(), n);
|
|
phex(_external_values_msg.data()+sizeof(MSG_TYPE), n);
|
|
#endif
|
|
}
|
|
|
|
void Party::_process_external_received(char* externals, party_id_t from)
|
|
{
|
|
prepare_input_regs(from);
|
|
|
|
// printf("received externals from %d\n", from);
|
|
// phex(externals, n);
|
|
for(unsigned int i=0; i<(*input_regs)[from].size(); i++) {
|
|
get_reg((*input_regs)[from][i]).set_external(externals[i]);
|
|
}
|
|
#ifdef DEBUG_VALUES
|
|
cout << "externals from " << from << ":\t";
|
|
print_bit_array(externals, (*input_regs)[from].size());
|
|
cout << "masks:\t\t\t";
|
|
print_masks((*input_regs)[from]);
|
|
cout << "outputs:\t\t";
|
|
print_masks((*input_regs)[from]);
|
|
cout << "on registers:" << endl;
|
|
print_indices((*input_regs)[from]);
|
|
#endif
|
|
|
|
size_t num_received;
|
|
{
|
|
std::unique_lock<std::mutex> locker(_process_externals_mx);
|
|
num_received = ++_num_externals_msg_received;
|
|
}
|
|
if(num_received == _N-1) {
|
|
_num_externals_msg_received = 0;
|
|
SendBuffer& buffer = get_buffer(TYPE_ALL_EXTERNAL_VALUES);
|
|
int w = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
vector<int>& regs = (*input_regs)[i+1];
|
|
for (unsigned int j = 0; j < regs.size(); j++)
|
|
{
|
|
buffer.push_back(registers[regs[j]].get_external());
|
|
w++;
|
|
}
|
|
}
|
|
_node->Broadcast2(buffer);
|
|
// printf("sending all externals\n");
|
|
// phex(all_externals+sizeof(MSG_TYPE), _IO);
|
|
}
|
|
}
|
|
|
|
void Party::_process_all_external_received(char* externals)
|
|
{
|
|
// phex(message+sizeof(MSG_TYPE), _IO);
|
|
_input_wire_keys_msg.clear();
|
|
fill_message_type(_input_wire_keys_msg, TYPE_KEY_PER_IN_WIRE);
|
|
_input_wire_keys_msg.resize(INPUT_KEYS_MSG_TYPE_SIZE);
|
|
|
|
int w = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
vector<int>& regs = (*input_regs)[i+1];
|
|
for (unsigned int j = 0; j < regs.size(); j++)
|
|
{
|
|
get_reg(regs[j]).set_external(externals[w]);
|
|
get_reg(regs[j]).external_key(_id).serialize(_input_wire_keys_msg);
|
|
#ifdef DEBUG
|
|
printf("k^%d_{%u,%d}=", _id,w,registers[regs[j]].get_external());
|
|
std::cout << registers[regs[j]].external_key(_id) << std::endl;
|
|
#endif
|
|
w++;
|
|
}
|
|
#ifdef DEBUG_VALUES
|
|
cout << "on registers:" << endl;
|
|
for (unsigned j = 0; j < (*input_regs)[i + 1].size(); j++)
|
|
cout << (*input_regs)[i + 1][j] << " ";
|
|
cout << endl;
|
|
cout << "externals from " << (i + 1) << ":\t";
|
|
print_bit_array(externals, regs.size());
|
|
cout << "masks:\t\t\t";
|
|
print_masks(output_regs);
|
|
cout << "outputs:\t\t";
|
|
print_outputs(output_regs);
|
|
#endif
|
|
}
|
|
|
|
_node->Send(1, _input_wire_keys_msg);
|
|
#ifdef DEBUG2
|
|
printf("input wire keys:\n");
|
|
phex(_input_wire_keys_msg);
|
|
#endif
|
|
}
|
|
|
|
//void Party::_process_input_keys(char* keys, party_id_t from)
|
|
//{
|
|
//// printf("keys from party: %d\n",from);
|
|
// for(wire_id_t w=0; w<_IO; w++) {
|
|
// unsigned int offset = w*2*_N+_circuit->_externals[w]*_N+from-1;
|
|
// _circuit->_keys[offset] = *(Key*)(keys+w*sizeof(Key));
|
|
//// printf("received key k^%d_{%u,%d}=", from,w,_circuit->_externals[w]);
|
|
//// std::cout <<_circuit->_keys[offset] << std::endl;
|
|
// }
|
|
//}
|
|
|
|
void Party::_process_input_keys(Key* keys, party_id_t from)
|
|
{
|
|
// printf("keys from party: %d\n",from);
|
|
// for(wire_id_t w=0; w<_IO; w++) {
|
|
// printf("k^%d_{%u,%d}=", from,w,_circuit->_externals[w]);
|
|
// std::cout << keys[w] << std::endl;
|
|
// }
|
|
|
|
int w = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
for (unsigned int j = 0; j < (*input_regs)[i+1].size(); j++)
|
|
{
|
|
get_reg((*input_regs)[i+1][j]).set_external_key(from, keys[w]);
|
|
w++;
|
|
}
|
|
}
|
|
|
|
size_t num_received;
|
|
{
|
|
std::unique_lock<std::mutex> locker(_process_keys_mx);
|
|
num_received = ++_num_inputkeys_msg_received;
|
|
}
|
|
// printf("num_received = %d\n",num_received);
|
|
if(num_received == _N-1)
|
|
{
|
|
_num_inputkeys_msg_received = 0;
|
|
#ifdef DEBUG_STEPS
|
|
printf("received input keys from everyone\n");
|
|
#endif
|
|
SendBuffer& buffer = get_buffer(TYPE_ALL_KEYS_PER_IN_WIRE);
|
|
buffer.resize(INPUT_KEYS_MSG_TYPE_SIZE);
|
|
int w = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
for (unsigned int j = 0; j < (*input_regs)[i+1].size(); j++)
|
|
{
|
|
get_reg((*input_regs)[i+1][j]).keys.serialize(buffer);
|
|
w++;
|
|
}
|
|
}
|
|
#ifdef DEBUG2
|
|
printf("all input keys:\n");
|
|
_print_input_keys_checksum();
|
|
phex(buffer);
|
|
#endif
|
|
_node->Broadcast2(buffer);
|
|
_check_evaluate();
|
|
}
|
|
|
|
}
|
|
|
|
|
|
void Party::_process_all_input_keys(char* keys)
|
|
{
|
|
/* keys: a block containing 2*_N keys for each input wire so the
|
|
* receiver only has to copy one time.
|
|
*/
|
|
int w = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
for (unsigned int j = 0; j < (*input_regs)[i+1].size(); j++)
|
|
{
|
|
get_reg((*input_regs)[i+1][j]).set_eval_keys((Key*)keys+w*2*_N, _N, _id - 1);
|
|
w++;
|
|
}
|
|
}
|
|
|
|
#ifdef DEBUG
|
|
printf("all input keys:\n");
|
|
// phex(message+INPUT_KEYS_MSG_TYPE_SIZE, 2*_N*_IO*sizeof(Key));
|
|
_print_input_keys_checksum();
|
|
#endif
|
|
_check_evaluate();
|
|
}
|
|
|
|
void Party::_print_input_keys_checksum()
|
|
{
|
|
int w = 0;
|
|
for (unsigned int i = 0; i < (*input_regs).size(); i++)
|
|
for (unsigned int j = 0; j < (*input_regs)[i].size(); j++)
|
|
registers[(*input_regs)[i][j]].print_input(w++);
|
|
}
|
|
|
|
void Party::mask_output(ReceivedMsg& msg)
|
|
{
|
|
char* masks = msg.consume(0);
|
|
size_t n_masks = msg.left();
|
|
prepare_output_regs();
|
|
if (n_masks != output_regs.size())
|
|
throw runtime_error("number of masks doesn't match");
|
|
for (unsigned int i = 0; i < output_regs.size(); i++)
|
|
get_reg(output_regs[i]).set_mask(masks[i]);
|
|
#ifdef DEBUG_VALUES
|
|
cout << "received output masks in registers:" << endl;
|
|
vector<char> externals(output_regs.size()), outputs(output_regs.size());
|
|
for (unsigned int i = 0; i < output_regs.size(); i++)
|
|
{
|
|
externals[i] = get_reg(output_regs[i]).get_external_no_check();
|
|
outputs[i] = get_reg(output_regs[i]).get_output_no_check();
|
|
cout << output_regs[i] << " ";
|
|
}
|
|
cout << endl;
|
|
cout << "masks:\t\t\t";
|
|
print_bit_array(masks, output_regs.size());
|
|
cout << "externals:\t\t";
|
|
print_bit_array(externals);
|
|
cout << "outputs:\t\t";
|
|
print_bit_array(outputs);
|
|
#endif
|
|
}
|
|
|
|
void Party::mask_input(ReceivedMsg& msg)
|
|
{
|
|
input_masks.insert(input_masks.end(), msg.data() + sizeof(MSG_TYPE), msg.data() + msg.size());
|
|
#ifdef DEBUG_COMM
|
|
cout << "got " << dec << input_masks.size() << " input masks" << endl;
|
|
#endif
|
|
input_mask = input_masks.begin();
|
|
}
|
|
|
|
void Party::receive_keys(Key* keys) {
|
|
resize_registers();
|
|
for (size_t w = 0; w < _W; w++)
|
|
{
|
|
registers[w].init(_N);
|
|
for (int i = 0; i < 2; i++)
|
|
registers[w].keys[i][_id - 1] = *(keys + w * 2 + i);
|
|
}
|
|
#ifdef DEBUG
|
|
_print_keys();
|
|
#endif
|
|
}
|
|
|
|
int Party::get_n_inputs() {
|
|
int res = 0;
|
|
for (size_t i = 0; i < _N; i++)
|
|
{
|
|
prepare_input_regs(i + 1);
|
|
vector<int>& regs = (*input_regs)[i+1];
|
|
res += regs.size();
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void BaseParty::done() {
|
|
cout << "Online phase took " << online_timer.elapsed() << " seconds" << endl;
|
|
_node->Send(SERVER_ID, get_buffer(TYPE_DONE));
|
|
_node->Stop();
|
|
}
|
|
|
|
ProgramParty::ProgramParty(int argc, char** argv) :
|
|
BaseParty(-1), keys_for_prf(0),
|
|
spdz_storage(0), garbled_storage(0), spdz_counters(SPDZ_OP_N),
|
|
machine(dynamic_memory),
|
|
processor(machine), prf_machine(dynamic_memory),
|
|
prf_processor(prf_machine),
|
|
MC(0)
|
|
{
|
|
if (argc < 3)
|
|
{
|
|
cerr << "Usage: " << argv[0] << " <id> <program> [netmap]" << endl;
|
|
exit(1);
|
|
}
|
|
|
|
_id = atoi(argv[1]);
|
|
program.parse(string(argv[2]) + "-0");
|
|
machine.reset(program);
|
|
processor.reset(program);
|
|
processor.open_input_file("user_inputs/user_" + to_string(_id - 1) + "_input.txt");
|
|
prf_machine.reset(*reinterpret_cast<GC::Program<GC::Secret<PRFRegister> >* >(&program));
|
|
prf_processor.reset(*reinterpret_cast<GC::Program<GC::Secret<PRFRegister> >* >(&program));
|
|
if (singleton)
|
|
throw runtime_error("there can only be one");
|
|
singleton = this;
|
|
if (argc > 3)
|
|
{
|
|
int n_parties = init(argv[3], _id);
|
|
ifstream netmap(argv[3]);
|
|
int tmp;
|
|
string tmp2;
|
|
netmap >> tmp >> tmp2 >> tmp;
|
|
vector<string> hostnames(n_parties);
|
|
for (int i = 0; i < n_parties; i++)
|
|
{
|
|
netmap >> hostnames[i];
|
|
netmap >> tmp;
|
|
}
|
|
N.init(_id - 1, 5000, hostnames);
|
|
}
|
|
else
|
|
{
|
|
int n_parties = init("LOOPBACK", _id);
|
|
N.init(_id - 1, 5000, vector<string>(n_parties, "localhost"));
|
|
}
|
|
prf_output = (char*)new __m128i[PAD_TO_8(get_n_parties())];
|
|
mac_key = prng.get_word() & ((1ULL << GC::Secret<EvalRegister>::default_length) - 1);
|
|
cout << "MAC key: " << hex << mac_key << endl;
|
|
ifstream schfile((string("Programs/Schedules/") + argv[2] + ".sch").c_str());
|
|
string curr, prev;
|
|
while (schfile.good())
|
|
{
|
|
prev = curr;
|
|
getline(schfile, curr);
|
|
}
|
|
cout << "Compiler: " << prev << endl;
|
|
P = new PlainPlayer(N, 0);
|
|
if (argc > 4)
|
|
threshold = atoi(argv[4]);
|
|
else
|
|
threshold = 128;
|
|
cout << "Threshold for multi-threaded evaluation: " << threshold << endl;
|
|
eval_threads = new Worker<AndJob>[N_EVAL_THREADS];
|
|
and_jobs.resize(N_EVAL_THREADS);
|
|
}
|
|
|
|
ProgramParty::~ProgramParty()
|
|
{
|
|
reset();
|
|
delete[] prf_output;
|
|
delete P;
|
|
if (MC)
|
|
delete MC;
|
|
delete[] eval_threads;
|
|
cout << "SPDZ loading: " << spdz_counters[SPDZ_LOAD] << endl;
|
|
cout << "SPDZ storing: " << spdz_counters[SPDZ_STORE] << endl;
|
|
cout << "SPDZ wire storage: " << 1e-9 * spdz_storage << " GB" << endl;
|
|
cout << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity() *
|
|
sizeof(GC::Secret<EvalRegister>::DynamicType) << " GB" << endl;
|
|
cout << "Maximum circuit storage: " << 1e-9 * garbled_storage << " GB" << endl;
|
|
}
|
|
|
|
void ProgramParty::_compute_prfs_outputs(Key* keys)
|
|
{
|
|
keys_for_prf = keys;
|
|
first_phase(program, prf_processor, prf_machine);
|
|
}
|
|
|
|
void ProgramParty::reset()
|
|
{
|
|
CommonParty::reset();
|
|
}
|
|
|
|
void ProgramParty::store_garbled_circuit(ReceivedMsg& msg)
|
|
{
|
|
garbled_storage = max(msg.size(), garbled_storage);
|
|
garbled_circuits.push(msg);
|
|
}
|
|
|
|
void ProgramParty::load_garbled_circuit()
|
|
{
|
|
if (not garbled_circuits.pop(garbled_circuit))
|
|
throw runtime_error("no garbled circuit available");
|
|
if (not output_masks_store.pop(output_masks))
|
|
throw runtime_error("no output masks available");
|
|
if (not input_masks_store.pop(input_masks))
|
|
throw runtime_error("no input masks available");
|
|
#ifdef DEBUG_OUTPUT_MASKS
|
|
cout << "loaded " << output_masks.left() << " output masks" << endl;
|
|
#endif
|
|
#ifdef DEBUG_INPUT
|
|
cout << "loaded " << input_masks.left() << " input masks" << endl;
|
|
#endif
|
|
}
|
|
|
|
void ProgramParty::start_online_round()
|
|
{
|
|
machine.reset_timer();
|
|
_check_evaluate();
|
|
}
|
|
|
|
void ProgramParty::_check_evaluate()
|
|
{
|
|
#ifdef DEBUG_REGS
|
|
print_round_regs();
|
|
#endif
|
|
cout << "Online time at evaluation start: " << online_timer.elapsed()
|
|
<< endl;
|
|
GC::BreakType next = GC::TIME_BREAK;
|
|
while (next == GC::TIME_BREAK)
|
|
{
|
|
load_garbled_circuit();
|
|
next = second_phase(program, processor, machine);
|
|
}
|
|
cout << "Online time at evaluation stop: " << online_timer.elapsed()
|
|
<< endl;
|
|
if (next == GC::TIME_BREAK)
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
cout << "another round of garbling" << endl;
|
|
#endif
|
|
}
|
|
if (next != GC::DONE_BREAK)
|
|
{
|
|
#ifdef DEBUG_STEPS
|
|
cout << "another round of evaluation" << endl;
|
|
#endif
|
|
start_online_round();
|
|
}
|
|
else
|
|
{
|
|
Timer timer;
|
|
timer.start();
|
|
MC->Check(*P);
|
|
cout << "Final check took " << timer.elapsed() << endl;
|
|
done();
|
|
}
|
|
}
|
|
|
|
void ProgramParty::receive_keys(Register& reg)
|
|
{
|
|
reg.init(_N);
|
|
for (int i = 0; i < 2; i++)
|
|
reg.keys[i][_id-1] = *(keys_for_prf++);
|
|
#ifdef DEBUG
|
|
cout << "receive keys " << reg.get_id() << "(" << ® << ") " << dec << reg.keys[0].size() << endl;
|
|
reg.keys.print(reg.get_id(), _id);
|
|
cout << "delta " << reg.get_id() << " " << (reg.keys[0][_id-1] ^ reg.keys[1][_id-1]) << endl;
|
|
#endif
|
|
}
|
|
|
|
void ProgramParty::receive_all_keys(Register& reg, bool external)
|
|
{
|
|
reg.init(get_n_parties());
|
|
for (int i = 0; i < get_n_parties(); i++)
|
|
reg.keys[external][i] = *(keys_for_prf++);
|
|
}
|
|
|
|
void ProgramParty::receive_spdz_wires(ReceivedMsg& msg)
|
|
{
|
|
int op;
|
|
msg.unserialize(op);
|
|
spdz_wires[op].push_back({});
|
|
size_t l = msg.left();
|
|
spdz_wires[op].back().append((octet*)msg.consume(l), l);
|
|
spdz_storage += l;
|
|
#ifdef DEBUG_SPDZ_WIRES
|
|
cout << "receive " << dec << spdz_wires[op].back().get_length() << "/"
|
|
<< msg.size() << " bytes for type " << op << endl;
|
|
#endif
|
|
if (op == SPDZ_MAC)
|
|
{
|
|
gf2n spdz_mac_key;
|
|
spdz_mac_key.unpack(spdz_wires[op].back());
|
|
if (!MC)
|
|
{
|
|
MC = new Passing_MAC_Check<gf2n>(spdz_mac_key, N, 0);
|
|
cout << "MAC key: " << hex << spdz_mac_key << endl;
|
|
mac_key = spdz_mac_key;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ProgramParty::get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire)
|
|
{
|
|
while (true)
|
|
{
|
|
if (spdz_wires[op].empty())
|
|
throw runtime_error("no SPDZ wires available");
|
|
if (spdz_wires[op].front().done())
|
|
spdz_wires[op].pop_front();
|
|
else
|
|
break;
|
|
}
|
|
spdz_wire.unpack(spdz_wires[op].front(), get_n_parties());
|
|
spdz_counters[op]++;
|
|
#ifdef DEBUG_SPDZ_WIRE
|
|
cout << "get SPDZ wire of type " << op << ", " << spdz_wires[op].front().left() << " bytes left" << endl;
|
|
cout << "mask share for " << get_id() << ": " << spdz_wire.mask << endl;
|
|
#endif
|
|
}
|
|
|
|
void ProgramParty::store_wire(const Register& reg)
|
|
{
|
|
wires.serialize(reg.key(get_id(), 0));
|
|
#ifndef FREE_XOR
|
|
wires.serialize(reg.key(get_id(), 1));
|
|
#endif
|
|
#ifdef DEBUG
|
|
cout << "storing wire" << endl;
|
|
reg.print();
|
|
#endif
|
|
}
|
|
|
|
void ProgramParty::load_wire(Register& reg)
|
|
{
|
|
wires.unserialize(reg.key(get_id(), 0));
|
|
#ifdef FREE_XOR
|
|
reg.key(get_id(), 1) = reg.key(get_id(), 0) ^ get_delta();
|
|
#else
|
|
wires.unserialize(reg.key(get_id(), 1));
|
|
#endif
|
|
#ifdef DEBUG
|
|
cout << "loading wire" << endl;
|
|
reg.print();
|
|
#endif
|
|
}
|