mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
400 lines
11 KiB
C++
400 lines
11 KiB
C++
#include "Networking/Player.h"
|
|
#include "OT/OTExtension.h"
|
|
#include "OT/OTExtensionWithMatrix.h"
|
|
#include "Tools/Exceptions.h"
|
|
#include "Tools/time-func.h"
|
|
|
|
#include <stdlib.h>
|
|
#include <sstream>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
|
|
#include <sys/time.h>
|
|
|
|
#include "OutputCheck.h"
|
|
#include "OTMachine.h"
|
|
|
|
//#define BASE_OT_DEBUG
|
|
|
|
class OT_thread_info
|
|
{
|
|
public:
|
|
|
|
int thread_num;
|
|
bool stop;
|
|
int other_player_num;
|
|
OTExtensionWithMatrix* ot_ext;
|
|
int nOTs, nbase;
|
|
BitVector receiverInput;
|
|
int nloops;
|
|
};
|
|
|
|
void* run_otext_thread(void* ptr)
|
|
{
|
|
OT_thread_info *tinfo = (OT_thread_info*) ptr;
|
|
|
|
//int num = tinfo->thread_num;
|
|
//int other_player_num = tinfo->other_player_num;
|
|
printf("\tI am in thread %d\n", tinfo->thread_num);
|
|
tinfo->ot_ext->transfer(tinfo->nOTs, tinfo->receiverInput, tinfo->nloops);
|
|
return NULL;
|
|
}
|
|
|
|
OTMachine::OTMachine(int argc, const char** argv)
|
|
{
|
|
opt.add(
|
|
"", // Default.
|
|
1, // Required?
|
|
1, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"This player's number, 0/1 (required).", // Help description.
|
|
"-p", // Flag token.
|
|
"--player" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"5000", // Default.
|
|
0, // Required?
|
|
1, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Base port number (default: 5000).", // Help description.
|
|
"-pn", // Flag token.
|
|
"--portnum" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"localhost", // Default.
|
|
0, // Required?
|
|
1, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Host name(s) that player 0 is running on (default: localhost). Split with commas.", // Help description.
|
|
"-h", // Flag token.
|
|
"--hostname" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"1024",
|
|
0,
|
|
1,
|
|
0,
|
|
"Number of extended OTs to run (default: 1024).",
|
|
"-n",
|
|
"--nOTs"
|
|
);
|
|
|
|
opt.add(
|
|
"128", // Default.
|
|
0, // Required?
|
|
1, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Number of base OTs to run (default: 128).", // Help description.
|
|
"-b", // Flag token.
|
|
"--nbase" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"s",
|
|
0,
|
|
1,
|
|
0,
|
|
"Mode for OT. a (asymmetric) or s (symmetric, i.e. play both sender/receiver) (default: s).",
|
|
"-m",
|
|
"--mode"
|
|
);
|
|
opt.add(
|
|
"1",
|
|
0,
|
|
1,
|
|
0,
|
|
"Number of threads (default: 1).",
|
|
"-x",
|
|
"--nthreads"
|
|
);
|
|
|
|
opt.add(
|
|
"1",
|
|
0,
|
|
1,
|
|
0,
|
|
"Number of loops (default: 1).",
|
|
"-l",
|
|
"--nloops"
|
|
);
|
|
|
|
opt.add(
|
|
"1",
|
|
0,
|
|
1,
|
|
0,
|
|
"Number of subloops (default: 1).",
|
|
"-s",
|
|
"--nsubloops"
|
|
);
|
|
|
|
opt.add(
|
|
"", // Default.
|
|
0, // Required?
|
|
0, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Run in passive security mode.", // Help description.
|
|
"-pas", // Flag token.
|
|
"--passive" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"", // Default.
|
|
0, // Required?
|
|
0, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Write results to files.", // Help description.
|
|
"-o", // Flag token.
|
|
"--output" // Flag token.
|
|
);
|
|
|
|
opt.add(
|
|
"", // Default.
|
|
0, // Required?
|
|
0, // Number of args expected.
|
|
0, // Delimiter if expecting multiple args.
|
|
"Real base OT.", // Help description.
|
|
"-r", // Flag token.
|
|
"--real" // Flag token.
|
|
);
|
|
|
|
opt.parse(argc, argv);
|
|
|
|
string hostname, ot_mode, usage;
|
|
passive = false;
|
|
opt.get("-p")->getInt(my_num);
|
|
opt.get("-pn")->getInt(portnum_base);
|
|
opt.get("-h")->getString(hostname);
|
|
opt.get("-n")->getLong(nOTs);
|
|
opt.get("-m")->getString(ot_mode);
|
|
opt.get("--nthreads")->getInt(nthreads);
|
|
opt.get("--nloops")->getInt(nloops);
|
|
opt.get("--nsubloops")->getInt(nsubloops);
|
|
opt.get("--nbase")->getInt(nbase);
|
|
if (opt.isSet("-pas"))
|
|
passive = true;
|
|
|
|
if (!opt.isSet("-p"))
|
|
{
|
|
opt.getUsage(usage);
|
|
cout << usage;
|
|
exit(0);
|
|
}
|
|
|
|
cout << "Player 0 host name = " << hostname << endl;
|
|
cout << "Creating " << nOTs << " extended OTs in " << nthreads << " threads\n";
|
|
cout << "Running in mode " << ot_mode << endl;
|
|
|
|
if (passive)
|
|
cout << "Running with PASSIVE security only\n";
|
|
|
|
if (nbase < 128)
|
|
cout << "WARNING: only using " << nbase << " seed OTs, using these for OT extensions is insecure.\n";
|
|
|
|
if (ot_mode.compare("s") == 0)
|
|
ot_role = BOTH;
|
|
else if (ot_mode.compare("a") == 0)
|
|
{
|
|
if (my_num == 0)
|
|
ot_role = SENDER;
|
|
else
|
|
ot_role = RECEIVER;
|
|
}
|
|
else
|
|
{
|
|
cerr << "Invalid OT mode argument: " << ot_mode << endl;
|
|
exit(1);
|
|
}
|
|
|
|
// Several names for multiplexing
|
|
unsigned int pos = 0;
|
|
while (pos < hostname.length())
|
|
{
|
|
string::size_type new_pos = hostname.find(',', pos);
|
|
if (new_pos == string::npos)
|
|
new_pos = hostname.length();
|
|
int len = new_pos - pos;
|
|
string name = hostname.substr(pos, len);
|
|
pos = new_pos + 1;
|
|
|
|
vector<string> names(2);
|
|
names[my_num] = "localhost";
|
|
names[1-my_num] = name;
|
|
N.push_back(new Names(my_num, portnum_base + 1000 * N.size(), names));
|
|
}
|
|
|
|
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "machine");
|
|
|
|
timeval baseOTstart, baseOTend;
|
|
gettimeofday(&baseOTstart, NULL);
|
|
// swap role for base OTs
|
|
if (opt.isSet("-r"))
|
|
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
|
|
else
|
|
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
|
|
cout << "real mode " << opt.isSet("-r") << endl;
|
|
BaseOT& bot = *bot_;
|
|
bot.exec_base();
|
|
gettimeofday(&baseOTend, NULL);
|
|
double basetime = timeval_diff(&baseOTstart, &baseOTend);
|
|
cout << "\t\tBaseTime (" << role_to_str(ot_role) << "): " << basetime/1000000 << endl << flush;
|
|
|
|
// Receiver send something to force synchronization
|
|
// (since Sender finishes baseOTs before Receiver)
|
|
int a = 3;
|
|
vector<octetStream> os(2);
|
|
os[0].store(a);
|
|
P->send_receive_player(os);
|
|
os[1].get(a);
|
|
cout << a << endl;
|
|
|
|
#ifdef BASE_OT_DEBUG
|
|
// check base OTs
|
|
bot.check();
|
|
// check after extending with PRG a few times
|
|
for (int i = 0; i < 8; i++)
|
|
{
|
|
bot.extend_length();
|
|
bot.check();
|
|
}
|
|
cout << "Verifying base OTs (debugging)\n";
|
|
#endif
|
|
|
|
// convert baseOT selection bits to BitVector
|
|
// (not already BitVector due to legacy PVW code)
|
|
baseReceiverInput = bot.receiver_inputs;
|
|
baseReceiverInput.resize(nbase);
|
|
}
|
|
|
|
OTMachine::~OTMachine()
|
|
{
|
|
for (auto names : N)
|
|
delete names;
|
|
delete bot_;
|
|
delete P;
|
|
}
|
|
|
|
|
|
void OTMachine::run()
|
|
{
|
|
// divide nOTs between threads and loops
|
|
nOTs = DIV_CEIL(nOTs, nthreads * nloops);
|
|
// round up to multiple of base OTs and subloops
|
|
// discount for discarded OTs
|
|
nOTs = DIV_CEIL(nOTs + 2 * 128, nbase * nsubloops) * nbase * nsubloops - 2 * 128;
|
|
cout << "Running " << nOTs << " OT extensions per thread and loop\n" << flush;
|
|
|
|
// PRG for generating inputs etc
|
|
PRNG G;
|
|
G.ReSeed();
|
|
BitVector receiverInput(nOTs);
|
|
receiverInput.randomize(G);
|
|
BaseOT& bot = *bot_;
|
|
|
|
cout << "Initialize OT Extension\n";
|
|
vector<OT_thread_info> tinfos(nthreads);
|
|
vector<pthread_t> threads(nthreads);
|
|
timeval OTextstart, OTextend;
|
|
gettimeofday(&OTextstart, NULL);
|
|
|
|
// copy base inputs/outputs for each thread
|
|
vector<BitVector> base_receiver_input_copy(nthreads);
|
|
vector<vector< array<BitVector, 2> > > base_sender_inputs_copy(nthreads, vector<array<BitVector, 2> >(nbase));
|
|
vector< vector<BitVector> > base_receiver_outputs_copy(nthreads, vector<BitVector>(nbase));
|
|
vector<TwoPartyPlayer*> players(nthreads);
|
|
|
|
for (int i = 0; i < nthreads; i++)
|
|
{
|
|
tinfos[i].receiverInput.assign(receiverInput);
|
|
|
|
base_receiver_input_copy[i].assign(baseReceiverInput);
|
|
for (int j = 0; j < nbase; j++)
|
|
{
|
|
base_sender_inputs_copy[i][j][0].assign(bot.sender_inputs[j][0]);
|
|
base_sender_inputs_copy[i][j][1].assign(bot.sender_inputs[j][1]);
|
|
base_receiver_outputs_copy[i][j].assign(bot.receiver_outputs[j]);
|
|
}
|
|
// now setup resources for each thread
|
|
// round robin with the names
|
|
players[i] = new RealTwoPartyPlayer(*N[i % N.size()], 1 - my_num,
|
|
"thread" + to_string(i));
|
|
tinfos[i].thread_num = i+1;
|
|
tinfos[i].other_player_num = 1 - my_num;
|
|
tinfos[i].nOTs = nOTs;
|
|
tinfos[i].ot_ext = new OTExtensionWithMatrix(
|
|
players[i],
|
|
ot_role,
|
|
passive,
|
|
nsubloops);
|
|
tinfos[i].ot_ext->init(base_receiver_input_copy[i],
|
|
base_sender_inputs_copy[i], base_receiver_outputs_copy[i]);
|
|
tinfos[i].nloops = nloops;
|
|
|
|
// create the thread
|
|
pthread_create(&threads[i], NULL, run_otext_thread, &tinfos[i]);
|
|
|
|
// extend base OTs with PRG for the next thread
|
|
bot.extend_length();
|
|
}
|
|
// wait for threads to finish
|
|
for (int i = 0; i < nthreads; i++)
|
|
{
|
|
pthread_join(threads[i],NULL);
|
|
cout << "thread " << i+1 << " finished\n" << flush;
|
|
}
|
|
|
|
map<string,long long>& times = tinfos[0].ot_ext->times;
|
|
for (map<string,long long>::iterator it = times.begin(); it != times.end(); it++)
|
|
{
|
|
long long sum = 0;
|
|
for (int i = 0; i < nthreads; i++)
|
|
sum += tinfos[i].ot_ext->times[it->first];
|
|
|
|
cout << it->first << " on average took time "
|
|
<< double(sum) / nthreads / 1e6 << endl;
|
|
}
|
|
|
|
gettimeofday(&OTextend, NULL);
|
|
double totaltime = timeval_diff(&OTextstart, &OTextend);
|
|
cout << "Time for OTExt threads (" << role_to_str(ot_role) << "): " << totaltime/1000000 << endl << flush;
|
|
|
|
if (opt.isSet("-o"))
|
|
{
|
|
BitVector receiver_output, sender_output;
|
|
char filename[1024];
|
|
snprintf(filename, 1024, RECEIVER_INPUT, my_num);
|
|
ofstream outf(filename);
|
|
receiverInput.output(outf, false);
|
|
outf.close();
|
|
snprintf(filename, 1024, RECEIVER_OUTPUT, my_num);
|
|
outf.open(filename);
|
|
for (unsigned int i = 0; i < nOTs; i++)
|
|
{
|
|
receiver_output.assign_bytes((char*) tinfos[0].ot_ext->get_receiver_output(i), sizeof(__m128i));
|
|
receiver_output.output(outf, false);
|
|
}
|
|
outf.close();
|
|
|
|
for (int i = 0; i < 2; i++)
|
|
{
|
|
snprintf(filename,1024, SENDER_OUTPUT, my_num, i);
|
|
outf.open(filename);
|
|
for (int j = 0; j < nOTs; j++)
|
|
{
|
|
sender_output.assign_bytes((char*) tinfos[0].ot_ext->get_sender_output(i, j), sizeof(__m128i));
|
|
sender_output.output(outf, false);
|
|
}
|
|
outf.close();
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < nthreads; i++)
|
|
{
|
|
delete players[i];
|
|
delete tinfos[i].ot_ext;
|
|
}
|
|
}
|