Maintenance.

This commit is contained in:
Marcel Keller
2024-07-09 12:17:25 +10:00
parent b0dc2b36f8
commit 78fe3d8bad
234 changed files with 4273 additions and 1367 deletions

View File

@@ -70,6 +70,29 @@ int BaseMachine::bucket_size(size_t usage)
return res;
}
int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols)
{
unsigned res = min(100, OnlineOptions::singleton.batch_size);
if (has_program())
res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols));
return res;
}
int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
{
if (has_program())
{
auto res = s().progs[0].get_offline_data_used().matmuls[
{n_rows, n_inner, n_cols}];
if (res)
return res;
else
return -1;
}
else
return -1;
}
BaseMachine::BaseMachine() : nthreads(0)
{
if (sodium_init() == -1)
@@ -287,7 +310,7 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (nthreads > 1)
if (multithread)
cerr << "; rounds counted double due to multi-threading";
if (not OnlineOptions::singleton.verbose)
cerr << "; use '-v' for more details";

View File

@@ -44,6 +44,7 @@ public:
string progname;
int nthreads;
bool multithread;
ThreadQueues queues;
@@ -66,10 +67,14 @@ public:
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
template<class T>
static int input_batch_size(int player, int buffer_size = 0);
template<class T>
static int edabit_batch_size(int n_bits, int buffer_size = 0);
static int edabit_bucket_size(int n_bits);
static int triple_bucket_size(DataFieldType type);
static int bucket_size(size_t usage);
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
BaseMachine();
virtual ~BaseMachine() {}
@@ -105,6 +110,10 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
template<class T>
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
{
if (OnlineOptions::singleton.has_option("debug_batch_size"))
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
fallback);
int n_opts;
int n = 0;
int res = 0;
@@ -114,7 +123,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
else if (fallback > 0)
n_opts = fallback;
else
n_opts = OnlineOptions::singleton.batch_size;
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
if (buffer_size <= 0 and has_program())
{
@@ -132,7 +141,6 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
{
n = buffer_size;
buffer_size = 0;
n_opts = OnlineOptions::singleton.batch_size;
}
if (n > 0 and not (buffer_size > 0))
@@ -161,16 +169,33 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
else
res = n_opts;
#ifdef DEBUG_BATCH_SIZE
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n="
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
#endif
if (OnlineOptions::singleton.has_option("debug_batch_size"))
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
<< " buffer_size=" << buffer_size << endl;
assert(res > 0);
return res;
}
template<class T>
int BaseMachine::input_batch_size(int player, int buffer_size)
{
if (buffer_size)
return buffer_size;
if (has_program())
{
auto res =
s().progs[0].get_offline_data_used(
).inputs[player][T::clear::field_type()];
if (res > 0)
return res;
}
return OnlineOptions::singleton.batch_size;
}
template<class T>
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
{

View File

@@ -29,10 +29,12 @@ public:
Conv2dTuple(const vector<int>& args, int start);
array<int, 3> matrix_dimensions();
template<class T>
void pre(vector<T>& S, typename T::Protocol& protocol);
void pre(StackedVector<T>& S, typename T::Protocol& protocol);
template<class T>
void post(vector<T>& S, typename T::Protocol& protocol);
void post(StackedVector<T>& S, typename T::Protocol& protocol);
template<class T>
void run_matrix(SubProcessor<T>& processor);

View File

@@ -12,8 +12,10 @@
#include "Networking/Player.h"
#include "Protocols/edabit.h"
#include "PrepBase.h"
#include "PrepBuffer.h"
#include "EdabitBuffer.h"
#include "Tools/TimerWithComm.h"
#include "Tools/CheckVector.h"
#include <fstream>
#include <map>
@@ -104,8 +106,6 @@ class Preprocessing : public PrepBase
protected:
static const bool use_part = false;
DataPositions& usage;
bool do_count;
void count(Dtype dtype, int n = 1)
@@ -115,9 +115,9 @@ protected:
template<int>
void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
template<int>
void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
void get_edabits(bool, size_t, T*, StackedVector<typename T::bit_type>&,
const vector<int>&, true_type)
{ throw not_implemented(); }
@@ -126,6 +126,8 @@ protected:
T get_random_from_inputs(int nplayers);
public:
int buffer_size;
template<class U, class V>
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
SubProcessor<T>* proc);
@@ -135,7 +137,8 @@ public:
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
DataPositions& usage);
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
Preprocessing(DataPositions& usage) :
PrepBase(usage), do_count(true), buffer_size(0) {}
virtual ~Preprocessing() {}
virtual void set_protocol(typename T::Protocol&) {};
@@ -151,7 +154,7 @@ public:
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
virtual void get_input_no_count(T&, typename T::open_type&, int)
{ throw not_implemented() ; }
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
virtual void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
{ throw not_implemented(); }
void get(Dtype dtype, T* a);
@@ -159,7 +162,7 @@ public:
void get_two(Dtype dtype, T& a, T& b);
void get_one(Dtype dtype, T& a);
void get_input(T& a, typename T::open_type& x, int i);
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
/// Get fresh random multiplication triple
virtual array<T, 3> get_triple(int n_bits);
@@ -174,7 +177,7 @@ public:
virtual void get_dabit(T& a, typename T::bit_type& b);
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs)
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs)
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
virtual void get_edabit_no_count(bool, int, edabit<T>&)
{ throw runtime_error("no edaBits"); }
@@ -201,11 +204,11 @@ class Sub_Data_Files : public Preprocessing<T>
static int tuple_length(int dtype);
BufferOwner<T, T> buffers[N_DTYPE];
vector<BufferOwner<T, T>> input_buffers;
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
array<PrepBuffer<T>, N_DTYPE> buffers;
vector<PrepBuffer<T>> input_buffers;
PrepBuffer<InputTuple<T>, RefInputTuple<T>, T> my_input_buffers;
map<DataTag, PrepBuffer<T> > extended;
PrepBuffer<dabit<T>, dabit<T>, T> dabit_buffer;
map<int, EdabitBuffer<T>> edabit_buffers;
map<int, edabitvec<T>> my_edabits;
@@ -284,7 +287,7 @@ public:
}
void setup_extended(const DataTag& tag, int tuple_size = 0);
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_dabit_no_count(T& a, typename T::bit_type& b);
part_type& get_part();
@@ -397,7 +400,7 @@ inline void Preprocessing<T>::get_input(T& a, typename T::open_type& x, int i)
}
template<class T>
inline void Preprocessing<T>::get(vector<T>& S, DataTag tag,
inline void Preprocessing<T>::get(StackedVector<T>& S, DataTag tag,
const vector<int>& regs, int vector_size)
{
usage.count(T::clear::field_type(), tag, vector_size);

View File

@@ -143,14 +143,14 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
{
if (T::clear::allows(Dtype(dtype)))
{
buffers[dtype].setup(
buffers[dtype].setup(num_players,
PrepBase::get_filename(prep_data_dir, Dtype(dtype), type_short,
my_num, thread_num), tuple_length(dtype), type_string,
DataPositions::dtype_names[dtype]);
}
}
dabit_buffer.setup(
dabit_buffer.setup(num_players,
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
type_short, my_num, thread_num), dabit<T>::size(), type_string,
DataPositions::dtype_names[DATA_DABIT]);
@@ -161,10 +161,10 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
string filename = PrepBase::get_input_filename(prep_data_dir,
type_short, i, my_num, thread_num);
if (i == my_num)
my_input_buffers.setup(filename,
my_input_buffers.setup(num_players, filename,
InputTuple<T>::size(), type_string);
else
input_buffers[i].setup(filename,
input_buffers[i].setup(num_players, filename,
T::size(), type_string);
}
@@ -344,14 +344,14 @@ void Sub_Data_Files<T>::setup_extended(const DataTag& tag, int tuple_size)
{
stringstream ss;
ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num;
buffer.setup(ss.str(), tuple_length);
buffer.setup(num_players, ss.str(), tuple_length);
}
buffer.check_tuple_length(tuple_length);
}
template<class T>
void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
void Sub_Data_Files<T>::get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
{
setup_extended(tag, regs.size());
for (int j = 0; j < vector_size; j++)

View File

@@ -12,6 +12,7 @@ using namespace std;
#include "Math/BitVec.h"
#include "Data_Files.h"
#include "Protocols/Replicated.h"
#include "Protocols/ReplicatedPrep.h"
#include "Protocols/MAC_Check_Base.h"
#include "Processor/Input.h"
@@ -109,7 +110,7 @@ public:
};
template<class T>
class DummyLivePrep : public Preprocessing<T>
class DummyLivePrep : public BufferPrep<T>
{
public:
static const bool homomorphic = true;
@@ -133,16 +134,16 @@ public:
}
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
DummyLivePrep(DataPositions& usage, bool = true) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
Preprocessing<T>(usage)
BufferPrep<T>(usage)
{
}
@@ -165,7 +166,7 @@ public:
{
fail();
}
void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
{
fail();
}

View File

@@ -81,7 +81,6 @@ int ExternalClients::init_client_connection(const string& host, int portnum,
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
true);
if (party_num == 0)
{
octetStream specification;
specification.Receive(socket);

View File

@@ -15,7 +15,7 @@
#include <iomanip>
template<class cgf2n>
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
void Instruction::execute_clear_gf2n(StackedVector<cgf2n>& registers,
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
{
auto& C2 = registers;
@@ -30,7 +30,7 @@ void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
}
template<class cgf2n>
void Instruction::gbitdec(vector<cgf2n>& registers) const
void Instruction::gbitdec(StackedVector<cgf2n>& registers) const
{
for (int j = 0; j < size; j++)
{
@@ -44,7 +44,7 @@ void Instruction::gbitdec(vector<cgf2n>& registers) const
}
template<class cgf2n>
void Instruction::gbitcom(vector<cgf2n>& registers) const
void Instruction::gbitcom(StackedVector<cgf2n>& registers) const
{
for (int j = 0; j < size; j++)
{
@@ -124,7 +124,7 @@ ostream& operator<<(ostream& s, const Instruction& instr)
return s;
}
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_short>& registers,
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_long>& registers,
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;

View File

@@ -15,6 +15,7 @@ template<class sint, class sgf2n> class Machine;
template<class sint, class sgf2n> class Processor;
template<class T> class SubProcessor;
template<class T> class MemoryPart;
template<class T> class StackedVector;
class ArithmeticProcessor;
class SwitchableOutput;
@@ -73,6 +74,8 @@ enum
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
CMDLINEARG = 0xEB,
CALL_TAPE = 0xEC,
CALL_ARG = 0xED,
// Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -90,6 +93,7 @@ enum
PREFIXSUMS = 0x2D,
PICKS = 0x2E,
CONCATS = 0x2F,
ZIPS = 0x3F,
// Multiplication/division/other arithmetic
MULC = 0x30,
MULM = 0x31,
@@ -351,6 +355,7 @@ protected:
string str;
public:
BaseInstruction() : opcode(0), size(0), n(0) {}
virtual ~BaseInstruction() {};
int get_r(int i) const { return r[i]; }
@@ -391,13 +396,13 @@ public:
void execute(Processor<sint, sgf2n>& Proc) const;
template<class cgf2n>
void execute_clear_gf2n(vector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
void execute_clear_gf2n(StackedVector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
ArithmeticProcessor& Proc) const;
template<class cgf2n>
void gbitdec(vector<cgf2n>& registers) const;
void gbitdec(StackedVector<cgf2n>& registers) const;
template<class cgf2n>
void gbitcom(vector<cgf2n>& registers) const;
void gbitcom(StackedVector<cgf2n>& registers) const;
void execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const;

View File

@@ -92,6 +92,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case DIVINT:
case CONDPRINTPLAIN:
case INPUTMASKREG:
case ZIPS:
get_ints(r, s, 3);
break;
// instructions with 2 register operands
@@ -245,6 +246,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case CONDPRINTSTRB:
case RANDOMS:
case GENSECSHUFFLE:
case CALL_ARG:
r[0]=get_int(s);
n = get_int(s);
break;
@@ -330,6 +332,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
// read from file, input is opcode num_args,
// start_file_posn (read), end_file_posn(write) var1, var2, ...
case READFILESHARE:
case CALL_TAPE:
num_var_args = get_int(s) - 2;
r[0] = get_int(s);
r[1] = get_int(s);
@@ -588,6 +591,7 @@ int BaseInstruction::get_reg_type() const
case ACCEPTCLIENTCONNECTION:
case GENSECSHUFFLE:
case CMDLINEARG:
case CALL_TAPE:
return INT;
case PREP:
case GPREP:
@@ -596,7 +600,6 @@ int BaseInstruction::get_reg_type() const
case USE_EDABIT:
case USE_MATMUL:
case RUN_TAPE:
case CISC:
// those use r[] not for registers
return NONE;
case LDI:
@@ -639,6 +642,8 @@ int BaseInstruction::get_reg_type() const
case PRIVATEOUTPUT:
case FIXINPUT:
return CINT;
case CALL_ARG:
return n;
default:
if (is_gf2n_instruction())
{
@@ -706,6 +711,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
}
else
return 0;
case CALL_TAPE:
{
int res = 0;
for (auto it = start.begin(); it < start.end(); it += 5)
if (it[1] == reg_type)
res = max(res, (*it ? it[3] : it[4]) + it[2]);
return res;
}
default:
if (get_reg_type() != reg_type)
return 0;
@@ -713,6 +726,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
switch (opcode)
{
case CISC:
{
int res = 0;
for (auto it = start.begin(); it < start.end(); it += *it)
{
assert(it + *it <= start.end());
res = max(res, it[1] + it[2]);
}
return res;
}
case MULS:
skip = 4;
offset = 1;
size_offset = -1;
break;
case DOTPRODS:
{
int res = 0;
@@ -737,7 +765,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
return res;
}
case MATMULSM:
return r[0] + start[0] * start[2];
{
int res = 0;
for (auto it = start.begin(); it < start.end(); it += 12)
{
res = max(res, *it + *(it + 3) * *(it + 5));
}
return res;
}
case CONV2DS:
{
unsigned res = 0;
@@ -956,6 +991,17 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
}
return;
}
case ZIPS:
{
auto& S = Proc.Procp.get_S();
auto dest = S.begin() + r[0];
for (int i = 0; i < get_size(); i++)
{
*dest++ = S[r[1] + i];
*dest++ = S[r[2] + i];
}
return;
}
case DIVC:
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2]));
break;
@@ -1232,6 +1278,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case JOIN_TAPE:
Proc.machine.join_tape(r[0]);
break;
case CALL_TAPE:
Proc.call_tape(r[0], Proc.read_Ci(r[1]), start);
break;
case CRASH:
if (Proc.read_Ci(r[0]))
throw crash_requested();
@@ -1277,7 +1326,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
// get client connection at port number n + my_num())
int client_handle = Proc.external_clients.get_client_connection(
Proc.read_Ci(r[1]));
if (Proc.P.my_num() == 0)
{
octetStream os;
os.store(int(sint::open_type::type_char()));
@@ -1421,7 +1469,8 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
#endif
#ifdef OUTPUT_INSTRUCTIONS
cerr << instruction << endl;
if (OnlineOptions::singleton.has_option("output_instructions"))
cerr << instruction << endl;
#endif
Proc.PC++;
@@ -1461,7 +1510,7 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c
for (int i = 0; i < size; i++)
{
if (p == 0 or (*p == 0 and s == 0))
out << v[i];
out.signed_output(v[i]);
else if (s == 0)
out << bigint::get_float(v[i], p[i], {}, {});
else

View File

@@ -235,6 +235,7 @@ size_t Machine<sint, sgf2n>::load_program(const string& threadname,
M2.minimum_size(SGF2N, CGF2N, progs[i], threadname);
Mp.minimum_size(SINT, CINT, progs[i], threadname);
Mi.minimum_size(NONE, INT, progs[i], threadname);
bit_memories.reset(progs[i]);
return progs.back().size();
}
@@ -340,14 +341,15 @@ void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
auto subdim = it->first;
subdim[1] = min(subdim[1] - j, max_inner);
subdim[2] = min(subdim[2] - k, max_cols);
auto& source =
dynamic_cast<Hemi<sint>&>(source_proc.protocol).get_matrix_prep(
auto& source_proto = dynamic_cast<Hemi<sint>&>(source_proc.protocol);
auto& source = source_proto.get_matrix_prep(
subdim, source_proc);
auto& dest =
dynamic_cast<Hemi<sint>&>(tinfo[thread_number].processor->Procp.protocol).get_matrix_prep(
subdim, tinfo[thread_number].processor->Procp);
for (int i = 0; i < it->second; i++)
dest.push_triple(source.get_triple_no_count(-1));
if (not source_proto.use_plain_matmul(subdim, source_proc))
for (int i = 0; i < it->second; i++)
dest.push_triple(source.get_triple_no_count(-1));
}
}
}
@@ -434,7 +436,13 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
auto comm_stats = total_comm();
if (OnlineOptions::singleton.verbose)
queues.print_breakdown();
{
NamedStats total;
for (auto queue : queues)
total += queue->stats;
total.print();
queues.print_breakdown();
}
for (auto& queue : queues)
delete queue;
@@ -464,7 +472,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
finish_timer.start();
// actual usage
bool multithread = nthreads > 1;
multithread = nthreads > 1;
auto res = stop_threads();
DataPositions& pos = res.first;
@@ -518,14 +526,15 @@ void Machine<sint, sgf2n>::run(const string& progname)
cerr << "Full broadcast" << endl;
#endif
#ifdef CHOP_MEMORY
// Reduce memory size to speed up
unsigned max_size = 1 << 20;
if (M2.size_s() > max_size)
M2.resize_s(max_size);
if (Mp.size_s() > max_size)
Mp.resize_s(max_size);
#endif
if (not OnlineOptions::singleton.has_option("output_full_memory"))
{
// Reduce memory size to speed up
unsigned max_size = 1 << 20;
if (M2.size_s() > max_size)
M2.resize_s(max_size);
if (Mp.size_s() > max_size)
Mp.resize_s(max_size);
}
// Write out the memory to use next time
ofstream outf(memory_filename(), ios::out | ios::binary);

View File

@@ -44,10 +44,10 @@ public:
virtual const T& at(size_t i) const = 0;
template<class U>
void indirect_read(const Instruction& inst, vector<T>& regs,
void indirect_read(const Instruction& inst, StackedVector<T>& regs,
const U& indices);
template<class U>
void indirect_write(const Instruction& inst, vector<T>& regs,
void indirect_write(const Instruction& inst, StackedVector<T>& regs,
const U& indices);
void minimum_size(size_t size);

View File

@@ -6,21 +6,21 @@
template<class T>
template<class U>
void MemoryPart<T>::indirect_read(const Instruction& inst,
vector<T>& regs, const U& indices)
StackedVector<T>& regs, const U& indices)
{
size_t n = inst.get_size();
auto dest = regs.begin() + inst.get_r(0);
auto start = indices.begin() + inst.get_r(1);
#ifdef CHECK_SIZE
#ifndef NO_CHECK_SIZE
assert(start + n <= indices.end());
assert(dest + n <= regs.end());
#endif
long size = this->size();
size_t size = this->size();
const T* data = this->data();
for (auto it = start; it < start + n; it++)
{
#ifndef NO_CHECK_SIZE
if (*it >= size)
if (size_t(it->get()) >= size)
throw overflow(T::type_string() + " memory read", it->get(), size);
#endif
*dest++ = data[it->get()];
@@ -30,21 +30,21 @@ void MemoryPart<T>::indirect_read(const Instruction& inst,
template<class T>
template<class U>
void MemoryPart<T>::indirect_write(const Instruction& inst,
vector<T>& regs, const U& indices)
StackedVector<T>& regs, const U& indices)
{
size_t n = inst.get_size();
auto source = regs.begin() + inst.get_r(0);
auto start = indices.begin() + inst.get_r(1);
#ifdef CHECK_SIZE
#ifndef NO_CHECK_SIZE
assert(start + n <= indices.end());
assert(source + n <= regs.end());
#endif
long size = this->size();
size_t size = this->size();
T* data = this->data();
for (auto it = start; it < start + n; it++)
{
#ifndef NO_CHECK_SIZE
if (*it >= size)
if (size_t(it->get()) >= size)
throw overflow(T::type_string() + " memory write", it->get(), size);
#endif
data[it->get()] = *source++;

View File

@@ -33,6 +33,7 @@ class thread_info
const char* name);
void Sub_Main_Func();
void Main_Func_With_Purge();
};
#endif

View File

@@ -340,6 +340,13 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
cerr << endl;
#endif
if (num == 0 and OnlineOptions::singleton.verbose
and machine.queues.size() > 1)
{
cerr << "Main thread communication:" << endl;
P.total_comm().print();
}
// wind down thread by thread
machine.stats += Proc.stats;
queues->timers["wait"] = wait_timer + queues->wait_timer;
@@ -347,11 +354,43 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
NamedStats stats;
stats["integer multiplications"] = Proc.Procp.protocol.counter;
stats["integer multiplication rounds"] = Proc.Procp.protocol.rounds;
stats["probabilistic truncations"] = Proc.Procp.protocol.trunc_pr_counter;
stats["probabilistic truncation rounds"] = Proc.Procp.protocol.trunc_rounds;
stats["ANDs"] = Proc.share_thread.protocol->bit_counter;
stats["AND rounds"] = Proc.share_thread.protocol->rounds;
stats["integer openings"] = MCp->values_opened;
stats["integer inputs"] = Proc.Procp.input.values_input;
for (auto x : Proc.Procp.shuffler.stats)
stats["shuffles of length " + to_string(x.first)] = x.second;
try
{
auto proc = dynamic_cast<RingPrep<sint>&>(Proc.DataF.DataFp).bit_part_proc;
if (proc)
stats["ANDs in preprocessing"] = proc->protocol.bit_counter;
}
catch (...)
{
}
try
{
auto protocol = dynamic_cast<BitPrep<sint>&>(Proc.DataF.DataFp).protocol;
if (protocol)
stats["integer multiplications in preprocessing"] = protocol->counter;
}
catch (...)
{
}
// prevent faulty usage message
Proc.DataF.set_usage(actual_usage);
delete processor;
queues->finished(actual_usage, P.total_comm());
queues->finished(actual_usage, P.total_comm(), stats);
delete MC2;
delete MCp;
@@ -367,6 +406,27 @@ template<class sint, class sgf2n>
void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
{
auto& ti = *(thread_info<sint, sgf2n>*)(ptr);
if (OnlineOptions::singleton.has_option("throw_exceptions"))
ti.Main_Func_With_Purge();
else
{
try
{
ti.Main_Func_With_Purge();
}
catch (exception& e)
{
cerr << "Fatal error: " << e.what() << endl;
exit(1);
}
}
return 0;
}
template<class sint, class sgf2n>
void thread_info<sint, sgf2n>::Main_Func_With_Purge()
{
auto& ti = *this;
#ifdef INSECURE
ti.Sub_Main_Func();
#else
@@ -383,12 +443,10 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
}
catch (...)
{
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num);
purge_preprocessing(machine->get_N(), thread_num);
throw;
}
#endif
return 0;
}

View File

@@ -36,6 +36,8 @@ public:
template<class T, class U>
int run();
template<class T, class U>
int run_with_error();
Player* new_player(const string& id_base);

View File

@@ -173,6 +173,25 @@ Player* OnlineMachine::new_player(const string& id_base)
template<class T, class U>
int OnlineMachine::run()
{
if (online_opts.has_option("throw_exception"))
return run_with_error<T, U>();
else
{
try
{
return run_with_error<T, U>();
}
catch (exception& e)
{
cerr << "Fatal error: " << e.what() << endl;
exit(1);
}
}
}
template<class T, class U>
int OnlineMachine::run_with_error()
{
#ifndef INSECURE
try

View File

@@ -112,6 +112,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-B", // Flag token.
"--bucket-size" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
-1, // Number of args expected.
',', // Delimiter if expecting multiple args.
"Further options", // Help description.
"-o", // Flag token.
"--options" // Flag token.
);
if (security)
opt.add(
@@ -138,6 +147,12 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
verbose = opt.isSet("--verbose");
#endif
opt.get("--options")->getStrings(options);
#ifdef THROW_EXCEPTIONS
options.push_back("throw_exceptions");
#endif
if (security)
{
opt.get("-S")->getInt(security_parameter);

View File

@@ -37,6 +37,7 @@ public:
bool receive_threads;
std::string disk_memory;
vector<long> args;
vector<string> options;
OnlineOptions();
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
@@ -67,6 +68,11 @@ public:
lgp = numBits(prime);
return get_prep_sub_dir<T>(PREP_DIR, nplayers, lgp);
}
bool has_option(const string& option)
{
return find(options.begin(), options.end(), option) != options.end();
}
};
#endif /* PROCESSOR_ONLINEOPTIONS_H_ */

View File

@@ -40,6 +40,11 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir,
+ to_string(my_num) + get_suffix(thread_num);
}
PrepBase::PrepBase(DataPositions& usage) :
usage(usage)
{
}
void PrepBase::print_left(const char* name, size_t n, const string& type_string,
size_t used, bool large)
{
@@ -72,7 +77,8 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
cerr << " edaBits of size " << n_bits << " left" << endl;
}
if (n * n_batch > used / 10)
if (n * n_batch > used / 10
and n * n_batch > size_t(usage.files[DATA_INT][DATA_DABIT]) / 10)
{
cerr << "Significant amount of unused edaBits of size " << n_bits
<< ". ";

View File

@@ -12,8 +12,13 @@ using namespace std;
#include "Math/field_types.h"
#include "Tools/TimerWithComm.h"
class DataPositions;
class PrepBase
{
protected:
DataPositions& usage;
public:
static string get_suffix(int thread_num);
@@ -25,12 +30,15 @@ public:
static string get_edabit_filename(const string& prep_data_dir, int n_bits,
int my_num, int thread_num = 0);
static void print_left(const char* name, size_t n,
const string& type_string, size_t used, bool large = false);
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used, bool malicious);
TimerWithComm prep_timer;
PrepBase(DataPositions& usage);
void print_left(const char* name, size_t n,
const string& type_string, size_t used, bool large = false);
void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used, bool malicious);
};
#endif /* PROCESSOR_PREPBASE_H_ */

44
Processor/PrepBuffer.h Normal file
View File

@@ -0,0 +1,44 @@
/*
* PrepBuffer.h
*
*/
#ifndef PROCESSOR_PREPBUFFER_H_
#define PROCESSOR_PREPBUFFER_H_
#include "Tools/Buffer.h"
template<class T, class U = T, class V = T>
class PrepBuffer : public BufferOwner<T, U, V>
{
int num_players;
string fake_opts;
public:
PrepBuffer() :
num_players(0)
{
}
void setup(int num_players, const string& filename, int tuple_length,
const string& type_string = "", const char* data_type = "")
{
this->num_players = num_players;
fake_opts = V::template proto_fake_opts<typename V::open_type>();
BufferOwner<T, U, V>::setup(filename, tuple_length, type_string, data_type);
}
void input(U& a)
{
try
{
BufferOwner<T, U, V>::input(a);
}
catch (exception& e)
{
throw prep_setup_error(e.what(), num_players, fake_opts);
}
}
};
#endif /* PROCESSOR_PREPBUFFER_H_ */

View File

@@ -21,18 +21,22 @@
#include "GC/Processor.h"
#include "GC/ShareThread.h"
#include "Protocols/SecureShuffle.h"
#include "Tools/NamedStats.h"
class Program;
// synchronize in asymmetric protocols
template<class T>
void sync(vector<Integer>& x, Player& P);
template <class T>
class SubProcessor
{
CheckVector<typename T::clear> C;
CheckVector<T> S;
StackedVector<typename T::clear> C;
StackedVector<T> S;
DataPositions bit_usage;
typename T::Protocol::Shuffler shuffler;
NamedStats stats;
void resize(size_t size) { C.resize(size); S.resize(size); }
@@ -62,6 +66,8 @@ public:
typename BT::LivePrep bit_prep;
vector<typename BT::LivePrep*> personal_bit_preps;
typename T::Protocol::Shuffler shuffler;
SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC,
Preprocessing<T>& DataF, Player& P);
SubProcessor(typename T::MAC_Check& MC, Preprocessing<T>& DataF, Player& P,
@@ -76,8 +82,8 @@ public:
void muls(const vector<int>& reg);
void mulrs(const vector<int>& reg);
void dotprods(const vector<int>& reg, int size);
void matmuls(const vector<T>& source, const Instruction& instruction);
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction);
void matmuls(const StackedVector<T>& source, const Instruction& instruction);
void matmulsm(const MemoryPart<T>& source, const vector<int>& args);
void matmulsm_finalize_batch(vector<int>::const_iterator startMatmul, int startI, int startJ,
vector<int>::const_iterator endMatmul,
@@ -96,12 +102,12 @@ public:
void send_personal(const vector<int>& args);
void private_output(const vector<int>& args);
CheckVector<T>& get_S()
StackedVector<T>& get_S()
{
return S;
}
CheckVector<typename T::clear>& get_C()
StackedVector<typename T::clear>& get_C()
{
return C;
}
@@ -116,13 +122,17 @@ public:
return C[i];
}
void inverse_permutation(const Instruction &instruction, int handle);
void inverse_permutation(const Instruction &instruction, int handle);
void push_stack();
void push_args(const vector<int>& args);
void pop_stack(const vector<int>& results);
};
class ArithmeticProcessor : public ProcessorBase
{
protected:
CheckVector<Integer> Ci;
StackedVector<Integer> Ci;
ofstream public_output;
ofstream binary_output;
@@ -174,7 +184,7 @@ public:
{ return Ci[i]; }
void write_Ci(size_t i, const long& x)
{ Ci[i]=x; }
CheckVector<Integer>& get_Ci()
StackedVector<Integer>& get_Ci()
{ return Ci; }
virtual ofstream& get_public_output()
@@ -292,6 +302,8 @@ class Processor : public ArithmeticProcessor
ofstream& get_public_output();
ofstream& get_binary_output();
void call_tape(int tape_number, int arg, const vector<int>& results);
private:
template<class T> friend class SPDZ;

View File

@@ -25,9 +25,8 @@ SubProcessor<T>::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check&
template <class T>
SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
Preprocessing<T>& DataF, Player& P, ArithmeticProcessor* Proc) :
shuffler(*this),
Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC),
bit_prep(bit_usage)
bit_prep(bit_usage), shuffler(*this)
{
DataF.set_proc(this);
protocol.init(DataF, MC);
@@ -113,7 +112,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
secure_prng.ReSeed();
shared_prng.SeedGlobally(P, false);
setup_redirection(P.my_num(), thread_num, opts, out);
setup_redirection(P.my_num(), thread_num, opts, out, sint::real_shares(P));
Procb.out = out;
}
@@ -158,6 +157,7 @@ void Processor<sint, sgf2n>::reset(const Program& program,int arg)
Procp.get_S().resize(program.num_reg(SINT));
Procp.get_C().resize(program.num_reg(CINT));
Ci.resize(program.num_reg(INT));
this->arg = arg;
Procb.reset(program);
}
@@ -209,17 +209,6 @@ void Processor<sint, sgf2n>::edabit(const Instruction& instruction, bool strict)
&Procp.get_S_ref(instruction.get_r(0)), Procb.S, regs);
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::convcbitvec(const Instruction& instruction)
{
for (size_t i = 0; i < instruction.get_n(); i++)
{
int i1 = i / GC::Clear::N_BITS;
int i2 = i % GC::Clear::N_BITS;
Ci[instruction.get_r(0) + i] = Procb.C[instruction.get_r(1) + i1].get_bit(i2);
}
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::convcintvec(const Instruction& instruction)
{
@@ -314,10 +303,9 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
}
}
#ifdef VERBOSE_COMM
cerr << "send " << socket_stream.get_length() << " to client " << socket_id
<< endl;
#endif
if (OnlineOptions::singleton.has_option("verbose_comm"))
fprintf(stderr, "Send %zu bytes to client %d\n", socket_stream.get_length(),
socket_id);
try {
TimeScope _(client_stats.add(socket_stream.get_length()));
@@ -362,7 +350,10 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id,
for (int j = 0; j < size; j++)
for (int i = 0; i < m; i++)
get_Cp_ref(registers[i] + j) =
socket_stream.get<typename sint::open_type>();
socket_stream.get<typename sint::share_type::open_type>();
if (socket_stream.left())
throw runtime_error("unexpected data");
}
// Receive vector of field element shares over private channel
@@ -562,7 +553,7 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
}
template<class T>
void SubProcessor<T>::matmuls(const vector<T>& source,
void SubProcessor<T>::matmuls(const StackedVector<T>& source,
const Instruction& instruction)
{
protocol.init_dotprod();
@@ -604,12 +595,10 @@ void SubProcessor<T>::matmuls(const vector<T>& source,
template<class T>
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
const Instruction& instruction)
const vector<int>& start)
{
assert(Proc);
auto& start = instruction.get_start();
auto batchStartMatrix = start.begin();
int batchStartI = 0;
int batchStartJ = 0;
@@ -816,7 +805,7 @@ Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
}
template<class T>
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
void Conv2dTuple::pre(StackedVector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
@@ -857,7 +846,7 @@ void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
}
template<class T>
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
void Conv2dTuple::post(StackedVector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
@@ -1049,17 +1038,85 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
template<class sint, class sgf2n>
long Processor<sint, sgf2n>::sync(long x) const
{
vector<Integer> tmp = {x};
::sync<sint>(tmp, P);
return tmp[0].get();
}
template<class sint>
void sync(vector<Integer>& x, Player& P)
{
if (not sint::symmetric)
{
octetStream os;
// send number to dealer
if (P.my_num() == 0)
P.send_long(P.num_players() - 1, x);
{
os.store(x);
P.send_to(P.num_players() - 1, os);
}
if (not sint::real_shares(P))
return P.receive_long(0);
{
P.receive_player(0, os);
os.get(x);
}
}
}
return x;
template<class T>
void SubProcessor<T>::push_stack()
{
S.push_stack();
C.push_stack();
}
template<class T>
void SubProcessor<T>::push_args(const vector<int>& args)
{
auto char2 = T::clear::characteristic_two;
S.push_args(args, char2 ? SGF2N : SINT);
C.push_args(args, char2 ? CGF2N : CINT);
}
template<class T>
void SubProcessor<T>::pop_stack(const vector<int>& results)
{
auto char2 = T::clear::characteristic_two;
S.pop_stack(results, char2 ? SGF2N : SINT);
C.pop_stack(results, char2 ? CGF2N : CINT);
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::call_tape(int tape_number, int arg,
const vector<int>& args)
{
PC_stack.push_back(PC);
arg_stack.push_back(this->arg);
Procp.push_stack();
Proc2.push_stack();
Procb.push_stack();
Ci.push_stack();
auto& tape = machine.progs.at(tape_number);
reset(tape, arg);
Procp.push_args(args);
Proc2.push_args(args);
Procb.push_args(args);
Ci.push_args(args, INT);
tape.execute(*this);
Procp.pop_stack(args);
Proc2.pop_stack(args);
Procb.pop_stack(args);
Ci.pop_stack(args, INT);
PC = PC_stack.back();
PC_stack.pop_back();
this->arg = arg_stack.back();
arg_stack.pop_back();
}
#endif

View File

@@ -27,11 +27,12 @@ void ProcessorBase::open_input_file(int my_num, int thread_num,
}
void ProcessorBase::setup_redirection(int my_num, int thread_num,
OnlineOptions& opts, SwitchableOutput& out)
OnlineOptions& opts, SwitchableOutput& out, bool real)
{
// only output on party 0 if not interactive
bool always_stdout = opts.cmd_private_output_file == ".";
bool output = my_num == 0 or opts.interactive or always_stdout;
output &= real;
out.activate(output);
if (not (opts.cmd_private_output_file.empty() or always_stdout))

View File

@@ -28,6 +28,8 @@ class ProcessorBase
protected:
// Optional argument to tape
Integer arg;
vector<Integer> arg_stack;
vector<int> PC_stack;
string get_parameterized_filename(int my_num, int thread_num,
const string& prefix);
@@ -61,7 +63,7 @@ public:
T get_input(istream& is, const string& input_filename, const int* params);
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts,
SwitchableOutput& out);
SwitchableOutput& out, bool real = true);
};
#endif /* PROCESSOR_PROCESSORBASE_H_ */

