Files
MP-SPDZ/Machines/OTMachine.cpp
Marcel Keller 78fe3d8bad Maintenance.
2024-07-09 12:19:52 +10:00

400 lines
11 KiB
C++

#include "Networking/Player.h"
#include "OT/OTExtension.h"
#include "OT/OTExtensionWithMatrix.h"
#include "Tools/Exceptions.h"
#include "Tools/time-func.h"
#include <stdlib.h>
#include <sstream>
#include <fstream>
#include <iostream>
#include <sys/time.h>
#include "OutputCheck.h"
#include "OTMachine.h"
//#define BASE_OT_DEBUG
class OT_thread_info
{
public:
int thread_num;
bool stop;
int other_player_num;
OTExtensionWithMatrix* ot_ext;
int nOTs, nbase;
BitVector receiverInput;
int nloops;
};
void* run_otext_thread(void* ptr)
{
OT_thread_info *tinfo = (OT_thread_info*) ptr;
//int num = tinfo->thread_num;
//int other_player_num = tinfo->other_player_num;
printf("\tI am in thread %d\n", tinfo->thread_num);
tinfo->ot_ext->transfer(tinfo->nOTs, tinfo->receiverInput, tinfo->nloops);
return NULL;
}
OTMachine::OTMachine(int argc, const char** argv)
{
opt.add(
"", // Default.
1, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number, 0/1 (required).", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.add(
"5000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Base port number (default: 5000).", // Help description.
"-pn", // Flag token.
"--portnum" // Flag token.
);
opt.add(
"localhost", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Host name(s) that player 0 is running on (default: localhost). Split with commas.", // Help description.
"-h", // Flag token.
"--hostname" // Flag token.
);
opt.add(
"1024",
0,
1,
0,
"Number of extended OTs to run (default: 1024).",
"-n",
"--nOTs"
);
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of base OTs to run (default: 128).", // Help description.
"-b", // Flag token.
"--nbase" // Flag token.
);
opt.add(
"s",
0,
1,
0,
"Mode for OT. a (asymmetric) or s (symmetric, i.e. play both sender/receiver) (default: s).",
"-m",
"--mode"
);
opt.add(
"1",
0,
1,
0,
"Number of threads (default: 1).",
"-x",
"--nthreads"
);
opt.add(
"1",
0,
1,
0,
"Number of loops (default: 1).",
"-l",
"--nloops"
);
opt.add(
"1",
0,
1,
0,
"Number of subloops (default: 1).",
"-s",
"--nsubloops"
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Run in passive security mode.", // Help description.
"-pas", // Flag token.
"--passive" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Write results to files.", // Help description.
"-o", // Flag token.
"--output" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Real base OT.", // Help description.
"-r", // Flag token.
"--real" // Flag token.
);
opt.parse(argc, argv);
string hostname, ot_mode, usage;
passive = false;
opt.get("-p")->getInt(my_num);
opt.get("-pn")->getInt(portnum_base);
opt.get("-h")->getString(hostname);
opt.get("-n")->getLong(nOTs);
opt.get("-m")->getString(ot_mode);
opt.get("--nthreads")->getInt(nthreads);
opt.get("--nloops")->getInt(nloops);
opt.get("--nsubloops")->getInt(nsubloops);
opt.get("--nbase")->getInt(nbase);
if (opt.isSet("-pas"))
passive = true;
if (!opt.isSet("-p"))
{
opt.getUsage(usage);
cout << usage;
exit(0);
}
cout << "Player 0 host name = " << hostname << endl;
cout << "Creating " << nOTs << " extended OTs in " << nthreads << " threads\n";
cout << "Running in mode " << ot_mode << endl;
if (passive)
cout << "Running with PASSIVE security only\n";
if (nbase < 128)
cout << "WARNING: only using " << nbase << " seed OTs, using these for OT extensions is insecure.\n";
if (ot_mode.compare("s") == 0)
ot_role = BOTH;
else if (ot_mode.compare("a") == 0)
{
if (my_num == 0)
ot_role = SENDER;
else
ot_role = RECEIVER;
}
else
{
cerr << "Invalid OT mode argument: " << ot_mode << endl;
exit(1);
}
// Several names for multiplexing
unsigned int pos = 0;
while (pos < hostname.length())
{
string::size_type new_pos = hostname.find(',', pos);
if (new_pos == string::npos)
new_pos = hostname.length();
int len = new_pos - pos;
string name = hostname.substr(pos, len);
pos = new_pos + 1;
vector<string> names(2);
names[my_num] = "localhost";
names[1-my_num] = name;
N.push_back(new Names(my_num, portnum_base + 1000 * N.size(), names));
}
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "machine");
timeval baseOTstart, baseOTend;
gettimeofday(&baseOTstart, NULL);
// swap role for base OTs
if (opt.isSet("-r"))
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
else
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
cout << "real mode " << opt.isSet("-r") << endl;
BaseOT& bot = *bot_;
bot.exec_base();
gettimeofday(&baseOTend, NULL);
double basetime = timeval_diff(&baseOTstart, &baseOTend);
cout << "\t\tBaseTime (" << role_to_str(ot_role) << "): " << basetime/1000000 << endl << flush;
// Receiver send something to force synchronization
// (since Sender finishes baseOTs before Receiver)
int a = 3;
vector<octetStream> os(2);
os[0].store(a);
P->send_receive_player(os);
os[1].get(a);
cout << a << endl;
#ifdef BASE_OT_DEBUG
// check base OTs
bot.check();
// check after extending with PRG a few times
for (int i = 0; i < 8; i++)
{
bot.extend_length();
bot.check();
}
cout << "Verifying base OTs (debugging)\n";
#endif
// convert baseOT selection bits to BitVector
// (not already BitVector due to legacy PVW code)
baseReceiverInput = bot.receiver_inputs;
baseReceiverInput.resize(nbase);
}
OTMachine::~OTMachine()
{
for (auto names : N)
delete names;
delete bot_;
delete P;
}
void OTMachine::run()
{
// divide nOTs between threads and loops
nOTs = DIV_CEIL(nOTs, nthreads * nloops);
// round up to multiple of base OTs and subloops
// discount for discarded OTs
nOTs = DIV_CEIL(nOTs + 2 * 128, nbase * nsubloops) * nbase * nsubloops - 2 * 128;
cout << "Running " << nOTs << " OT extensions per thread and loop\n" << flush;
// PRG for generating inputs etc
PRNG G;
G.ReSeed();
BitVector receiverInput(nOTs);
receiverInput.randomize(G);
BaseOT& bot = *bot_;
cout << "Initialize OT Extension\n";
vector<OT_thread_info> tinfos(nthreads);
vector<pthread_t> threads(nthreads);
timeval OTextstart, OTextend;
gettimeofday(&OTextstart, NULL);
// copy base inputs/outputs for each thread
vector<BitVector> base_receiver_input_copy(nthreads);
vector<vector< array<BitVector, 2> > > base_sender_inputs_copy(nthreads, vector<array<BitVector, 2> >(nbase));
vector< vector<BitVector> > base_receiver_outputs_copy(nthreads, vector<BitVector>(nbase));
vector<TwoPartyPlayer*> players(nthreads);
for (int i = 0; i < nthreads; i++)
{
tinfos[i].receiverInput.assign(receiverInput);
base_receiver_input_copy[i].assign(baseReceiverInput);
for (int j = 0; j < nbase; j++)
{
base_sender_inputs_copy[i][j][0].assign(bot.sender_inputs[j][0]);
base_sender_inputs_copy[i][j][1].assign(bot.sender_inputs[j][1]);
base_receiver_outputs_copy[i][j].assign(bot.receiver_outputs[j]);
}
// now setup resources for each thread
// round robin with the names
players[i] = new RealTwoPartyPlayer(*N[i % N.size()], 1 - my_num,
"thread" + to_string(i));
tinfos[i].thread_num = i+1;
tinfos[i].other_player_num = 1 - my_num;
tinfos[i].nOTs = nOTs;
tinfos[i].ot_ext = new OTExtensionWithMatrix(
players[i],
ot_role,
passive,
nsubloops);
tinfos[i].ot_ext->init(base_receiver_input_copy[i],
base_sender_inputs_copy[i], base_receiver_outputs_copy[i]);
tinfos[i].nloops = nloops;
// create the thread
pthread_create(&threads[i], NULL, run_otext_thread, &tinfos[i]);
// extend base OTs with PRG for the next thread
bot.extend_length();
}
// wait for threads to finish
for (int i = 0; i < nthreads; i++)
{
pthread_join(threads[i],NULL);
cout << "thread " << i+1 << " finished\n" << flush;
}
map<string,long long>& times = tinfos[0].ot_ext->times;
for (map<string,long long>::iterator it = times.begin(); it != times.end(); it++)
{
long long sum = 0;
for (int i = 0; i < nthreads; i++)
sum += tinfos[i].ot_ext->times[it->first];
cout << it->first << " on average took time "
<< double(sum) / nthreads / 1e6 << endl;
}
gettimeofday(&OTextend, NULL);
double totaltime = timeval_diff(&OTextstart, &OTextend);
cout << "Time for OTExt threads (" << role_to_str(ot_role) << "): " << totaltime/1000000 << endl << flush;
if (opt.isSet("-o"))
{
BitVector receiver_output, sender_output;
char filename[1024];
snprintf(filename, 1024, RECEIVER_INPUT, my_num);
ofstream outf(filename);
receiverInput.output(outf, false);
outf.close();
snprintf(filename, 1024, RECEIVER_OUTPUT, my_num);
outf.open(filename);
for (unsigned int i = 0; i < nOTs; i++)
{
receiver_output.assign_bytes((char*) tinfos[0].ot_ext->get_receiver_output(i), sizeof(__m128i));
receiver_output.output(outf, false);
}
outf.close();
for (int i = 0; i < 2; i++)
{
snprintf(filename,1024, SENDER_OUTPUT, my_num, i);
outf.open(filename);
for (int j = 0; j < nOTs; j++)
{
sender_output.assign_bytes((char*) tinfos[0].ot_ext->get_sender_output(i, j), sizeof(__m128i));
sender_output.output(outf, false);
}
outf.close();
}
}
for (int i = 0; i < nthreads; i++)
{
delete players[i];
delete tinfos[i].ot_ext;
}
}