mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Maintenance.
This commit is contained in:
@@ -126,7 +126,7 @@ void BaseMachine::time()
|
||||
void BaseMachine::start(int n)
|
||||
{
|
||||
cout << "Starting timer " << n << " at " << timer[n].elapsed()
|
||||
<< " (" << timer[n].mb_sent() << " MB)"
|
||||
<< " (" << timer[n] << ")"
|
||||
<< " after " << timer[n].idle() << endl;
|
||||
timer[n].start(total_comm());
|
||||
}
|
||||
@@ -135,7 +135,7 @@ void BaseMachine::stop(int n)
|
||||
{
|
||||
timer[n].stop(total_comm());
|
||||
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
|
||||
<< timer[n].mb_sent() << " MB)" << endl;
|
||||
<< timer[n] << ")" << endl;
|
||||
}
|
||||
|
||||
void BaseMachine::print_timers()
|
||||
@@ -150,7 +150,7 @@ void BaseMachine::print_timers()
|
||||
timer.erase(0);
|
||||
for (auto it = timer.begin(); it != timer.end(); it++)
|
||||
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
|
||||
<< it->second.mb_sent() << " MB)" << endl;
|
||||
<< it->second << ")" << endl;
|
||||
}
|
||||
|
||||
string BaseMachine::memory_filename(const string& type_short, int my_number)
|
||||
@@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
|
||||
global += os.get_int(8);
|
||||
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
|
||||
}
|
||||
|
||||
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
{
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << P.my_num() << " only";
|
||||
if (nthreads > 1)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
if (not OnlineOptions::singleton.verbose)
|
||||
cerr << "; use '-v' for more details";
|
||||
cerr << ")" << endl;
|
||||
|
||||
print_global_comm(P, comm_stats);
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ public:
|
||||
void print_timers();
|
||||
|
||||
virtual void reqbl(int) {}
|
||||
virtual void active(int) {}
|
||||
|
||||
static OTTripleSetup fresh_ot_setup(Player& P);
|
||||
|
||||
@@ -74,6 +75,7 @@ public:
|
||||
void set_thread_comm(const NamedCommStats& stats);
|
||||
|
||||
void print_global_comm(Player& P, const NamedCommStats& stats);
|
||||
void print_comm(Player& P, const NamedCommStats& stats);
|
||||
};
|
||||
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
|
||||
41
Processor/Conv2dTuple.h
Normal file
41
Processor/Conv2dTuple.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Conv2dTuple.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_CONV2DTUPLE_H_
|
||||
#define PROCESSOR_CONV2DTUPLE_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
class Conv2dTuple
|
||||
{
|
||||
public:
|
||||
int output_h, output_w;
|
||||
int inputs_h, inputs_w;
|
||||
int weights_h, weights_w;
|
||||
int stride_h, stride_w;
|
||||
int n_channels_in;
|
||||
int padding_h;
|
||||
int padding_w;
|
||||
int batch_size;
|
||||
size_t r0;
|
||||
size_t r1;
|
||||
int r2;
|
||||
vector<vector<vector<int>>> lengths;
|
||||
int filter_stride_h = 1;
|
||||
int filter_stride_w = 1;
|
||||
|
||||
Conv2dTuple(const vector<int>& args, int start);
|
||||
|
||||
template<class T>
|
||||
void pre(vector<T>& S, typename T::Protocol& protocol);
|
||||
template<class T>
|
||||
void post(vector<T>& S, typename T::Protocol& protocol);
|
||||
|
||||
template<class T>
|
||||
void run_matrix(SubProcessor<T>& processor);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_CONV2DTUPLE_H_ */
|
||||
@@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const
|
||||
for (auto it = edabits.begin(); it != edabits.end(); it++)
|
||||
{
|
||||
auto x = other.edabits.find(it->first);
|
||||
if (x == other.edabits.end() or it->second > x->second)
|
||||
if ((x == other.edabits.end() or it->second > x->second)
|
||||
and it->second > 0)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "Networking/Player.h"
|
||||
#include "Protocols/edabit.h"
|
||||
#include "PrepBase.h"
|
||||
#include "EdabitBuffer.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
@@ -102,9 +104,6 @@ protected:
|
||||
|
||||
DataPositions& usage;
|
||||
|
||||
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
|
||||
map<pair<bool, int>, edabitvec<T>> my_edabits;
|
||||
|
||||
bool do_count;
|
||||
|
||||
void count(Dtype dtype, int n = 1)
|
||||
@@ -120,6 +119,8 @@ protected:
|
||||
const vector<int>&, true_type)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void fill(edabitvec<T>& res, bool strict, int n_bits);
|
||||
|
||||
T get_random_from_inputs(int nplayers);
|
||||
|
||||
public:
|
||||
@@ -173,12 +174,11 @@ public:
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
|
||||
template<int>
|
||||
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
|
||||
template<int>
|
||||
virtual void get_edabit_no_count(bool, int, edabit<T>&)
|
||||
{ throw runtime_error("no edaBits"); }
|
||||
/// Get fresh edaBit chunk
|
||||
edabitvec<T> get_edabitvec(bool strict, int n_bits);
|
||||
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
|
||||
virtual edabitvec<T> get_edabitvec(bool, int)
|
||||
{ throw runtime_error("no edabitvec"); }
|
||||
|
||||
virtual void push_triples(const vector<array<T, 3>>&)
|
||||
{ throw runtime_error("no pushing"); }
|
||||
@@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
|
||||
map<DataTag, BufferOwner<T, T> > extended;
|
||||
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
|
||||
map<int, ifstream*> edabit_buffers;
|
||||
map<int, EdabitBuffer<T>> edabit_buffers;
|
||||
map<int, edabitvec<T>> my_edabits;
|
||||
|
||||
int my_num,num_players;
|
||||
|
||||
@@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
|
||||
part_type* part;
|
||||
|
||||
void buffer_edabits_with_queues(bool strict, int n_bits)
|
||||
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
|
||||
template<int>
|
||||
void buffer_edabits_with_queues(bool strict, int n_bits, false_type);
|
||||
template<int>
|
||||
void buffer_edabits_with_queues(bool, int, true_type)
|
||||
{ throw not_implemented(); }
|
||||
EdabitBuffer<T>& get_edabit_buffer(int n_bits);
|
||||
|
||||
/// Get fresh edaBit chunk
|
||||
edabitvec<T> get_edabitvec(bool strict, int n_bits);
|
||||
void get_edabit_no_count(bool strict, int n_bits, edabit<T>& eb);
|
||||
|
||||
public:
|
||||
static string get_filename(const Names& N, Dtype type, int thread_num = -1);
|
||||
@@ -317,6 +316,8 @@ class Data_Files
|
||||
void reset_usage() { usage.reset(); skipped.reset(); }
|
||||
|
||||
void set_usage(const DataPositions& pos) { usage = pos; }
|
||||
|
||||
TimerWithComm total_time();
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
|
||||
@@ -108,7 +108,21 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
#ifdef DEBUG_FILES
|
||||
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
|
||||
#endif
|
||||
T::clear::check_setup(prep_data_dir);
|
||||
|
||||
try
|
||||
{
|
||||
T::clear::check_setup(prep_data_dir);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
cerr << "Something is wrong with the preprocessing data on disk." << endl;
|
||||
cerr
|
||||
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
|
||||
<< num_players
|
||||
<< T::clear::fake_opts() << "'?" << endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
string type_short = T::type_short();
|
||||
string type_string = T::type_string();
|
||||
|
||||
@@ -135,7 +149,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
type_short, i, my_num, thread_num);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(filename,
|
||||
T::size() + T::clear::size(), type_string);
|
||||
InputTuple<T>::size(), type_string);
|
||||
else
|
||||
input_buffers[i].setup(filename,
|
||||
T::size(), type_string);
|
||||
@@ -179,10 +193,6 @@ Data_Files<sint, sgf2n>::~Data_Files()
|
||||
template<class T>
|
||||
Sub_Data_Files<T>::~Sub_Data_Files()
|
||||
{
|
||||
for (auto& x: edabit_buffers)
|
||||
{
|
||||
delete x.second;
|
||||
}
|
||||
if (part != 0)
|
||||
delete part;
|
||||
}
|
||||
@@ -229,6 +239,26 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
|
||||
extended[it->first].seekg(it->second);
|
||||
}
|
||||
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
|
||||
|
||||
if (field_type == DATA_INT)
|
||||
{
|
||||
for (auto& x : pos.edabits)
|
||||
{
|
||||
// open files
|
||||
get_edabit_buffer(x.first.second);
|
||||
}
|
||||
|
||||
|
||||
int block_size = edabitvec<T>::MAX_SIZE;
|
||||
for (auto& x : edabit_buffers)
|
||||
{
|
||||
int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}];
|
||||
x.second.seekg(n / block_size);
|
||||
edabit<T> eb;
|
||||
for (int i = 0; i < n % block_size; i++)
|
||||
get_edabit_no_count(false, x.first, eb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -262,6 +292,8 @@ void Sub_Data_Files<T>::prune()
|
||||
dabit_buffer.prune();
|
||||
if (part != 0)
|
||||
part->prune();
|
||||
for (auto& x : edabit_buffers)
|
||||
x.second.prune();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -285,6 +317,8 @@ void Sub_Data_Files<T>::purge()
|
||||
dabit_buffer.purge();
|
||||
if (part != 0)
|
||||
part->purge();
|
||||
for (auto& x : edabit_buffers)
|
||||
x.second.prune();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -322,34 +356,43 @@ void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<int>
|
||||
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
|
||||
false_type)
|
||||
EdabitBuffer<T>& Sub_Data_Files<T>::get_edabit_buffer(int n_bits)
|
||||
{
|
||||
if (edabit_buffers.empty())
|
||||
insecure("reading edaBits from files");
|
||||
|
||||
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
|
||||
{
|
||||
string filename = PrepBase::get_edabit_filename(prep_data_dir,
|
||||
n_bits, my_num, thread_num);
|
||||
ifstream* f = new ifstream(filename);
|
||||
if (f->fail())
|
||||
throw runtime_error("cannot open " + filename);
|
||||
check_file_signature<T>(*f, filename);
|
||||
edabit_buffers[n_bits] = f;
|
||||
edabit_buffers[n_bits] = n_bits;
|
||||
edabit_buffers[n_bits].setup(filename,
|
||||
T::size() * edabitvec<T>::MAX_SIZE
|
||||
+ n_bits * T::bit_type::part_type::size());
|
||||
}
|
||||
auto& buffer = *edabit_buffers[n_bits];
|
||||
if (buffer.peek() == EOF)
|
||||
return edabit_buffers[n_bits];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
edabitvec<T> Sub_Data_Files<T>::get_edabitvec(bool strict, int n_bits)
|
||||
{
|
||||
if (my_edabits[n_bits].empty())
|
||||
return get_edabit_buffer(n_bits).read();
|
||||
else
|
||||
{
|
||||
buffer.seekg(0);
|
||||
check_file_signature<T>(buffer, "");
|
||||
auto res = my_edabits[n_bits];
|
||||
my_edabits[n_bits] = {};
|
||||
this->fill(res, strict, n_bits);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Preprocessing<T>::fill(edabitvec<T>& res, bool strict, int n_bits)
|
||||
{
|
||||
edabit<T> eb;
|
||||
while (res.size() < res.MAX_SIZE)
|
||||
{
|
||||
get_edabit_no_count(strict, n_bits, eb);
|
||||
res.push_back(eb);
|
||||
}
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
this->edabits[{strict, n_bits}].push_back(eb);
|
||||
if (buffer.fail())
|
||||
throw runtime_error("error reading edaBits");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -362,4 +405,10 @@ typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
|
||||
return *part;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
TimerWithComm Data_Files<sint, sgf2n>::total_time()
|
||||
{
|
||||
return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
50
Processor/EdabitBuffer.h
Normal file
50
Processor/EdabitBuffer.h
Normal file
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* EdabitBuffer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_EDABITBUFFER_H_
|
||||
#define PROCESSOR_EDABITBUFFER_H_
|
||||
|
||||
#include "Tools/Buffer.h"
|
||||
|
||||
template<class T>
|
||||
class EdabitBuffer : public BufferOwner<T, T>
|
||||
{
|
||||
int n_bits;
|
||||
|
||||
int element_length()
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
public:
|
||||
EdabitBuffer(int n_bits = 0) :
|
||||
n_bits(n_bits)
|
||||
{
|
||||
}
|
||||
|
||||
edabitvec<T> read()
|
||||
{
|
||||
if (not BufferBase::file)
|
||||
{
|
||||
if (this->open()->fail())
|
||||
throw runtime_error("error opening " + this->filename);
|
||||
}
|
||||
|
||||
assert(BufferBase::file);
|
||||
auto& buffer = *BufferBase::file;
|
||||
if (buffer.peek() == EOF)
|
||||
{
|
||||
this->try_rewind();
|
||||
}
|
||||
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
if (buffer.fail())
|
||||
throw runtime_error("error reading edaBits");
|
||||
return eb;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_EDABITBUFFER_H_ */
|
||||
@@ -70,6 +70,7 @@ enum
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
|
||||
@@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRIVATEOUTPUT:
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
case CONV2DS:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
@@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
get_ints(r, s, 3);
|
||||
get_vector(9, start, s);
|
||||
break;
|
||||
case CONV2DS:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(12, start, s);
|
||||
break;
|
||||
|
||||
// read from file, input is opcode num_args,
|
||||
// start_file_posn (read), end_file_posn(write) var1, var2, ...
|
||||
@@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
break;
|
||||
case ACTIVE:
|
||||
n = get_int(s);
|
||||
BaseMachine::s().active(n);
|
||||
break;
|
||||
case XORM:
|
||||
case ANDM:
|
||||
case XORCB:
|
||||
@@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
case MATMULSM:
|
||||
return r[0] + start[0] * start[2];
|
||||
case CONV2DS:
|
||||
return r[0] + start[0] * start[1] * start[11];
|
||||
{
|
||||
unsigned res = 0;
|
||||
for (size_t i = 0; i < start.size(); i += 15)
|
||||
{
|
||||
unsigned tmp = start[i]
|
||||
+ start[i + 3] * start[i + 4] * start.at(i + 14);
|
||||
res = max(res, tmp);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case OPEN:
|
||||
skip = 2;
|
||||
break;
|
||||
@@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
break;
|
||||
case REQBL:
|
||||
case GREQBL:
|
||||
case ACTIVE:
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
|
||||
@@ -109,6 +109,7 @@ class Machine : public BaseMachine
|
||||
string prep_dir_prefix();
|
||||
|
||||
void reqbl(int n);
|
||||
void active(int n);
|
||||
|
||||
typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; }
|
||||
typename sint::mac_key_type get_sint_mac_key() { return alphapi; }
|
||||
|
||||
@@ -415,6 +415,9 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
|
||||
auto comm_stats = total_comm();
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
queues.print_breakdown();
|
||||
|
||||
for (auto& queue : queues)
|
||||
delete queue;
|
||||
|
||||
@@ -477,20 +480,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
print_timers();
|
||||
|
||||
if (sint::is_real)
|
||||
{
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << my_number;
|
||||
if (threads.size() > 1)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
cerr << "; use '-v' for more details";
|
||||
cerr << ")" << endl;
|
||||
|
||||
auto& P = *this->P;
|
||||
this->print_global_comm(P, comm_stats);
|
||||
}
|
||||
this->print_comm(*this->P, comm_stats);
|
||||
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
if (opening_sum < N.num_players() && !direct)
|
||||
@@ -521,23 +511,6 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
bit_memories.write_memory(N.my_num());
|
||||
|
||||
#ifdef OLD_USAGE
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
{
|
||||
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
cerr << " " << pos.files[field_type][dtype];
|
||||
cerr << endl;
|
||||
}
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t=";
|
||||
for (int i = 0; i < N.num_players(); i++)
|
||||
cerr << " " << pos.inputs[i][field_type];
|
||||
cerr << endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Actual cost of program:" << endl;
|
||||
@@ -586,6 +559,17 @@ void Machine<sint, sgf2n>::reqbl(int n)
|
||||
sint::clear::reqbl(n);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::active(int n)
|
||||
{
|
||||
|
||||
if (sint::malicious and n == 0)
|
||||
{
|
||||
cerr << "Program requires a semi-honest protocol" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::suggest_optimizations()
|
||||
{
|
||||
@@ -599,8 +583,8 @@ void Machine<sint, sgf2n>::suggest_optimizations()
|
||||
optimizations.append("\tprogram.use_edabit(True)\n");
|
||||
if (not optimizations.empty())
|
||||
cerr << "This program might benefit from some protocol options." << endl
|
||||
<< "Consider adding the following at the beginning of '" << progname
|
||||
<< ".mpc':" << endl << optimizations;
|
||||
<< "Consider adding the following at the beginning of your code:"
|
||||
<< endl << optimizations;
|
||||
#ifndef __clang__
|
||||
cerr << "This virtual machine was compiled with GCC. Recompile with "
|
||||
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;
|
||||
|
||||
@@ -172,7 +172,7 @@ void OfflineMachine<W>::generate()
|
||||
auto& opts = OnlineOptions::singleton;
|
||||
opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch;
|
||||
for (int i = 0; i < buffered_total(total, batch) / batch; i++)
|
||||
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
|
||||
preprocessing.get_edabitvec(true, n_bits).output(n_bits,
|
||||
out);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -44,6 +44,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
auto& queues = machine.queues[num];
|
||||
queues->next();
|
||||
ThreadQueue::thread_queue = queues;
|
||||
|
||||
#ifdef DEBUG_THREADS
|
||||
fprintf(stderr, "\tI am in thread %d\n",num);
|
||||
@@ -118,6 +119,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
DataPositions actual_usage(P.num_players());
|
||||
Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer;
|
||||
thread_timer.start();
|
||||
TimerWithComm timer, online_timer, online_prep_timer;
|
||||
timer.start();
|
||||
|
||||
while (flag)
|
||||
{ // Wait until I have a program to run
|
||||
@@ -262,6 +265,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
#ifdef DEBUG_THREADS
|
||||
printf("\tClient %d about to run %d\n",num,program);
|
||||
#endif
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
Proc.reset(progs[program], job.arg);
|
||||
|
||||
// Bits, Triples, Squares, and Inverses skipping
|
||||
@@ -290,6 +295,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
printf("\tSignalling I have finished with program %d"
|
||||
"in thread %d\n", program, num);
|
||||
#endif
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
wait_timer.start();
|
||||
queues->finished(job, P.total_comm());
|
||||
wait_timer.stop();
|
||||
@@ -297,7 +304,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
}
|
||||
|
||||
// final check
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
Proc.check();
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
|
||||
if (machine.opts.file_prep_per_thread)
|
||||
Proc.DataF.prune();
|
||||
@@ -330,6 +341,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
// wind down thread by thread
|
||||
machine.stats += Proc.stats;
|
||||
queues->timers["wait"] = wait_timer + queues->wait_timer;
|
||||
timer.stop(P.total_comm());
|
||||
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
|
||||
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
|
||||
|
||||
// prevent faulty usage message
|
||||
Proc.DataF.set_usage(actual_usage);
|
||||
delete processor;
|
||||
|
||||
@@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
cerr << " edaBits of size " << n_bits << " left" << endl;
|
||||
}
|
||||
|
||||
if (n > used / 10)
|
||||
if (n * n_batch > used / 10)
|
||||
cerr << "Significant amount of unused edaBits of size " << n_bits
|
||||
<< ". For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size "
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "Math/field_types.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
|
||||
class PrepBase
|
||||
{
|
||||
@@ -28,6 +29,8 @@ public:
|
||||
const string& type_string, size_t used);
|
||||
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used);
|
||||
|
||||
TimerWithComm prep_timer;
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PREPBASE_H_ */
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "Processor/Program.h"
|
||||
#include "GC/square64.h"
|
||||
#include "SpecificPrivateOutput.h"
|
||||
#include "Conv2dTuple.h"
|
||||
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
@@ -31,6 +32,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
|
||||
DataF.set_proc(this);
|
||||
protocol.init(DataF, MC);
|
||||
DataF.set_protocol(protocol);
|
||||
MC.set_prep(DataF);
|
||||
bit_usage.set_num_players(P.num_players());
|
||||
personal_bit_preps.resize(P.num_players());
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
@@ -40,6 +42,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
|
||||
template<class T>
|
||||
SubProcessor<T>::~SubProcessor()
|
||||
{
|
||||
DataF.set_proc(0);
|
||||
for (size_t i = 0; i < personal_bit_preps.size(); i++)
|
||||
{
|
||||
auto& x = personal_bit_preps[i];
|
||||
@@ -391,7 +394,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
|
||||
return;
|
||||
|
||||
string filename;
|
||||
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data";
|
||||
filename = binary_file_io.filename(P.my_num());
|
||||
|
||||
unsigned int size = data_registers.size();
|
||||
|
||||
@@ -652,21 +655,35 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
auto& args = instruction.get_start();
|
||||
int output_h = args[0], output_w = args[1];
|
||||
int inputs_h = args[2], inputs_w = args[3];
|
||||
int weights_h = args[4], weights_w = args[5];
|
||||
int stride_h = args[6], stride_w = args[7];
|
||||
int n_channels_in = args[8];
|
||||
int padding_h = args[9];
|
||||
int padding_w = args[10];
|
||||
int batch_size = args[11];
|
||||
size_t r0 = instruction.get_r(0);
|
||||
size_t r1 = instruction.get_r(1);
|
||||
int r2 = instruction.get_r(2);
|
||||
int lengths[batch_size][output_h][output_w];
|
||||
memset(lengths, 0, sizeof(lengths));
|
||||
int filter_stride_h = 1;
|
||||
int filter_stride_w = 1;
|
||||
vector<Conv2dTuple> tuples;
|
||||
for (size_t i = 0; i < args.size(); i += 15)
|
||||
tuples.push_back(Conv2dTuple(args, i));
|
||||
for (auto& tuple : tuples)
|
||||
tuple.pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (auto& tuple : tuples)
|
||||
tuple.post(S, protocol);
|
||||
}
|
||||
|
||||
inline
|
||||
Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
|
||||
{
|
||||
assert(arguments.size() >= start + 15ul);
|
||||
auto args = arguments.data() + start + 3;
|
||||
output_h = args[0], output_w = args[1];
|
||||
inputs_h = args[2], inputs_w = args[3];
|
||||
weights_h = args[4], weights_w = args[5];
|
||||
stride_h = args[6], stride_w = args[7];
|
||||
n_channels_in = args[8];
|
||||
padding_h = args[9];
|
||||
padding_w = args[10];
|
||||
batch_size = args[11];
|
||||
r0 = arguments[start];
|
||||
r1 = arguments[start + 1];
|
||||
r2 = arguments[start + 2];
|
||||
lengths.resize(batch_size, vector<vector<int>>(output_h, vector<int>(output_w)));
|
||||
filter_stride_h = 1;
|
||||
filter_stride_w = 1;
|
||||
if (stride_h < 0)
|
||||
{
|
||||
filter_stride_h = -stride_h;
|
||||
@@ -677,7 +694,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
filter_stride_w = -stride_w;
|
||||
stride_w = 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;
|
||||
@@ -714,9 +735,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protocol.exchange();
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
size_t base = r0 + i_batch * output_h * output_w;
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include "ThreadQueue.h"
|
||||
|
||||
thread_local ThreadQueue* ThreadQueue::thread_queue = 0;
|
||||
|
||||
void ThreadQueue::schedule(const ThreadJob& job)
|
||||
{
|
||||
lock.lock();
|
||||
@@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job)
|
||||
cerr << this << ": " << left << " left" << endl;
|
||||
#endif
|
||||
lock.unlock();
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.start();
|
||||
in.push(job);
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.stop();
|
||||
}
|
||||
|
||||
ThreadJob ThreadQueue::next()
|
||||
@@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
|
||||
|
||||
ThreadJob ThreadQueue::result()
|
||||
{
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.start();
|
||||
auto res = out.pop();
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.stop();
|
||||
lock.lock();
|
||||
left--;
|
||||
#ifdef DEBUG_THREAD_QUEUE
|
||||
|
||||
@@ -16,6 +16,11 @@ class ThreadQueue
|
||||
NamedCommStats comm_stats;
|
||||
|
||||
public:
|
||||
static thread_local ThreadQueue* thread_queue;
|
||||
|
||||
map<string, TimerWithComm> timers;
|
||||
Timer wait_timer;
|
||||
|
||||
ThreadQueue() :
|
||||
left(0)
|
||||
{
|
||||
|
||||
@@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job)
|
||||
}
|
||||
available.clear();
|
||||
}
|
||||
|
||||
TimerWithComm ThreadQueues::sum(const string& phase)
|
||||
{
|
||||
TimerWithComm res;
|
||||
for (auto& x : *this)
|
||||
res += x->timers[phase];
|
||||
return res;
|
||||
}
|
||||
|
||||
void ThreadQueues::print_breakdown()
|
||||
{
|
||||
if (size() > 0)
|
||||
{
|
||||
if (size() == 1)
|
||||
{
|
||||
cerr << "Spent " << (*this)[0]->timers["online"].full()
|
||||
<< " on the online phase and "
|
||||
<< (*this)[0]->timers["prep"].full()
|
||||
<< " on the preprocessing/offline phase." << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << size() << " threads spent a total of " << sum("online").full()
|
||||
<< " on the online phase, " << sum("prep").full()
|
||||
<< " on the preprocessing/offline phase, and "
|
||||
<< sum("wait").full() << " idling." << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ public:
|
||||
int distribute_no_setup(ThreadJob job, int n_items, int base = 0,
|
||||
int granularity = 1, const vector<void*>* supplies = 0);
|
||||
void wrap_up(ThreadJob job);
|
||||
|
||||
TimerWithComm sum(const string& phase);
|
||||
|
||||
void print_breakdown();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_THREADQUEUES_H_ */
|
||||
|
||||
@@ -387,6 +387,7 @@
|
||||
X(GENSECSHUFFLE, throw not_implemented(),) \
|
||||
X(APPLYSHUFFLE, throw not_implemented(),) \
|
||||
X(DELSHUFFLE, throw not_implemented(),) \
|
||||
X(ACTIVE, throw not_implemented(),) \
|
||||
|
||||
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
|
||||
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS
|
||||
|
||||
Reference in New Issue
Block a user