View File

@@ -28,6 +28,24 @@ void Program::compute_constants()
}
void Program::parse(string filename)
{
if (OnlineOptions::singleton.has_option("throw_exceptions"))
parse_with_error(filename);
else
{
try
{
parse_with_error(filename);
}
catch(exception& e)
{
cerr << "Error in bytecode: " << e.what() << endl;
exit(1);
}
}
}
void Program::parse_with_error(string filename)
{
ifstream pinp(filename);
if (pinp.fail())

View File

@@ -42,6 +42,7 @@ class Program
// Read in a program
void parse(string filename);
void parse_with_error(string filename);
void parse(istream& s);
DataPositions get_offline_data_used() const { return offline_data_used; }

View File

@@ -30,9 +30,14 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
int RingOptions::ring_size_from_opts_or_schedule(string progname)
{
if (R_is_set)
return R;
int r = BaseMachine::ring_size_from_schedule(progname);
if (R_is_set)
{
if (r and r != R)
cerr << "Different -R option in compilation and run-time: " << r
<< " vs " << R << endl;
return R;
}
if (r == 0)
r = R;
cerr << "Trying to run " << r << "-bit computation" << endl;

View File

@@ -33,10 +33,12 @@ void ThreadQueue::finished(const ThreadJob& job)
out.push(job);
}
void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats)
void ThreadQueue::finished(const ThreadJob& job,
const NamedCommStats& new_comm_stats, const NamedStats& stats)
{
finished(job);
set_comm_stats(new_comm_stats);
this->stats = stats;
}
void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)

View File

@@ -7,6 +7,7 @@
#define PROCESSOR_THREADQUEUE_H_
#include "ThreadJob.h"
#include "Tools/NamedStats.h"
class ThreadQueue
{
@@ -20,6 +21,7 @@ public:
map<string, TimerWithComm> timers;
Timer wait_timer;
NamedStats stats;
ThreadQueue() :
left(0)
@@ -34,7 +36,8 @@ public:
void schedule(const ThreadJob& job);
ThreadJob next();
void finished(const ThreadJob& job);
void finished(const ThreadJob& job, const NamedCommStats& comm_stats);
void finished(const ThreadJob& job, const NamedCommStats& comm_stats,
const NamedStats& stats = {});
ThreadJob result();
void set_comm_stats(const NamedCommStats& new_comm_stats);