Maintenance.

This commit is contained in:
Marcel Keller
2023-05-09 14:49:52 +10:00
parent c62ab2ca1e
commit 6cc3fccef0
135 changed files with 1658 additions and 1062 deletions

View File

@@ -126,7 +126,7 @@ void BaseMachine::time()
void BaseMachine::start(int n)
{
cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " (" << timer[n].mb_sent() << " MB)"
<< " (" << timer[n] << ")"
<< " after " << timer[n].idle() << endl;
timer[n].start(total_comm());
}
@@ -135,7 +135,7 @@ void BaseMachine::stop(int n)
{
timer[n].stop(total_comm());
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
<< timer[n].mb_sent() << " MB)" << endl;
<< timer[n] << ")" << endl;
}
void BaseMachine::print_timers()
@@ -150,7 +150,7 @@ void BaseMachine::print_timers()
timer.erase(0);
for (auto it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
<< it->second.mb_sent() << " MB)" << endl;
<< it->second << ")" << endl;
}
string BaseMachine::memory_filename(const string& type_short, int my_number)
@@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
}
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (nthreads > 1)
cerr << "; rounds counted double due to multi-threading";
if (not OnlineOptions::singleton.verbose)
cerr << "; use '-v' for more details";
cerr << ")" << endl;
print_global_comm(P, comm_stats);
}

View File

@@ -67,6 +67,7 @@ public:
void print_timers();
virtual void reqbl(int) {}
virtual void active(int) {}
static OTTripleSetup fresh_ot_setup(Player& P);
@@ -74,6 +75,7 @@ public:
void set_thread_comm(const NamedCommStats& stats);
void print_global_comm(Player& P, const NamedCommStats& stats);
void print_comm(Player& P, const NamedCommStats& stats);
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)

41
Processor/Conv2dTuple.h Normal file
View File

@@ -0,0 +1,41 @@
/*
* Conv2dTuple.h
*
*/
#ifndef PROCESSOR_CONV2DTUPLE_H_
#define PROCESSOR_CONV2DTUPLE_H_
#include <vector>
using namespace std;
class Conv2dTuple
{
public:
int output_h, output_w;
int inputs_h, inputs_w;
int weights_h, weights_w;
int stride_h, stride_w;
int n_channels_in;
int padding_h;
int padding_w;
int batch_size;
size_t r0;
size_t r1;
int r2;
vector<vector<vector<int>>> lengths;
int filter_stride_h = 1;
int filter_stride_w = 1;
Conv2dTuple(const vector<int>& args, int start);
template<class T>
void pre(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void post(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void run_matrix(SubProcessor<T>& processor);
};
#endif /* PROCESSOR_CONV2DTUPLE_H_ */

View File

@@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const
for (auto it = edabits.begin(); it != edabits.end(); it++)
{
auto x = other.edabits.find(it->first);
if (x == other.edabits.end() or it->second > x->second)
if ((x == other.edabits.end() or it->second > x->second)
and it->second > 0)
return true;
}

View File

@@ -12,6 +12,8 @@
#include "Networking/Player.h"
#include "Protocols/edabit.h"
#include "PrepBase.h"
#include "EdabitBuffer.h"
#include "Tools/TimerWithComm.h"
#include <fstream>
#include <map>
@@ -102,9 +104,6 @@ protected:
DataPositions& usage;
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
map<pair<bool, int>, edabitvec<T>> my_edabits;
bool do_count;
void count(Dtype dtype, int n = 1)
@@ -120,6 +119,8 @@ protected:
const vector<int>&, true_type)
{ throw not_implemented(); }
void fill(edabitvec<T>& res, bool strict, int n_bits);
T get_random_from_inputs(int nplayers);
public:
@@ -173,12 +174,11 @@ public:
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs)
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
template<int>
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
template<int>
virtual void get_edabit_no_count(bool, int, edabit<T>&)
{ throw runtime_error("no edaBits"); }
/// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits);
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
virtual edabitvec<T> get_edabitvec(bool, int)
{ throw runtime_error("no edabitvec"); }
virtual void push_triples(const vector<array<T, 3>>&)
{ throw runtime_error("no pushing"); }
@@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing<T>
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
map<int, ifstream*> edabit_buffers;
map<int, EdabitBuffer<T>> edabit_buffers;
map<int, edabitvec<T>> my_edabits;
int my_num,num_players;
@@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing<T>
part_type* part;
void buffer_edabits_with_queues(bool strict, int n_bits)
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
template<int>
void buffer_edabits_with_queues(bool strict, int n_bits, false_type);
template<int>
void buffer_edabits_with_queues(bool, int, true_type)
{ throw not_implemented(); }
EdabitBuffer<T>& get_edabit_buffer(int n_bits);
/// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits);
void get_edabit_no_count(bool strict, int n_bits, edabit<T>& eb);
public:
static string get_filename(const Names& N, Dtype type, int thread_num = -1);
@@ -317,6 +316,8 @@ class Data_Files
void reset_usage() { usage.reset(); skipped.reset(); }
void set_usage(const DataPositions& pos) { usage = pos; }
TimerWithComm total_time();
};
template<class T> inline

