mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Maintenance.
This commit is contained in:
@@ -70,6 +70,29 @@ int BaseMachine::bucket_size(size_t usage)
|
||||
return res;
|
||||
}
|
||||
|
||||
int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols)
|
||||
{
|
||||
unsigned res = min(100, OnlineOptions::singleton.batch_size);
|
||||
if (has_program())
|
||||
res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols));
|
||||
return res;
|
||||
}
|
||||
|
||||
int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
|
||||
{
|
||||
if (has_program())
|
||||
{
|
||||
auto res = s().progs[0].get_offline_data_used().matmuls[
|
||||
{n_rows, n_inner, n_cols}];
|
||||
if (res)
|
||||
return res;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
BaseMachine::BaseMachine() : nthreads(0)
|
||||
{
|
||||
if (sodium_init() == -1)
|
||||
@@ -287,7 +310,7 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << P.my_num() << " only";
|
||||
if (nthreads > 1)
|
||||
if (multithread)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
if (not OnlineOptions::singleton.verbose)
|
||||
cerr << "; use '-v' for more details";
|
||||
|
||||
@@ -44,6 +44,7 @@ public:
|
||||
|
||||
string progname;
|
||||
int nthreads;
|
||||
bool multithread;
|
||||
|
||||
ThreadQueues queues;
|
||||
|
||||
@@ -66,10 +67,14 @@ public:
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
template<class T>
|
||||
static int input_batch_size(int player, int buffer_size = 0);
|
||||
template<class T>
|
||||
static int edabit_batch_size(int n_bits, int buffer_size = 0);
|
||||
static int edabit_bucket_size(int n_bits);
|
||||
static int triple_bucket_size(DataFieldType type);
|
||||
static int bucket_size(size_t usage);
|
||||
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
|
||||
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
@@ -105,6 +110,10 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
template<class T>
|
||||
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
|
||||
fallback);
|
||||
|
||||
int n_opts;
|
||||
int n = 0;
|
||||
int res = 0;
|
||||
@@ -114,7 +123,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
else if (fallback > 0)
|
||||
n_opts = fallback;
|
||||
else
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
|
||||
|
||||
if (buffer_size <= 0 and has_program())
|
||||
{
|
||||
@@ -132,7 +141,6 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
{
|
||||
n = buffer_size;
|
||||
buffer_size = 0;
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
}
|
||||
|
||||
if (n > 0 and not (buffer_size > 0))
|
||||
@@ -161,16 +169,33 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
else
|
||||
res = n_opts;
|
||||
|
||||
#ifdef DEBUG_BATCH_SIZE
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n="
|
||||
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
|
||||
<< " buffer_size=" << buffer_size << endl;
|
||||
|
||||
assert(res > 0);
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::input_batch_size(int player, int buffer_size)
|
||||
{
|
||||
if (buffer_size)
|
||||
return buffer_size;
|
||||
|
||||
if (has_program())
|
||||
{
|
||||
auto res =
|
||||
s().progs[0].get_offline_data_used(
|
||||
).inputs[player][T::clear::field_type()];
|
||||
if (res > 0)
|
||||
return res;
|
||||
}
|
||||
|
||||
return OnlineOptions::singleton.batch_size;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
|
||||
{
|
||||
|
||||
@@ -29,10 +29,12 @@ public:
|
||||
|
||||
Conv2dTuple(const vector<int>& args, int start);
|
||||
|
||||
array<int, 3> matrix_dimensions();
|
||||
|
||||
template<class T>
|
||||
void pre(vector<T>& S, typename T::Protocol& protocol);
|
||||
void pre(StackedVector<T>& S, typename T::Protocol& protocol);
|
||||
template<class T>
|
||||
void post(vector<T>& S, typename T::Protocol& protocol);
|
||||
void post(StackedVector<T>& S, typename T::Protocol& protocol);
|
||||
|
||||
template<class T>
|
||||
void run_matrix(SubProcessor<T>& processor);
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
#include "Networking/Player.h"
|
||||
#include "Protocols/edabit.h"
|
||||
#include "PrepBase.h"
|
||||
#include "PrepBuffer.h"
|
||||
#include "EdabitBuffer.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
@@ -104,8 +106,6 @@ class Preprocessing : public PrepBase
|
||||
protected:
|
||||
static const bool use_part = false;
|
||||
|
||||
DataPositions& usage;
|
||||
|
||||
bool do_count;
|
||||
|
||||
void count(Dtype dtype, int n = 1)
|
||||
@@ -115,9 +115,9 @@ protected:
|
||||
|
||||
template<int>
|
||||
void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
|
||||
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
|
||||
template<int>
|
||||
void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
|
||||
void get_edabits(bool, size_t, T*, StackedVector<typename T::bit_type>&,
|
||||
const vector<int>&, true_type)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
@@ -126,6 +126,8 @@ protected:
|
||||
T get_random_from_inputs(int nplayers);
|
||||
|
||||
public:
|
||||
int buffer_size;
|
||||
|
||||
template<class U, class V>
|
||||
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
|
||||
SubProcessor<T>* proc);
|
||||
@@ -135,7 +137,8 @@ public:
|
||||
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
|
||||
DataPositions& usage);
|
||||
|
||||
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
|
||||
Preprocessing(DataPositions& usage) :
|
||||
PrepBase(usage), do_count(true), buffer_size(0) {}
|
||||
virtual ~Preprocessing() {}
|
||||
|
||||
virtual void set_protocol(typename T::Protocol&) {};
|
||||
@@ -151,7 +154,7 @@ public:
|
||||
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
|
||||
virtual void get_input_no_count(T&, typename T::open_type&, int)
|
||||
{ throw not_implemented() ; }
|
||||
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
|
||||
virtual void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void get(Dtype dtype, T* a);
|
||||
@@ -159,7 +162,7 @@ public:
|
||||
void get_two(Dtype dtype, T& a, T& b);
|
||||
void get_one(Dtype dtype, T& a);
|
||||
void get_input(T& a, typename T::open_type& x, int i);
|
||||
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
|
||||
/// Get fresh random multiplication triple
|
||||
virtual array<T, 3> get_triple(int n_bits);
|
||||
@@ -174,7 +177,7 @@ public:
|
||||
virtual void get_dabit(T& a, typename T::bit_type& b);
|
||||
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
|
||||
virtual void get_edabit_no_count(bool, int, edabit<T>&)
|
||||
{ throw runtime_error("no edaBits"); }
|
||||
@@ -201,11 +204,11 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
|
||||
static int tuple_length(int dtype);
|
||||
|
||||
BufferOwner<T, T> buffers[N_DTYPE];
|
||||
vector<BufferOwner<T, T>> input_buffers;
|
||||
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
|
||||
map<DataTag, BufferOwner<T, T> > extended;
|
||||
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
|
||||
array<PrepBuffer<T>, N_DTYPE> buffers;
|
||||
vector<PrepBuffer<T>> input_buffers;
|
||||
PrepBuffer<InputTuple<T>, RefInputTuple<T>, T> my_input_buffers;
|
||||
map<DataTag, PrepBuffer<T> > extended;
|
||||
PrepBuffer<dabit<T>, dabit<T>, T> dabit_buffer;
|
||||
map<int, EdabitBuffer<T>> edabit_buffers;
|
||||
map<int, edabitvec<T>> my_edabits;
|
||||
|
||||
@@ -284,7 +287,7 @@ public:
|
||||
}
|
||||
|
||||
void setup_extended(const DataTag& tag, int tuple_size = 0);
|
||||
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b);
|
||||
|
||||
part_type& get_part();
|
||||
@@ -397,7 +400,7 @@ inline void Preprocessing<T>::get_input(T& a, typename T::open_type& x, int i)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void Preprocessing<T>::get(vector<T>& S, DataTag tag,
|
||||
inline void Preprocessing<T>::get(StackedVector<T>& S, DataTag tag,
|
||||
const vector<int>& regs, int vector_size)
|
||||
{
|
||||
usage.count(T::clear::field_type(), tag, vector_size);
|
||||
|
||||
@@ -143,14 +143,14 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
{
|
||||
if (T::clear::allows(Dtype(dtype)))
|
||||
{
|
||||
buffers[dtype].setup(
|
||||
buffers[dtype].setup(num_players,
|
||||
PrepBase::get_filename(prep_data_dir, Dtype(dtype), type_short,
|
||||
my_num, thread_num), tuple_length(dtype), type_string,
|
||||
DataPositions::dtype_names[dtype]);
|
||||
}
|
||||
}
|
||||
|
||||
dabit_buffer.setup(
|
||||
dabit_buffer.setup(num_players,
|
||||
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
|
||||
type_short, my_num, thread_num), dabit<T>::size(), type_string,
|
||||
DataPositions::dtype_names[DATA_DABIT]);
|
||||
@@ -161,10 +161,10 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
string filename = PrepBase::get_input_filename(prep_data_dir,
|
||||
type_short, i, my_num, thread_num);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(filename,
|
||||
my_input_buffers.setup(num_players, filename,
|
||||
InputTuple<T>::size(), type_string);
|
||||
else
|
||||
input_buffers[i].setup(filename,
|
||||
input_buffers[i].setup(num_players, filename,
|
||||
T::size(), type_string);
|
||||
}
|
||||
|
||||
@@ -344,14 +344,14 @@ void Sub_Data_Files<T>::setup_extended(const DataTag& tag, int tuple_size)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num;
|
||||
buffer.setup(ss.str(), tuple_length);
|
||||
buffer.setup(num_players, ss.str(), tuple_length);
|
||||
}
|
||||
|
||||
buffer.check_tuple_length(tuple_length);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
|
||||
void Sub_Data_Files<T>::get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
|
||||
{
|
||||
setup_extended(tag, regs.size());
|
||||
for (int j = 0; j < vector_size; j++)
|
||||
|
||||
@@ -12,6 +12,7 @@ using namespace std;
|
||||
#include "Math/BitVec.h"
|
||||
#include "Data_Files.h"
|
||||
#include "Protocols/Replicated.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "Protocols/MAC_Check_Base.h"
|
||||
#include "Processor/Input.h"
|
||||
|
||||
@@ -109,7 +110,7 @@ public:
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class DummyLivePrep : public Preprocessing<T>
|
||||
class DummyLivePrep : public BufferPrep<T>
|
||||
{
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
@@ -133,16 +134,16 @@ public:
|
||||
}
|
||||
|
||||
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
DummyLivePrep(DataPositions& usage, bool = true) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -165,7 +166,7 @@ public:
|
||||
{
|
||||
fail();
|
||||
}
|
||||
void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
|
||||
void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
|
||||
{
|
||||
fail();
|
||||
}
|
||||
|
||||
@@ -81,7 +81,6 @@ int ExternalClients::init_client_connection(const string& host, int portnum,
|
||||
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
|
||||
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
|
||||
true);
|
||||
if (party_num == 0)
|
||||
{
|
||||
octetStream specification;
|
||||
specification.Receive(socket);
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include <iomanip>
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
void Instruction::execute_clear_gf2n(StackedVector<cgf2n>& registers,
|
||||
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
{
|
||||
auto& C2 = registers;
|
||||
@@ -30,7 +30,7 @@ void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
}
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::gbitdec(vector<cgf2n>& registers) const
|
||||
void Instruction::gbitdec(StackedVector<cgf2n>& registers) const
|
||||
{
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
@@ -44,7 +44,7 @@ void Instruction::gbitdec(vector<cgf2n>& registers) const
|
||||
}
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::gbitcom(vector<cgf2n>& registers) const
|
||||
void Instruction::gbitcom(StackedVector<cgf2n>& registers) const
|
||||
{
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
@@ -124,7 +124,7 @@ ostream& operator<<(ostream& s, const Instruction& instr)
|
||||
return s;
|
||||
}
|
||||
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
|
||||
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_short>& registers,
|
||||
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
|
||||
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_long>& registers,
|
||||
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
|
||||
@@ -15,6 +15,7 @@ template<class sint, class sgf2n> class Machine;
|
||||
template<class sint, class sgf2n> class Processor;
|
||||
template<class T> class SubProcessor;
|
||||
template<class T> class MemoryPart;
|
||||
template<class T> class StackedVector;
|
||||
class ArithmeticProcessor;
|
||||
class SwitchableOutput;
|
||||
|
||||
@@ -73,6 +74,8 @@ enum
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
CALL_TAPE = 0xEC,
|
||||
CALL_ARG = 0xED,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -90,6 +93,7 @@ enum
|
||||
PREFIXSUMS = 0x2D,
|
||||
PICKS = 0x2E,
|
||||
CONCATS = 0x2F,
|
||||
ZIPS = 0x3F,
|
||||
// Multiplication/division/other arithmetic
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -351,6 +355,7 @@ protected:
|
||||
string str;
|
||||
|
||||
public:
|
||||
BaseInstruction() : opcode(0), size(0), n(0) {}
|
||||
virtual ~BaseInstruction() {};
|
||||
|
||||
int get_r(int i) const { return r[i]; }
|
||||
@@ -391,13 +396,13 @@ public:
|
||||
void execute(Processor<sint, sgf2n>& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
void execute_clear_gf2n(vector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
|
||||
void execute_clear_gf2n(StackedVector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
|
||||
ArithmeticProcessor& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
void gbitdec(vector<cgf2n>& registers) const;
|
||||
void gbitdec(StackedVector<cgf2n>& registers) const;
|
||||
template<class cgf2n>
|
||||
void gbitcom(vector<cgf2n>& registers) const;
|
||||
void gbitcom(StackedVector<cgf2n>& registers) const;
|
||||
|
||||
void execute_regint(ArithmeticProcessor& Proc, MemoryPart<Integer>& Mi) const;
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case DIVINT:
|
||||
case CONDPRINTPLAIN:
|
||||
case INPUTMASKREG:
|
||||
case ZIPS:
|
||||
get_ints(r, s, 3);
|
||||
break;
|
||||
// instructions with 2 register operands
|
||||
@@ -245,6 +246,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case CONDPRINTSTRB:
|
||||
case RANDOMS:
|
||||
case GENSECSHUFFLE:
|
||||
case CALL_ARG:
|
||||
r[0]=get_int(s);
|
||||
n = get_int(s);
|
||||
break;
|
||||
@@ -330,6 +332,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
// read from file, input is opcode num_args,
|
||||
// start_file_posn (read), end_file_posn(write) var1, var2, ...
|
||||
case READFILESHARE:
|
||||
case CALL_TAPE:
|
||||
num_var_args = get_int(s) - 2;
|
||||
r[0] = get_int(s);
|
||||
r[1] = get_int(s);
|
||||
@@ -588,6 +591,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
case GENSECSHUFFLE:
|
||||
case CMDLINEARG:
|
||||
case CALL_TAPE:
|
||||
return INT;
|
||||
case PREP:
|
||||
case GPREP:
|
||||
@@ -596,7 +600,6 @@ int BaseInstruction::get_reg_type() const
|
||||
case USE_EDABIT:
|
||||
case USE_MATMUL:
|
||||
case RUN_TAPE:
|
||||
case CISC:
|
||||
// those use r[] not for registers
|
||||
return NONE;
|
||||
case LDI:
|
||||
@@ -639,6 +642,8 @@ int BaseInstruction::get_reg_type() const
|
||||
case PRIVATEOUTPUT:
|
||||
case FIXINPUT:
|
||||
return CINT;
|
||||
case CALL_ARG:
|
||||
return n;
|
||||
default:
|
||||
if (is_gf2n_instruction())
|
||||
{
|
||||
@@ -706,6 +711,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
case CALL_TAPE:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 5)
|
||||
if (it[1] == reg_type)
|
||||
res = max(res, (*it ? it[3] : it[4]) + it[2]);
|
||||
return res;
|
||||
}
|
||||
default:
|
||||
if (get_reg_type() != reg_type)
|
||||
return 0;
|
||||
@@ -713,6 +726,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
|
||||
switch (opcode)
|
||||
{
|
||||
case CISC:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += *it)
|
||||
{
|
||||
assert(it + *it <= start.end());
|
||||
res = max(res, it[1] + it[2]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case MULS:
|
||||
skip = 4;
|
||||
offset = 1;
|
||||
size_offset = -1;
|
||||
break;
|
||||
case DOTPRODS:
|
||||
{
|
||||
int res = 0;
|
||||
@@ -737,7 +765,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return res;
|
||||
}
|
||||
case MATMULSM:
|
||||
return r[0] + start[0] * start[2];
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 12)
|
||||
{
|
||||
res = max(res, *it + *(it + 3) * *(it + 5));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case CONV2DS:
|
||||
{
|
||||
unsigned res = 0;
|
||||
@@ -956,6 +991,17 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
}
|
||||
return;
|
||||
}
|
||||
case ZIPS:
|
||||
{
|
||||
auto& S = Proc.Procp.get_S();
|
||||
auto dest = S.begin() + r[0];
|
||||
for (int i = 0; i < get_size(); i++)
|
||||
{
|
||||
*dest++ = S[r[1] + i];
|
||||
*dest++ = S[r[2] + i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
case DIVC:
|
||||
Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2]));
|
||||
break;
|
||||
@@ -1232,6 +1278,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case JOIN_TAPE:
|
||||
Proc.machine.join_tape(r[0]);
|
||||
break;
|
||||
case CALL_TAPE:
|
||||
Proc.call_tape(r[0], Proc.read_Ci(r[1]), start);
|
||||
break;
|
||||
case CRASH:
|
||||
if (Proc.read_Ci(r[0]))
|
||||
throw crash_requested();
|
||||
@@ -1277,7 +1326,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
// get client connection at port number n + my_num())
|
||||
int client_handle = Proc.external_clients.get_client_connection(
|
||||
Proc.read_Ci(r[1]));
|
||||
if (Proc.P.my_num() == 0)
|
||||
{
|
||||
octetStream os;
|
||||
os.store(int(sint::open_type::type_char()));
|
||||
@@ -1421,7 +1469,8 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
|
||||
#endif
|
||||
|
||||
#ifdef OUTPUT_INSTRUCTIONS
|
||||
cerr << instruction << endl;
|
||||
if (OnlineOptions::singleton.has_option("output_instructions"))
|
||||
cerr << instruction << endl;
|
||||
#endif
|
||||
|
||||
Proc.PC++;
|
||||
@@ -1461,7 +1510,7 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
if (p == 0 or (*p == 0 and s == 0))
|
||||
out << v[i];
|
||||
out.signed_output(v[i]);
|
||||
else if (s == 0)
|
||||
out << bigint::get_float(v[i], p[i], {}, {});
|
||||
else
|
||||
|
||||
@@ -235,6 +235,7 @@ size_t Machine<sint, sgf2n>::load_program(const string& threadname,
|
||||
M2.minimum_size(SGF2N, CGF2N, progs[i], threadname);
|
||||
Mp.minimum_size(SINT, CINT, progs[i], threadname);
|
||||
Mi.minimum_size(NONE, INT, progs[i], threadname);
|
||||
bit_memories.reset(progs[i]);
|
||||
return progs.back().size();
|
||||
}
|
||||
|
||||
@@ -340,14 +341,15 @@ void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
|
||||
auto subdim = it->first;
|
||||
subdim[1] = min(subdim[1] - j, max_inner);
|
||||
subdim[2] = min(subdim[2] - k, max_cols);
|
||||
auto& source =
|
||||
dynamic_cast<Hemi<sint>&>(source_proc.protocol).get_matrix_prep(
|
||||
auto& source_proto = dynamic_cast<Hemi<sint>&>(source_proc.protocol);
|
||||
auto& source = source_proto.get_matrix_prep(
|
||||
subdim, source_proc);
|
||||
auto& dest =
|
||||
dynamic_cast<Hemi<sint>&>(tinfo[thread_number].processor->Procp.protocol).get_matrix_prep(
|
||||
subdim, tinfo[thread_number].processor->Procp);
|
||||
for (int i = 0; i < it->second; i++)
|
||||
dest.push_triple(source.get_triple_no_count(-1));
|
||||
if (not source_proto.use_plain_matmul(subdim, source_proc))
|
||||
for (int i = 0; i < it->second; i++)
|
||||
dest.push_triple(source.get_triple_no_count(-1));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -434,7 +436,13 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
auto comm_stats = total_comm();
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
queues.print_breakdown();
|
||||
{
|
||||
NamedStats total;
|
||||
for (auto queue : queues)
|
||||
total += queue->stats;
|
||||
total.print();
|
||||
queues.print_breakdown();
|
||||
}
|
||||
|
||||
for (auto& queue : queues)
|
||||
delete queue;
|
||||
@@ -464,7 +472,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
finish_timer.start();
|
||||
|
||||
// actual usage
|
||||
bool multithread = nthreads > 1;
|
||||
multithread = nthreads > 1;
|
||||
auto res = stop_threads();
|
||||
DataPositions& pos = res.first;
|
||||
|
||||
@@ -518,14 +526,15 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
cerr << "Full broadcast" << endl;
|
||||
#endif
|
||||
|
||||
#ifdef CHOP_MEMORY
|
||||
// Reduce memory size to speed up
|
||||
unsigned max_size = 1 << 20;
|
||||
if (M2.size_s() > max_size)
|
||||
M2.resize_s(max_size);
|
||||
if (Mp.size_s() > max_size)
|
||||
Mp.resize_s(max_size);
|
||||
#endif
|
||||
if (not OnlineOptions::singleton.has_option("output_full_memory"))
|
||||
{
|
||||
// Reduce memory size to speed up
|
||||
unsigned max_size = 1 << 20;
|
||||
if (M2.size_s() > max_size)
|
||||
M2.resize_s(max_size);
|
||||
if (Mp.size_s() > max_size)
|
||||
Mp.resize_s(max_size);
|
||||
}
|
||||
|
||||
// Write out the memory to use next time
|
||||
ofstream outf(memory_filename(), ios::out | ios::binary);
|
||||
|
||||
@@ -44,10 +44,10 @@ public:
|
||||
virtual const T& at(size_t i) const = 0;
|
||||
|
||||
template<class U>
|
||||
void indirect_read(const Instruction& inst, vector<T>& regs,
|
||||
void indirect_read(const Instruction& inst, StackedVector<T>& regs,
|
||||
const U& indices);
|
||||
template<class U>
|
||||
void indirect_write(const Instruction& inst, vector<T>& regs,
|
||||
void indirect_write(const Instruction& inst, StackedVector<T>& regs,
|
||||
const U& indices);
|
||||
|
||||
void minimum_size(size_t size);
|
||||
|
||||
@@ -6,21 +6,21 @@
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_read(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
StackedVector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto dest = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
#ifndef NO_CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(dest + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
size_t size = this->size();
|
||||
const T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
if (size_t(it->get()) >= size)
|
||||
throw overflow(T::type_string() + " memory read", it->get(), size);
|
||||
#endif
|
||||
*dest++ = data[it->get()];
|
||||
@@ -30,21 +30,21 @@ void MemoryPart<T>::indirect_read(const Instruction& inst,
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_write(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
StackedVector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto source = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
#ifndef NO_CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(source + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
size_t size = this->size();
|
||||
T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
if (size_t(it->get()) >= size)
|
||||
throw overflow(T::type_string() + " memory write", it->get(), size);
|
||||
#endif
|
||||
data[it->get()] = *source++;
|
||||
|
||||
@@ -33,6 +33,7 @@ class thread_info
|
||||
const char* name);
|
||||
|
||||
void Sub_Main_Func();
|
||||
void Main_Func_With_Purge();
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -340,6 +340,13 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
cerr << endl;
|
||||
#endif
|
||||
|
||||
if (num == 0 and OnlineOptions::singleton.verbose
|
||||
and machine.queues.size() > 1)
|
||||
{
|
||||
cerr << "Main thread communication:" << endl;
|
||||
P.total_comm().print();
|
||||
}
|
||||
|
||||
// wind down thread by thread
|
||||
machine.stats += Proc.stats;
|
||||
queues->timers["wait"] = wait_timer + queues->wait_timer;
|
||||
@@ -347,11 +354,43 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
|
||||
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
|
||||
|
||||
NamedStats stats;
|
||||
stats["integer multiplications"] = Proc.Procp.protocol.counter;
|
||||
stats["integer multiplication rounds"] = Proc.Procp.protocol.rounds;
|
||||
stats["probabilistic truncations"] = Proc.Procp.protocol.trunc_pr_counter;
|
||||
stats["probabilistic truncation rounds"] = Proc.Procp.protocol.trunc_rounds;
|
||||
stats["ANDs"] = Proc.share_thread.protocol->bit_counter;
|
||||
stats["AND rounds"] = Proc.share_thread.protocol->rounds;
|
||||
stats["integer openings"] = MCp->values_opened;
|
||||
stats["integer inputs"] = Proc.Procp.input.values_input;
|
||||
for (auto x : Proc.Procp.shuffler.stats)
|
||||
stats["shuffles of length " + to_string(x.first)] = x.second;
|
||||
|
||||
try
|
||||
{
|
||||
auto proc = dynamic_cast<RingPrep<sint>&>(Proc.DataF.DataFp).bit_part_proc;
|
||||
if (proc)
|
||||
stats["ANDs in preprocessing"] = proc->protocol.bit_counter;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto protocol = dynamic_cast<BitPrep<sint>&>(Proc.DataF.DataFp).protocol;
|
||||
if (protocol)
|
||||
stats["integer multiplications in preprocessing"] = protocol->counter;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
}
|
||||
|
||||
// prevent faulty usage message
|
||||
Proc.DataF.set_usage(actual_usage);
|
||||
delete processor;
|
||||
|
||||
queues->finished(actual_usage, P.total_comm());
|
||||
queues->finished(actual_usage, P.total_comm(), stats);
|
||||
|
||||
delete MC2;
|
||||
delete MCp;
|
||||
@@ -367,6 +406,27 @@ template<class sint, class sgf2n>
|
||||
void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
|
||||
{
|
||||
auto& ti = *(thread_info<sint, sgf2n>*)(ptr);
|
||||
if (OnlineOptions::singleton.has_option("throw_exceptions"))
|
||||
ti.Main_Func_With_Purge();
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
ti.Main_Func_With_Purge();
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
cerr << "Fatal error: " << e.what() << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void thread_info<sint, sgf2n>::Main_Func_With_Purge()
|
||||
{
|
||||
auto& ti = *this;
|
||||
#ifdef INSECURE
|
||||
ti.Sub_Main_Func();
|
||||
#else
|
||||
@@ -383,12 +443,10 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
|
||||
ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num);
|
||||
purge_preprocessing(machine->get_N(), thread_num);
|
||||
throw;
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ public:
|
||||
|
||||
template<class T, class U>
|
||||
int run();
|
||||
template<class T, class U>
|
||||
int run_with_error();
|
||||
|
||||
Player* new_player(const string& id_base);
|
||||
|
||||
|
||||
@@ -173,6 +173,25 @@ Player* OnlineMachine::new_player(const string& id_base)
|
||||
|
||||
template<class T, class U>
|
||||
int OnlineMachine::run()
|
||||
{
|
||||
if (online_opts.has_option("throw_exception"))
|
||||
return run_with_error<T, U>();
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
return run_with_error<T, U>();
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
cerr << "Fatal error: " << e.what() << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
int OnlineMachine::run_with_error()
|
||||
{
|
||||
#ifndef INSECURE
|
||||
try
|
||||
|
||||
@@ -112,6 +112,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-B", // Flag token.
|
||||
"--bucket-size" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
-1, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Further options", // Help description.
|
||||
"-o", // Flag token.
|
||||
"--options" // Flag token.
|
||||
);
|
||||
|
||||
if (security)
|
||||
opt.add(
|
||||
@@ -138,6 +147,12 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
verbose = opt.isSet("--verbose");
|
||||
#endif
|
||||
|
||||
opt.get("--options")->getStrings(options);
|
||||
|
||||
#ifdef THROW_EXCEPTIONS
|
||||
options.push_back("throw_exceptions");
|
||||
#endif
|
||||
|
||||
if (security)
|
||||
{
|
||||
opt.get("-S")->getInt(security_parameter);
|
||||
|
||||
@@ -37,6 +37,7 @@ public:
|
||||
bool receive_threads;
|
||||
std::string disk_memory;
|
||||
vector<long> args;
|
||||
vector<string> options;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
@@ -67,6 +68,11 @@ public:
|
||||
lgp = numBits(prime);
|
||||
return get_prep_sub_dir<T>(PREP_DIR, nplayers, lgp);
|
||||
}
|
||||
|
||||
bool has_option(const string& option)
|
||||
{
|
||||
return find(options.begin(), options.end(), option) != options.end();
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_ONLINEOPTIONS_H_ */
|
||||
|
||||
@@ -40,6 +40,11 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir,
|
||||
+ to_string(my_num) + get_suffix(thread_num);
|
||||
}
|
||||
|
||||
PrepBase::PrepBase(DataPositions& usage) :
|
||||
usage(usage)
|
||||
{
|
||||
}
|
||||
|
||||
void PrepBase::print_left(const char* name, size_t n, const string& type_string,
|
||||
size_t used, bool large)
|
||||
{
|
||||
@@ -72,7 +77,8 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
cerr << " edaBits of size " << n_bits << " left" << endl;
|
||||
}
|
||||
|
||||
if (n * n_batch > used / 10)
|
||||
if (n * n_batch > used / 10
|
||||
and n * n_batch > size_t(usage.files[DATA_INT][DATA_DABIT]) / 10)
|
||||
{
|
||||
cerr << "Significant amount of unused edaBits of size " << n_bits
|
||||
<< ". ";
|
||||
|
||||
@@ -12,8 +12,13 @@ using namespace std;
|
||||
#include "Math/field_types.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
|
||||
class DataPositions;
|
||||
|
||||
class PrepBase
|
||||
{
|
||||
protected:
|
||||
DataPositions& usage;
|
||||
|
||||
public:
|
||||
static string get_suffix(int thread_num);
|
||||
|
||||
@@ -25,12 +30,15 @@ public:
|
||||
static string get_edabit_filename(const string& prep_data_dir, int n_bits,
|
||||
int my_num, int thread_num = 0);
|
||||
|
||||
static void print_left(const char* name, size_t n,
|
||||
const string& type_string, size_t used, bool large = false);
|
||||
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used, bool malicious);
|
||||
|
||||
TimerWithComm prep_timer;
|
||||
|
||||
PrepBase(DataPositions& usage);
|
||||
|
||||
void print_left(const char* name, size_t n,
|
||||
const string& type_string, size_t used, bool large = false);
|
||||
|
||||
void print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used, bool malicious);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PREPBASE_H_ */
|
||||
|
||||
44
Processor/PrepBuffer.h
Normal file
44
Processor/PrepBuffer.h
Normal file
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
* PrepBuffer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_PREPBUFFER_H_
|
||||
#define PROCESSOR_PREPBUFFER_H_
|
||||
|
||||
#include "Tools/Buffer.h"
|
||||
|
||||
template<class T, class U = T, class V = T>
|
||||
class PrepBuffer : public BufferOwner<T, U, V>
|
||||
{
|
||||
int num_players;
|
||||
string fake_opts;
|
||||
|
||||
public:
|
||||
PrepBuffer() :
|
||||
num_players(0)
|
||||
{
|
||||
}
|
||||
|
||||
void setup(int num_players, const string& filename, int tuple_length,
|
||||
const string& type_string = "", const char* data_type = "")
|
||||
{
|
||||
this->num_players = num_players;
|
||||
fake_opts = V::template proto_fake_opts<typename V::open_type>();
|
||||
BufferOwner<T, U, V>::setup(filename, tuple_length, type_string, data_type);
|
||||
}
|
||||
|
||||
void input(U& a)
|
||||
{
|
||||
try
|
||||
{
|
||||
BufferOwner<T, U, V>::input(a);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
throw prep_setup_error(e.what(), num_players, fake_opts);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PREPBUFFER_H_ */
|
||||
@@ -21,18 +21,22 @@
|
||||
#include "GC/Processor.h"
|
||||
#include "GC/ShareThread.h"
|
||||
#include "Protocols/SecureShuffle.h"
|
||||
#include "Tools/NamedStats.h"
|
||||
|
||||
class Program;
|
||||
|
||||
// synchronize in asymmetric protocols
|
||||
template<class T>
|
||||
void sync(vector<Integer>& x, Player& P);
|
||||
|
||||
template <class T>
|
||||
class SubProcessor
|
||||
{
|
||||
CheckVector<typename T::clear> C;
|
||||
CheckVector<T> S;
|
||||
StackedVector<typename T::clear> C;
|
||||
StackedVector<T> S;
|
||||
|
||||
DataPositions bit_usage;
|
||||
|
||||
typename T::Protocol::Shuffler shuffler;
|
||||
NamedStats stats;
|
||||
|
||||
void resize(size_t size) { C.resize(size); S.resize(size); }
|
||||
|
||||
@@ -62,6 +66,8 @@ public:
|
||||
typename BT::LivePrep bit_prep;
|
||||
vector<typename BT::LivePrep*> personal_bit_preps;
|
||||
|
||||
typename T::Protocol::Shuffler shuffler;
|
||||
|
||||
SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC,
|
||||
Preprocessing<T>& DataF, Player& P);
|
||||
SubProcessor(typename T::MAC_Check& MC, Preprocessing<T>& DataF, Player& P,
|
||||
@@ -76,8 +82,8 @@ public:
|
||||
void muls(const vector<int>& reg);
|
||||
void mulrs(const vector<int>& reg);
|
||||
void dotprods(const vector<int>& reg, int size);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction);
|
||||
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction);
|
||||
void matmuls(const StackedVector<T>& source, const Instruction& instruction);
|
||||
void matmulsm(const MemoryPart<T>& source, const vector<int>& args);
|
||||
|
||||
void matmulsm_finalize_batch(vector<int>::const_iterator startMatmul, int startI, int startJ,
|
||||
vector<int>::const_iterator endMatmul,
|
||||
@@ -96,12 +102,12 @@ public:
|
||||
void send_personal(const vector<int>& args);
|
||||
void private_output(const vector<int>& args);
|
||||
|
||||
CheckVector<T>& get_S()
|
||||
StackedVector<T>& get_S()
|
||||
{
|
||||
return S;
|
||||
}
|
||||
|
||||
CheckVector<typename T::clear>& get_C()
|
||||
StackedVector<typename T::clear>& get_C()
|
||||
{
|
||||
return C;
|
||||
}
|
||||
@@ -116,13 +122,17 @@ public:
|
||||
return C[i];
|
||||
}
|
||||
|
||||
void inverse_permutation(const Instruction &instruction, int handle);
|
||||
void inverse_permutation(const Instruction &instruction, int handle);
|
||||
|
||||
void push_stack();
|
||||
void push_args(const vector<int>& args);
|
||||
void pop_stack(const vector<int>& results);
|
||||
};
|
||||
|
||||
class ArithmeticProcessor : public ProcessorBase
|
||||
{
|
||||
protected:
|
||||
CheckVector<Integer> Ci;
|
||||
StackedVector<Integer> Ci;
|
||||
|
||||
ofstream public_output;
|
||||
ofstream binary_output;
|
||||
@@ -174,7 +184,7 @@ public:
|
||||
{ return Ci[i]; }
|
||||
void write_Ci(size_t i, const long& x)
|
||||
{ Ci[i]=x; }
|
||||
CheckVector<Integer>& get_Ci()
|
||||
StackedVector<Integer>& get_Ci()
|
||||
{ return Ci; }
|
||||
|
||||
virtual ofstream& get_public_output()
|
||||
@@ -292,6 +302,8 @@ class Processor : public ArithmeticProcessor
|
||||
ofstream& get_public_output();
|
||||
ofstream& get_binary_output();
|
||||
|
||||
void call_tape(int tape_number, int arg, const vector<int>& results);
|
||||
|
||||
private:
|
||||
|
||||
template<class T> friend class SPDZ;
|
||||
|
||||
@@ -25,9 +25,8 @@ SubProcessor<T>::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check&
|
||||
template <class T>
|
||||
SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
|
||||
Preprocessing<T>& DataF, Player& P, ArithmeticProcessor* Proc) :
|
||||
shuffler(*this),
|
||||
Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC),
|
||||
bit_prep(bit_usage)
|
||||
bit_prep(bit_usage), shuffler(*this)
|
||||
{
|
||||
DataF.set_proc(this);
|
||||
protocol.init(DataF, MC);
|
||||
@@ -113,7 +112,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
secure_prng.ReSeed();
|
||||
shared_prng.SeedGlobally(P, false);
|
||||
|
||||
setup_redirection(P.my_num(), thread_num, opts, out);
|
||||
setup_redirection(P.my_num(), thread_num, opts, out, sint::real_shares(P));
|
||||
Procb.out = out;
|
||||
}
|
||||
|
||||
@@ -158,6 +157,7 @@ void Processor<sint, sgf2n>::reset(const Program& program,int arg)
|
||||
Procp.get_S().resize(program.num_reg(SINT));
|
||||
Procp.get_C().resize(program.num_reg(CINT));
|
||||
Ci.resize(program.num_reg(INT));
|
||||
|
||||
this->arg = arg;
|
||||
Procb.reset(program);
|
||||
}
|
||||
@@ -209,17 +209,6 @@ void Processor<sint, sgf2n>::edabit(const Instruction& instruction, bool strict)
|
||||
&Procp.get_S_ref(instruction.get_r(0)), Procb.S, regs);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::convcbitvec(const Instruction& instruction)
|
||||
{
|
||||
for (size_t i = 0; i < instruction.get_n(); i++)
|
||||
{
|
||||
int i1 = i / GC::Clear::N_BITS;
|
||||
int i2 = i % GC::Clear::N_BITS;
|
||||
Ci[instruction.get_r(0) + i] = Procb.C[instruction.get_r(1) + i1].get_bit(i2);
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::convcintvec(const Instruction& instruction)
|
||||
{
|
||||
@@ -314,10 +303,9 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "send " << socket_stream.get_length() << " to client " << socket_id
|
||||
<< endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("verbose_comm"))
|
||||
fprintf(stderr, "Send %zu bytes to client %d\n", socket_stream.get_length(),
|
||||
socket_id);
|
||||
|
||||
try {
|
||||
TimeScope _(client_stats.add(socket_stream.get_length()));
|
||||
@@ -362,7 +350,10 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id,
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
get_Cp_ref(registers[i] + j) =
|
||||
socket_stream.get<typename sint::open_type>();
|
||||
socket_stream.get<typename sint::share_type::open_type>();
|
||||
|
||||
if (socket_stream.left())
|
||||
throw runtime_error("unexpected data");
|
||||
}
|
||||
|
||||
// Receive vector of field element shares over private channel
|
||||
@@ -562,7 +553,7 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
void SubProcessor<T>::matmuls(const StackedVector<T>& source,
|
||||
const Instruction& instruction)
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
@@ -604,12 +595,10 @@ void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
const Instruction& instruction)
|
||||
const vector<int>& start)
|
||||
{
|
||||
assert(Proc);
|
||||
|
||||
auto& start = instruction.get_start();
|
||||
|
||||
auto batchStartMatrix = start.begin();
|
||||
int batchStartI = 0;
|
||||
int batchStartJ = 0;
|
||||
@@ -816,7 +805,7 @@ Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
|
||||
void Conv2dTuple::pre(StackedVector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
@@ -857,7 +846,7 @@ void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
|
||||
void Conv2dTuple::post(StackedVector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
@@ -1049,17 +1038,85 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
long Processor<sint, sgf2n>::sync(long x) const
|
||||
{
|
||||
vector<Integer> tmp = {x};
|
||||
::sync<sint>(tmp, P);
|
||||
return tmp[0].get();
|
||||
}
|
||||
|
||||
template<class sint>
|
||||
void sync(vector<Integer>& x, Player& P)
|
||||
{
|
||||
if (not sint::symmetric)
|
||||
{
|
||||
octetStream os;
|
||||
// send number to dealer
|
||||
if (P.my_num() == 0)
|
||||
P.send_long(P.num_players() - 1, x);
|
||||
{
|
||||
os.store(x);
|
||||
P.send_to(P.num_players() - 1, os);
|
||||
}
|
||||
if (not sint::real_shares(P))
|
||||
return P.receive_long(0);
|
||||
{
|
||||
P.receive_player(0, os);
|
||||
os.get(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return x;
|
||||
template<class T>
|
||||
void SubProcessor<T>::push_stack()
|
||||
{
|
||||
S.push_stack();
|
||||
C.push_stack();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::push_args(const vector<int>& args)
|
||||
{
|
||||
auto char2 = T::clear::characteristic_two;
|
||||
S.push_args(args, char2 ? SGF2N : SINT);
|
||||
C.push_args(args, char2 ? CGF2N : CINT);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::pop_stack(const vector<int>& results)
|
||||
{
|
||||
auto char2 = T::clear::characteristic_two;
|
||||
S.pop_stack(results, char2 ? SGF2N : SINT);
|
||||
C.pop_stack(results, char2 ? CGF2N : CINT);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::call_tape(int tape_number, int arg,
|
||||
const vector<int>& args)
|
||||
{
|
||||
PC_stack.push_back(PC);
|
||||
arg_stack.push_back(this->arg);
|
||||
Procp.push_stack();
|
||||
Proc2.push_stack();
|
||||
Procb.push_stack();
|
||||
Ci.push_stack();
|
||||
|
||||
auto& tape = machine.progs.at(tape_number);
|
||||
reset(tape, arg);
|
||||
|
||||
Procp.push_args(args);
|
||||
Proc2.push_args(args);
|
||||
Procb.push_args(args);
|
||||
Ci.push_args(args, INT);
|
||||
|
||||
tape.execute(*this);
|
||||
|
||||
Procp.pop_stack(args);
|
||||
Proc2.pop_stack(args);
|
||||
Procb.pop_stack(args);
|
||||
Ci.pop_stack(args, INT);
|
||||
|
||||
PC = PC_stack.back();
|
||||
PC_stack.pop_back();
|
||||
this->arg = arg_stack.back();
|
||||
arg_stack.pop_back();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -27,11 +27,12 @@ void ProcessorBase::open_input_file(int my_num, int thread_num,
|
||||
}
|
||||
|
||||
void ProcessorBase::setup_redirection(int my_num, int thread_num,
|
||||
OnlineOptions& opts, SwitchableOutput& out)
|
||||
OnlineOptions& opts, SwitchableOutput& out, bool real)
|
||||
{
|
||||
// only output on party 0 if not interactive
|
||||
bool always_stdout = opts.cmd_private_output_file == ".";
|
||||
bool output = my_num == 0 or opts.interactive or always_stdout;
|
||||
output &= real;
|
||||
out.activate(output);
|
||||
|
||||
if (not (opts.cmd_private_output_file.empty() or always_stdout))
|
||||
|
||||
@@ -28,6 +28,8 @@ class ProcessorBase
|
||||
protected:
|
||||
// Optional argument to tape
|
||||
Integer arg;
|
||||
vector<Integer> arg_stack;
|
||||
vector<int> PC_stack;
|
||||
|
||||
string get_parameterized_filename(int my_num, int thread_num,
|
||||
const string& prefix);
|
||||
@@ -61,7 +63,7 @@ public:
|
||||
T get_input(istream& is, const string& input_filename, const int* params);
|
||||
|
||||
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts,
|
||||
SwitchableOutput& out);
|
||||
SwitchableOutput& out, bool real = true);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PROCESSORBASE_H_ */
|
||||
|
||||
@@ -28,6 +28,24 @@ void Program::compute_constants()
|
||||
}
|
||||
|
||||
void Program::parse(string filename)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("throw_exceptions"))
|
||||
parse_with_error(filename);
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
parse_with_error(filename);
|
||||
}
|
||||
catch(exception& e)
|
||||
{
|
||||
cerr << "Error in bytecode: " << e.what() << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Program::parse_with_error(string filename)
|
||||
{
|
||||
ifstream pinp(filename);
|
||||
if (pinp.fail())
|
||||
|
||||
@@ -42,6 +42,7 @@ class Program
|
||||
|
||||
// Read in a program
|
||||
void parse(string filename);
|
||||
void parse_with_error(string filename);
|
||||
void parse(istream& s);
|
||||
|
||||
DataPositions get_offline_data_used() const { return offline_data_used; }
|
||||
|
||||
@@ -30,9 +30,14 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv)
|
||||
|
||||
int RingOptions::ring_size_from_opts_or_schedule(string progname)
|
||||
{
|
||||
if (R_is_set)
|
||||
return R;
|
||||
int r = BaseMachine::ring_size_from_schedule(progname);
|
||||
if (R_is_set)
|
||||
{
|
||||
if (r and r != R)
|
||||
cerr << "Different -R option in compilation and run-time: " << r
|
||||
<< " vs " << R << endl;
|
||||
return R;
|
||||
}
|
||||
if (r == 0)
|
||||
r = R;
|
||||
cerr << "Trying to run " << r << "-bit computation" << endl;
|
||||
|
||||
@@ -33,10 +33,12 @@ void ThreadQueue::finished(const ThreadJob& job)
|
||||
out.push(job);
|
||||
}
|
||||
|
||||
void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats)
|
||||
void ThreadQueue::finished(const ThreadJob& job,
|
||||
const NamedCommStats& new_comm_stats, const NamedStats& stats)
|
||||
{
|
||||
finished(job);
|
||||
set_comm_stats(new_comm_stats);
|
||||
this->stats = stats;
|
||||
}
|
||||
|
||||
void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define PROCESSOR_THREADQUEUE_H_
|
||||
|
||||
#include "ThreadJob.h"
|
||||
#include "Tools/NamedStats.h"
|
||||
|
||||
class ThreadQueue
|
||||
{
|
||||
@@ -20,6 +21,7 @@ public:
|
||||
|
||||
map<string, TimerWithComm> timers;
|
||||
Timer wait_timer;
|
||||
NamedStats stats;
|
||||
|
||||
ThreadQueue() :
|
||||
left(0)
|
||||
@@ -34,7 +36,8 @@ public:
|
||||
void schedule(const ThreadJob& job);
|
||||
ThreadJob next();
|
||||
void finished(const ThreadJob& job);
|
||||
void finished(const ThreadJob& job, const NamedCommStats& comm_stats);
|
||||
void finished(const ThreadJob& job, const NamedCommStats& comm_stats,
|
||||
const NamedStats& stats = {});
|
||||
ThreadJob result();
|
||||
|
||||
void set_comm_stats(const NamedCommStats& new_comm_stats);
|
||||
|
||||
Reference in New Issue
Block a user