View File

@@ -108,7 +108,21 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
#ifdef DEBUG_FILES
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
#endif
T::clear::check_setup(prep_data_dir);
try
{
T::clear::check_setup(prep_data_dir);
}
catch (...)
{
cerr << "Something is wrong with the preprocessing data on disk." << endl;
cerr
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
<< num_players
<< T::clear::fake_opts() << "'?" << endl;
throw;
}
string type_short = T::type_short();
string type_string = T::type_string();
@@ -135,7 +149,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
type_short, i, my_num, thread_num);
if (i == my_num)
my_input_buffers.setup(filename,
T::size() + T::clear::size(), type_string);
InputTuple<T>::size(), type_string);
else
input_buffers[i].setup(filename,
T::size(), type_string);
@@ -179,10 +193,6 @@ Data_Files<sint, sgf2n>::~Data_Files()
template<class T>
Sub_Data_Files<T>::~Sub_Data_Files()
{
for (auto& x: edabit_buffers)
{
delete x.second;
}
if (part != 0)
delete part;
}
@@ -229,6 +239,26 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
extended[it->first].seekg(it->second);
}
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
if (field_type == DATA_INT)
{
for (auto& x : pos.edabits)
{
// open files
get_edabit_buffer(x.first.second);
}
int block_size = edabitvec<T>::MAX_SIZE;
for (auto& x : edabit_buffers)
{
int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}];
x.second.seekg(n / block_size);
edabit<T> eb;
for (int i = 0; i < n % block_size; i++)
get_edabit_no_count(false, x.first, eb);
}
}
}
template<class sint, class sgf2n>
@@ -262,6 +292,8 @@ void Sub_Data_Files<T>::prune()
dabit_buffer.prune();
if (part != 0)
part->prune();
for (auto& x : edabit_buffers)
x.second.prune();
}
template<class sint, class sgf2n>
@@ -285,6 +317,8 @@ void Sub_Data_Files<T>::purge()
dabit_buffer.purge();
if (part != 0)
part->purge();
for (auto& x : edabit_buffers)
x.second.prune();
}
template<class T>
@@ -322,34 +356,43 @@ void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
}
template<class T>
template<int>
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
false_type)
EdabitBuffer<T>& Sub_Data_Files<T>::get_edabit_buffer(int n_bits)
{
if (edabit_buffers.empty())
insecure("reading edaBits from files");
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
{
string filename = PrepBase::get_edabit_filename(prep_data_dir,
n_bits, my_num, thread_num);
ifstream* f = new ifstream(filename);
if (f->fail())
throw runtime_error("cannot open " + filename);
check_file_signature<T>(*f, filename);
edabit_buffers[n_bits] = f;
edabit_buffers[n_bits] = n_bits;
edabit_buffers[n_bits].setup(filename,
T::size() * edabitvec<T>::MAX_SIZE
+ n_bits * T::bit_type::part_type::size());
}
auto& buffer = *edabit_buffers[n_bits];
if (buffer.peek() == EOF)
return edabit_buffers[n_bits];
}
template<class T>
edabitvec<T> Sub_Data_Files<T>::get_edabitvec(bool strict, int n_bits)
{
if (my_edabits[n_bits].empty())
return get_edabit_buffer(n_bits).read();
else
{
buffer.seekg(0);
check_file_signature<T>(buffer, "");
auto res = my_edabits[n_bits];
my_edabits[n_bits] = {};
this->fill(res, strict, n_bits);
return res;
}
}
template<class T>
void Preprocessing<T>::fill(edabitvec<T>& res, bool strict, int n_bits)
{
edabit<T> eb;
while (res.size() < res.MAX_SIZE)
{
get_edabit_no_count(strict, n_bits, eb);
res.push_back(eb);
}
edabitvec<T> eb;
eb.input(n_bits, buffer);
this->edabits[{strict, n_bits}].push_back(eb);
if (buffer.fail())
throw runtime_error("error reading edaBits");
}
template<class T>
@@ -362,4 +405,10 @@ typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
return *part;
}
template<class sint, class sgf2n>
TimerWithComm Data_Files<sint, sgf2n>::total_time()
{
return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer;
}
#endif

50
Processor/EdabitBuffer.h Normal file
View File

@@ -0,0 +1,50 @@
/*
* EdabitBuffer.h
*
*/
#ifndef PROCESSOR_EDABITBUFFER_H_
#define PROCESSOR_EDABITBUFFER_H_
#include "Tools/Buffer.h"
template<class T>
class EdabitBuffer : public BufferOwner<T, T>
{
int n_bits;
int element_length()
{
return -1;
}
public:
EdabitBuffer(int n_bits = 0) :
n_bits(n_bits)
{
}
edabitvec<T> read()
{
if (not BufferBase::file)
{
if (this->open()->fail())
throw runtime_error("error opening " + this->filename);
}
assert(BufferBase::file);
auto& buffer = *BufferBase::file;
if (buffer.peek() == EOF)
{
this->try_rewind();
}
edabitvec<T> eb;
eb.input(n_bits, buffer);
if (buffer.fail())
throw runtime_error("error reading edaBits");
return eb;
}
};
#endif /* PROCESSOR_EDABITBUFFER_H_ */

View File

@@ -70,6 +70,7 @@ enum
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
// Addition
ADDC = 0x20,
ADDS = 0x21,

View File

@@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRIVATEOUTPUT:
case TRUNC_PR:
case RUN_TAPE:
case CONV2DS:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
@@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_ints(r, s, 3);
get_vector(9, start, s);
break;
case CONV2DS:
get_ints(r, s, 3);
get_vector(12, start, s);
break;
// read from file, input is opcode num_args,
// start_file_posn (read), end_file_posn(write) var1, var2, ...
@@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
throw Processor_Error(ss.str());
}
break;
case ACTIVE:
n = get_int(s);
BaseMachine::s().active(n);
break;
case XORM:
case ANDM:
case XORCB:
@@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case MATMULSM:
return r[0] + start[0] * start[2];
case CONV2DS:
return r[0] + start[0] * start[1] * start[11];
{
unsigned res = 0;
for (size_t i = 0; i < start.size(); i += 15)
{
unsigned tmp = start[i]
+ start[i + 3] * start[i + 4] * start.at(i + 14);
res = max(res, tmp);
}
return res;
}
case OPEN:
skip = 2;
break;
@@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case REQBL:
case GREQBL:
case ACTIVE:
case USE:
case USE_INP:
case USE_EDABIT:

View File

@@ -109,6 +109,7 @@ class Machine : public BaseMachine
string prep_dir_prefix();
void reqbl(int n);
void active(int n);
typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; }
typename sint::mac_key_type get_sint_mac_key() { return alphapi; }

View File

@@ -415,6 +415,9 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
auto comm_stats = total_comm();
if (OnlineOptions::singleton.verbose)
queues.print_breakdown();
for (auto& queue : queues)
delete queue;
@@ -477,20 +480,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
print_timers();
if (sint::is_real)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << my_number;
if (threads.size() > 1)
cerr << "; rounds counted double due to multi-threading";
cerr << "; use '-v' for more details";
cerr << ")" << endl;
auto& P = *this->P;
this->print_global_comm(P, comm_stats);
}
this->print_comm(*this->P, comm_stats);
#ifdef VERBOSE_OPTIONS
if (opening_sum < N.num_players() && !direct)
@@ -521,23 +511,6 @@ void Machine<sint, sgf2n>::run(const string& progname)
bit_memories.write_memory(N.my_num());
#ifdef OLD_USAGE
for (int dtype = 0; dtype < N_DTYPE; dtype++)
{
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
cerr << " " << pos.files[field_type][dtype];
cerr << endl;
}
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t=";
for (int i = 0; i < N.num_players(); i++)
cerr << " " << pos.inputs[i][field_type];
cerr << endl;
}
#endif
if (opts.verbose)
{
cerr << "Actual cost of program:" << endl;
@@ -586,6 +559,17 @@ void Machine<sint, sgf2n>::reqbl(int n)
sint::clear::reqbl(n);
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::active(int n)
{
if (sint::malicious and n == 0)
{
cerr << "Program requires a semi-honest protocol" << endl;
exit(1);
}
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::suggest_optimizations()
{
@@ -599,8 +583,8 @@ void Machine<sint, sgf2n>::suggest_optimizations()
optimizations.append("\tprogram.use_edabit(True)\n");
if (not optimizations.empty())
cerr << "This program might benefit from some protocol options." << endl
<< "Consider adding the following at the beginning of '" << progname
<< ".mpc':" << endl << optimizations;
<< "Consider adding the following at the beginning of your code:"
<< endl << optimizations;
#ifndef __clang__
cerr << "This virtual machine was compiled with GCC. Recompile with "
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;

View File

@@ -172,7 +172,7 @@ void OfflineMachine<W>::generate()
auto& opts = OnlineOptions::singleton;
opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch;
for (int i = 0; i < buffered_total(total, batch) / batch; i++)
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
preprocessing.get_edabitvec(true, n_bits).output(n_bits,
out);
}
else

View File

@@ -44,6 +44,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
auto& queues = machine.queues[num];
queues->next();
ThreadQueue::thread_queue = queues;
#ifdef DEBUG_THREADS
fprintf(stderr, "\tI am in thread %d\n",num);
@@ -118,6 +119,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
DataPositions actual_usage(P.num_players());
Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer;
thread_timer.start();
TimerWithComm timer, online_timer, online_prep_timer;
timer.start();
while (flag)
{ // Wait until I have a program to run
@@ -262,6 +265,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#ifdef DEBUG_THREADS
printf("\tClient %d about to run %d\n",num,program);
#endif
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.reset(progs[program], job.arg);
// Bits, Triples, Squares, and Inverses skipping
@@ -290,6 +295,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
printf("\tSignalling I have finished with program %d"
"in thread %d\n", program, num);
#endif
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
wait_timer.start();
queues->finished(job, P.total_comm());
wait_timer.stop();
@@ -297,7 +304,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
}
// final check
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.check();
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
if (machine.opts.file_prep_per_thread)
Proc.DataF.prune();
@@ -330,6 +341,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// wind down thread by thread
machine.stats += Proc.stats;
queues->timers["wait"] = wait_timer + queues->wait_timer;
timer.stop(P.total_comm());
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
// prevent faulty usage message
Proc.DataF.set_usage(actual_usage);
delete processor;

View File

@@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
cerr << " edaBits of size " << n_bits << " left" << endl;
}
if (n > used / 10)
if (n * n_batch > used / 10)
cerr << "Significant amount of unused edaBits of size " << n_bits
<< ". For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size "

View File

@@ -10,6 +10,7 @@
using namespace std;
#include "Math/field_types.h"
#include "Tools/TimerWithComm.h"
class PrepBase
{
@@ -28,6 +29,8 @@ public:
const string& type_string, size_t used);
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used);
TimerWithComm prep_timer;
};
#endif /* PROCESSOR_PREPBASE_H_ */

View File

@@ -5,6 +5,7 @@
#include "Processor/Program.h"
#include "GC/square64.h"
#include "SpecificPrivateOutput.h"
#include "Conv2dTuple.h"
#include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp"
@@ -31,6 +32,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
DataF.set_proc(this);
protocol.init(DataF, MC);
DataF.set_protocol(protocol);
MC.set_prep(DataF);
bit_usage.set_num_players(P.num_players());
personal_bit_preps.resize(P.num_players());
for (int i = 0; i < P.num_players(); i++)
@@ -40,6 +42,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
template<class T>
SubProcessor<T>::~SubProcessor()
{
DataF.set_proc(0);
for (size_t i = 0; i < personal_bit_preps.size(); i++)
{
auto& x = personal_bit_preps[i];
@@ -391,7 +394,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
return;
string filename;
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data";
filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size();
@@ -652,21 +655,35 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
{
protocol.init_dotprod();
auto& args = instruction.get_start();
int output_h = args[0], output_w = args[1];
int inputs_h = args[2], inputs_w = args[3];
int weights_h = args[4], weights_w = args[5];
int stride_h = args[6], stride_w = args[7];
int n_channels_in = args[8];
int padding_h = args[9];
int padding_w = args[10];
int batch_size = args[11];
size_t r0 = instruction.get_r(0);
size_t r1 = instruction.get_r(1);
int r2 = instruction.get_r(2);
int lengths[batch_size][output_h][output_w];
memset(lengths, 0, sizeof(lengths));
int filter_stride_h = 1;
int filter_stride_w = 1;
vector<Conv2dTuple> tuples;
for (size_t i = 0; i < args.size(); i += 15)
tuples.push_back(Conv2dTuple(args, i));
for (auto& tuple : tuples)
tuple.pre(S, protocol);
protocol.exchange();
for (auto& tuple : tuples)
tuple.post(S, protocol);
}
inline
Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
{
assert(arguments.size() >= start + 15ul);
auto args = arguments.data() + start + 3;
output_h = args[0], output_w = args[1];
inputs_h = args[2], inputs_w = args[3];
weights_h = args[4], weights_w = args[5];
stride_h = args[6], stride_w = args[7];
n_channels_in = args[8];
padding_h = args[9];
padding_w = args[10];
batch_size = args[11];
r0 = arguments[start];
r1 = arguments[start + 1];
r2 = arguments[start + 2];
lengths.resize(batch_size, vector<vector<int>>(output_h, vector<int>(output_w)));
filter_stride_h = 1;
filter_stride_w = 1;
if (stride_h < 0)
{
filter_stride_h = -stride_h;
@@ -677,7 +694,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
filter_stride_w = -stride_w;
stride_w = 1;
}
}
template<class T>
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;
@@ -714,9 +735,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
protocol.next_dotprod();
}
}
}
protocol.exchange();
template<class T>
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
size_t base = r0 + i_batch * output_h * output_w;

View File

@@ -6,6 +6,8 @@
#include "ThreadQueue.h"
thread_local ThreadQueue* ThreadQueue::thread_queue = 0;
void ThreadQueue::schedule(const ThreadJob& job)
{
lock.lock();
@@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job)
cerr << this << ": " << left << " left" << endl;
#endif
lock.unlock();
if (thread_queue)
thread_queue->wait_timer.start();
in.push(job);
if (thread_queue)
thread_queue->wait_timer.stop();
}
ThreadJob ThreadQueue::next()
@@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
ThreadJob ThreadQueue::result()
{
if (thread_queue)
thread_queue->wait_timer.start();
auto res = out.pop();
if (thread_queue)
thread_queue->wait_timer.stop();
lock.lock();
left--;
#ifdef DEBUG_THREAD_QUEUE

View File

@@ -16,6 +16,11 @@ class ThreadQueue
NamedCommStats comm_stats;
public:
static thread_local ThreadQueue* thread_queue;
map<string, TimerWithComm> timers;
Timer wait_timer;
ThreadQueue() :
left(0)
{

View File

@@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job)
}
available.clear();
}
TimerWithComm ThreadQueues::sum(const string& phase)
{
TimerWithComm res;
for (auto& x : *this)
res += x->timers[phase];
return res;
}
void ThreadQueues::print_breakdown()
{
if (size() > 0)
{
if (size() == 1)
{
cerr << "Spent " << (*this)[0]->timers["online"].full()
<< " on the online phase and "
<< (*this)[0]->timers["prep"].full()
<< " on the preprocessing/offline phase." << endl;
}
else
{
cerr << size() << " threads spent a total of " << sum("online").full()
<< " on the online phase, " << sum("prep").full()
<< " on the preprocessing/offline phase, and "
<< sum("wait").full() << " idling." << endl;
}
}
}

View File

@@ -24,6 +24,10 @@ public:
int distribute_no_setup(ThreadJob job, int n_items, int base = 0,
int granularity = 1, const vector<void*>* supplies = 0);
void wrap_up(ThreadJob job);
TimerWithComm sum(const string& phase);
void print_breakdown();
};
#endif /* PROCESSOR_THREADQUEUES_H_ */

View File

@@ -387,6 +387,7 @@
X(GENSECSHUFFLE, throw not_implemented(),) \
X(APPLYSHUFFLE, throw not_implemented(),) \
X(DELSHUFFLE, throw not_implemented(),) \
X(ACTIVE, throw not_implemented(),) \
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS