mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Protocols with function-dependent preprocessing.
This commit is contained in:
@@ -18,6 +18,7 @@ using namespace std;
|
||||
BaseMachine* BaseMachine::singleton = 0;
|
||||
thread_local int BaseMachine::thread_num;
|
||||
thread_local OnDemandOTTripleSetup BaseMachine::ot_setup;
|
||||
thread_local const Program* BaseMachine::program = 0;
|
||||
|
||||
void print_usage(ostream& o, const char* name, size_t capacity)
|
||||
{
|
||||
@@ -38,11 +39,19 @@ bool BaseMachine::has_program()
|
||||
return has_singleton() and not s().progs.empty();
|
||||
}
|
||||
|
||||
DataPositions BaseMachine::get_offline_data_used()
|
||||
{
|
||||
if (program)
|
||||
return program->get_offline_data_used();
|
||||
else
|
||||
return s().progs[0].get_offline_data_used();
|
||||
}
|
||||
|
||||
int BaseMachine::edabit_bucket_size(int n_bits)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
usage = get_offline_data_used().total_edabits(n_bits);
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
@@ -50,7 +59,7 @@ int BaseMachine::triple_bucket_size(DataFieldType type)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().files[type][DATA_TRIPLE];
|
||||
usage = get_offline_data_used().files[type][DATA_TRIPLE];
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
@@ -83,7 +92,7 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
|
||||
{
|
||||
if (has_program())
|
||||
{
|
||||
auto res = s().progs[0].get_offline_data_used().matmuls[
|
||||
auto res = get_offline_data_used().matmuls[
|
||||
{n_rows, n_inner, n_cols}];
|
||||
if (res)
|
||||
return res;
|
||||
@@ -95,7 +104,7 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
|
||||
}
|
||||
|
||||
BaseMachine::BaseMachine() :
|
||||
nthreads(0), multithread(false)
|
||||
nthreads(0), multithread(false), nan_warning(0)
|
||||
{
|
||||
if (sodium_init() == -1)
|
||||
throw runtime_error("couldn't initialize libsodium");
|
||||
@@ -172,6 +181,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
getline(inpf, domain);
|
||||
getline(inpf, relevant_opts);
|
||||
getline(inpf, security);
|
||||
getline(inpf, gf2n);
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
@@ -266,6 +276,15 @@ int BaseMachine::prime_length_from_schedule(string progname)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int BaseMachine::gf2n_length_from_schedule(string progname)
|
||||
{
|
||||
string domain = get_basics(progname).gf2n;
|
||||
if (domain.substr(0, 4).compare("lg2:") == 0)
|
||||
return stoi(domain.substr(4));
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
bigint BaseMachine::prime_from_schedule(string progname)
|
||||
{
|
||||
string domain = get_domain(progname);
|
||||
@@ -286,10 +305,7 @@ int BaseMachine::security_from_schedule(string progname)
|
||||
|
||||
NamedCommStats BaseMachine::total_comm()
|
||||
{
|
||||
NamedCommStats res;
|
||||
for (auto& queue : queues)
|
||||
res += queue->get_comm_stats();
|
||||
return res;
|
||||
return queues.total_comm();
|
||||
}
|
||||
|
||||
void BaseMachine::set_thread_comm(const NamedCommStats& stats)
|
||||
|
||||
@@ -22,23 +22,30 @@ void print_usage(ostream& o, const char* name, size_t capacity);
|
||||
|
||||
class BaseMachine
|
||||
{
|
||||
friend class Program;
|
||||
|
||||
protected:
|
||||
static BaseMachine* singleton;
|
||||
|
||||
static thread_local OnDemandOTTripleSetup ot_setup;
|
||||
|
||||
static thread_local const Program* program;
|
||||
|
||||
std::map<int,TimerWithComm> timer;
|
||||
|
||||
string compiler;
|
||||
string domain;
|
||||
string relevant_opts;
|
||||
string security;
|
||||
string gf2n;
|
||||
|
||||
virtual size_t load_program(const string& threadname,
|
||||
const string& filename);
|
||||
|
||||
static BaseMachine get_basics(string progname);
|
||||
|
||||
static DataPositions get_offline_data_used();
|
||||
|
||||
public:
|
||||
static thread_local int thread_num;
|
||||
|
||||
@@ -52,6 +59,8 @@ public:
|
||||
|
||||
vector<Program> progs;
|
||||
|
||||
bool nan_warning;
|
||||
|
||||
static BaseMachine& s();
|
||||
static bool has_singleton() { return singleton != 0; }
|
||||
static bool has_program();
|
||||
@@ -61,6 +70,7 @@ public:
|
||||
static string get_domain(string progname);
|
||||
static int ring_size_from_schedule(string progname);
|
||||
static int prime_length_from_schedule(string progname);
|
||||
static int gf2n_length_from_schedule(string progname);
|
||||
static bigint prime_from_schedule(string progname);
|
||||
static int security_from_schedule(string progname);
|
||||
|
||||
@@ -127,7 +137,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
|
||||
if (buffer_size <= 0 and has_program())
|
||||
{
|
||||
auto files = s().progs[0].get_offline_data_used().files;
|
||||
auto files = get_offline_data_used().files;
|
||||
auto usage = files[T::clear::field_type()];
|
||||
|
||||
if (type == DATA_DABIT and T::LivePrep::bits_from_dabits())
|
||||
@@ -187,7 +197,7 @@ int BaseMachine::input_batch_size(int player, int buffer_size)
|
||||
if (has_program())
|
||||
{
|
||||
auto res =
|
||||
s().progs[0].get_offline_data_used(
|
||||
get_offline_data_used(
|
||||
).inputs[player][T::clear::field_type()];
|
||||
if (res > 0)
|
||||
return res;
|
||||
@@ -210,7 +220,7 @@ int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
|
||||
|
||||
if (has_program())
|
||||
{
|
||||
n = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
n = get_offline_data_used().total_edabits(n_bits);
|
||||
}
|
||||
|
||||
if (n > 0 and not (buffer_size > 0))
|
||||
|
||||
@@ -113,7 +113,10 @@ protected:
|
||||
void count(Dtype dtype, int n = 1)
|
||||
{ usage.files[T::clear::field_type()][dtype] += do_count * n; }
|
||||
void count_input(int player)
|
||||
{ usage.inputs[player][T::clear::field_type()] += do_count; }
|
||||
{
|
||||
usage.inputs.resize(max(size_t(player + 1), usage.inputs.size()));
|
||||
usage.inputs[player][T::clear::field_type()] += do_count;
|
||||
}
|
||||
|
||||
template<int>
|
||||
void get_edabits(bool strict, size_t size, T* a,
|
||||
@@ -130,6 +133,19 @@ protected:
|
||||
public:
|
||||
int buffer_size;
|
||||
|
||||
/// Key-independent setup if necessary (cryptosystem parameters)
|
||||
static void basic_setup(Player&) {}
|
||||
/// Generate keys if necessary
|
||||
static void setup(Player&, typename T::mac_key_type) {}
|
||||
/// Free memory of global cryptosystem parameters
|
||||
static void teardown() {}
|
||||
|
||||
static void edabit_sacrifice_buckets(vector<edabit<T>>&, size_t, bool, int,
|
||||
SubProcessor<T>&, int, int, const void* = 0)
|
||||
{
|
||||
throw runtime_error("sacrifice not available");
|
||||
}
|
||||
|
||||
template<class U, class V>
|
||||
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
|
||||
SubProcessor<T>* proc);
|
||||
|
||||
@@ -112,6 +112,10 @@ void Sub_Data_Files<T>::check_setup(const Names& N)
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::check_setup(int num_players, const string& prep_dir)
|
||||
{
|
||||
if (T::function_dependent)
|
||||
throw runtime_error("preprocessing from file not implemented "
|
||||
"for function-dependent preprocessing");
|
||||
|
||||
try
|
||||
{
|
||||
T::clear::check_setup(prep_dir);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "HonestMajorityMachine.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Protocols/MascotPrep.h"
|
||||
|
||||
#include "OnlineOptions.hpp"
|
||||
|
||||
@@ -39,7 +40,7 @@ public:
|
||||
ez::ezOptionParser& opt, bool live_prep_default = true)
|
||||
{
|
||||
OnlineOptions& online_opts = OnlineOptions::singleton;
|
||||
online_opts = {opt, argc, argv, T<gfp0>(), live_prep_default};
|
||||
online_opts = {opt, argc, argv, T<gfp0>(), live_prep_default, W()};
|
||||
|
||||
FieldMachine<T, V, X, W>(argc, argv, opt, online_opts);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,9 @@ FieldMachine<T, V, W, X>::FieldMachine(int argc, const char** argv,
|
||||
#undef X
|
||||
default:
|
||||
cerr << "Not compiled for " << online_opts.prime_length() << "-bit primes" << endl;
|
||||
cerr << "Compile with -DGFP_MOD_SZ=" << n_limbs << endl;
|
||||
cerr << "Put 'MOD = -DGFP_MOD_SZ=" << n_limbs
|
||||
<< "' in CONFIG.mine and run " << "'make " << argv[0] << "'"
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
* Argument with binary secret shares (always in array).
|
||||
*
|
||||
* @param n_bits number of bits
|
||||
* @param values shares (vector of vectors of bit_type)
|
||||
* @param values shares (vector of vectors of bit_type of length ceil(n_bits/64))
|
||||
*/
|
||||
template<class T>
|
||||
FunctionArgument(size_t n_bits, vector<vector<T>>& values) :
|
||||
|
||||
@@ -33,9 +33,10 @@ protected:
|
||||
Timer timer;
|
||||
|
||||
// Send my inputs (not generally available)
|
||||
virtual void send_mine() { throw not_implemented(); }
|
||||
virtual void send_mine() { throw runtime_error("implement send_mine()"); }
|
||||
// Get share for next input of mine (not generally available)
|
||||
virtual T finalize_mine() { throw not_implemented(); }
|
||||
virtual T finalize_mine()
|
||||
{ throw runtime_error("implement finalize_mine()"); }
|
||||
// Store share for next input from ``player`` from buffer ``o``
|
||||
// in ``target`` (not generally available)
|
||||
virtual void finalize_other(int, T&, octetStream&, int = -1)
|
||||
@@ -60,6 +61,8 @@ public:
|
||||
InputBase(SubProcessor<T>* proc);
|
||||
virtual ~InputBase();
|
||||
|
||||
bool virtual is_me(int player, int = -1);
|
||||
|
||||
/// Initialize input round for ``player``
|
||||
virtual void reset(int player) = 0;
|
||||
/// Initialize input round for all players
|
||||
|
||||
@@ -28,6 +28,8 @@ template<class T>
|
||||
InputBase<T>::InputBase(SubProcessor<T>* proc) :
|
||||
InputBase(proc ? proc->Proc : 0)
|
||||
{
|
||||
if (proc)
|
||||
P = &proc->P;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -38,7 +40,7 @@ Input<T>::Input(SubProcessor<T>& proc) :
|
||||
|
||||
template<class T>
|
||||
Input<T>::Input(SubProcessor<T>& proc, MAC_Check& mc) :
|
||||
InputBase<T>(proc.Proc), proc(&proc), MC(mc), prep(proc.DataF), P(proc.P),
|
||||
InputBase<T>(&proc), proc(&proc), MC(mc), prep(proc.DataF), P(proc.P),
|
||||
shares(proc.P.num_players())
|
||||
{
|
||||
}
|
||||
@@ -66,6 +68,13 @@ InputBase<T>::~InputBase()
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
bool InputBase<T>::is_me(int player, int)
|
||||
{
|
||||
assert(P);
|
||||
return player == P->my_num();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::reset(int player)
|
||||
{
|
||||
@@ -150,14 +159,15 @@ void InputBase<T>::raw_input(SubProcessor<T>& proc, const vector<int>& args,
|
||||
{
|
||||
int player = *it++;
|
||||
it++;
|
||||
if (player == P.my_num())
|
||||
if (is_me(player))
|
||||
{
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
clear t;
|
||||
try
|
||||
{
|
||||
this->buffer.input(t);
|
||||
if (T::real_shares(P))
|
||||
this->buffer.input(t);
|
||||
}
|
||||
catch (not_enough_to_buffer& e)
|
||||
{
|
||||
@@ -222,12 +232,14 @@ void InputBase<T>::prepare(SubProcessor<T>& Proc, int player, const int* params,
|
||||
{
|
||||
auto& input = Proc.input;
|
||||
assert(Proc.Proc != 0);
|
||||
if (player == Proc.P.my_num())
|
||||
if (input.is_me(player))
|
||||
{
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
U tuple = Proc.Proc->template get_input<U>(Proc.Proc->use_stdin(),
|
||||
params);
|
||||
U tuple;
|
||||
if (T::real_shares(Proc.P))
|
||||
tuple = Proc.Proc->template get_input<U>(
|
||||
Proc.Proc->use_stdin(), params);
|
||||
for (auto x : tuple.items)
|
||||
input.add_mine(x);
|
||||
}
|
||||
|
||||
@@ -106,11 +106,17 @@ string BaseInstruction::get_name() const
|
||||
COMBI_INSTRUCTIONS
|
||||
default:
|
||||
stringstream ss;
|
||||
ss << hex << get_opcode();
|
||||
ss << showbase << hex << get_opcode();
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
void BaseInstruction::bytecode_assert(bool condition) const
|
||||
{
|
||||
if (not condition)
|
||||
throw runtime_error("bytecode assertion violated");
|
||||
}
|
||||
|
||||
ostream& operator<<(ostream& s, const Instruction& instr)
|
||||
{
|
||||
s << instr.get_name();
|
||||
|
||||
@@ -137,6 +137,7 @@ enum
|
||||
SEDABIT = 0x5A,
|
||||
RANDOMS = 0x5B,
|
||||
RANDOMFULLS = 0x5D,
|
||||
UNSPLIT = 0x5E,
|
||||
// Input
|
||||
INPUT = 0x60,
|
||||
INPUTFIX = 0xF0,
|
||||
@@ -269,6 +270,8 @@ enum
|
||||
GMULS = 0x1A6,
|
||||
GMULRS = 0x1A7,
|
||||
GDOTPRODS = 0x1A8,
|
||||
GMATMULS = 0x1AA,
|
||||
GMATMULSM = 0x1AB,
|
||||
GSECSHUFFLE = 0x1FA,
|
||||
// Data access
|
||||
GTRIPLE = 0x150,
|
||||
@@ -356,6 +359,8 @@ protected:
|
||||
vector<int> start; // Values for a start/stop open
|
||||
string str;
|
||||
|
||||
void bytecode_assert(bool condition) const;
|
||||
|
||||
public:
|
||||
BaseInstruction() : opcode(0), size(0), n(0) {}
|
||||
virtual ~BaseInstruction() {};
|
||||
|
||||
@@ -37,6 +37,9 @@ void BaseInstruction::parse(istream& s, int inst_pos)
|
||||
size = code >> 10;
|
||||
opcode = 0x3FF & code;
|
||||
|
||||
if (s.fail())
|
||||
throw bytecode_error("cannot read opcode");
|
||||
|
||||
if (size==0)
|
||||
size=1;
|
||||
|
||||
@@ -141,6 +144,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRINTREGPLAIN:
|
||||
case PRINTREGPLAINB:
|
||||
case PRINTREGPLAINS:
|
||||
case PRINTREGPLAINSB:
|
||||
case LDTN:
|
||||
case LDARG:
|
||||
case STARG:
|
||||
@@ -187,7 +191,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case SHLCBI:
|
||||
case SHRCBI:
|
||||
case NOTC:
|
||||
case CONVMODP:
|
||||
case GADDCI:
|
||||
case GADDSI:
|
||||
case GSUBCI:
|
||||
@@ -212,6 +215,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
n = get_int(s);
|
||||
break;
|
||||
case PICKS:
|
||||
case CONVMODP:
|
||||
get_ints(r, s, 3);
|
||||
n = get_int(s);
|
||||
break;
|
||||
@@ -320,8 +324,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case RUN_TAPE:
|
||||
case CONV2DS:
|
||||
case MATMULS:
|
||||
case GMATMULS:
|
||||
case APPLYSHUFFLE:
|
||||
case MATMULSM:
|
||||
case GMATMULSM:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
@@ -400,6 +406,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case WRITEFILESHARE:
|
||||
case GWRITEFILESHARE:
|
||||
case CONCATS:
|
||||
case UNSPLIT:
|
||||
num_var_args = get_int(s) - 1;
|
||||
r[0] = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
@@ -729,7 +736,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += *it)
|
||||
{
|
||||
assert(it + *it <= start.end());
|
||||
bytecode_assert(it + *it <= start.end());
|
||||
res = max(res, it[1] + it[2]);
|
||||
}
|
||||
return res;
|
||||
@@ -746,7 +753,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
auto it = start.begin();
|
||||
while (it != start.end())
|
||||
{
|
||||
assert(it < start.end());
|
||||
bytecode_assert(it < start.end());
|
||||
int n = *it;
|
||||
res = max(res, *++it + size);
|
||||
it += n - 1;
|
||||
@@ -754,6 +761,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return res;
|
||||
}
|
||||
case MATMULS:
|
||||
case GMATMULS:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 6)
|
||||
@@ -764,6 +772,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return res;
|
||||
}
|
||||
case MATMULSM:
|
||||
case GMATMULSM:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 12)
|
||||
@@ -866,7 +875,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
int n = *it - n_prefix;
|
||||
size = max((long long) size, DIV_CEIL(*(it + 1), 64));
|
||||
it += n_prefix;
|
||||
assert(it + n <= start.end());
|
||||
bytecode_assert(it + n <= start.end());
|
||||
for (int i = 0; i < n; i++)
|
||||
res = max(res, *it++ + size);
|
||||
}
|
||||
@@ -958,7 +967,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
"integer registers only have 64 bits");
|
||||
values.push_back(tmp);
|
||||
}
|
||||
sync<sint>(values, Proc.P);
|
||||
if (r[2])
|
||||
Procp.protocol.sync(values, Proc.P);
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Ci(r[0] + i, values[i].get());
|
||||
return;
|
||||
@@ -987,8 +997,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
for (auto j = start.begin(); j < start.end(); j += 2)
|
||||
{
|
||||
auto source = S.begin() + *(j + 1);
|
||||
assert(dest + *j <= S.end());
|
||||
assert(source + *j <= S.end());
|
||||
bytecode_assert(dest + *j <= S.end());
|
||||
bytecode_assert(source + *j <= S.end());
|
||||
for (int k = 0; k < *j; k++)
|
||||
*dest++ = *source++;
|
||||
}
|
||||
@@ -1165,14 +1175,21 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case MATMULS:
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
|
||||
return;
|
||||
case GMATMULS:
|
||||
Proc.Proc2.matmuls(Proc.Proc2.get_S(), *this);
|
||||
return;
|
||||
case MATMULSM:
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this);
|
||||
return;
|
||||
case GMATMULSM:
|
||||
Proc.Proc2.protocol.matmulsm(Proc.Proc2, Proc.machine.M2.MS, *this);
|
||||
return;
|
||||
case CONV2DS:
|
||||
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
|
||||
return;
|
||||
case TRUNC_PR:
|
||||
Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp);
|
||||
Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp,
|
||||
sint::clear::characteristic_two);
|
||||
return;
|
||||
case SECSHUFFLE:
|
||||
Proc.Procp.secure_shuffle(*this);
|
||||
@@ -1424,12 +1441,14 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Procp.protocol.cisc(Procp, *this);
|
||||
return;
|
||||
default:
|
||||
printf("Case of opcode=0x%x not implemented yet\n",opcode);
|
||||
throw invalid_opcode(opcode);
|
||||
break;
|
||||
#define X(NAME, CODE) case NAME:
|
||||
COMBI_INSTRUCTIONS
|
||||
#undef X
|
||||
#define X(NAME, CODE) case NAME: throw no_dynamic_memory();
|
||||
DYNAMIC_INSTRUCTIONS
|
||||
#undef X
|
||||
#define X(NAME, PRE, CODE) case NAME:
|
||||
ARITHMETIC_INSTRUCTIONS
|
||||
#undef X
|
||||
@@ -1482,6 +1501,8 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
|
||||
auto& processor = Proc.Procb;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
BaseMachine::program = this;
|
||||
|
||||
while (Proc.PC<size)
|
||||
{
|
||||
Proc.last_PC = Proc.PC;
|
||||
@@ -1535,6 +1556,15 @@ void Program::execute_with_errors(Processor<sint, sgf2n>& Proc) const
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Program::mulm_check() const
|
||||
{
|
||||
if (T::function_dependent and not OnlineOptions::singleton.has_option("allow_mulm"))
|
||||
throw runtime_error("Mixed multiplication not implemented for function-dependent preprocessing. "
|
||||
"Use '-E <protocol>' during compilation or state "
|
||||
"'program.use_mulm = False' at the beginning of your high-level program.");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) const
|
||||
{
|
||||
|
||||
@@ -30,8 +30,6 @@
|
||||
#include <atomic>
|
||||
using namespace std;
|
||||
|
||||
#include "OnlineOptions.hpp"
|
||||
|
||||
template<class sint, class sgf2n = NoShare<gf2n>>
|
||||
class Machine : public BaseMachine
|
||||
{
|
||||
@@ -52,6 +50,10 @@ class Machine : public BaseMachine
|
||||
|
||||
Player* P;
|
||||
|
||||
RunningTimer setup_timer;
|
||||
|
||||
NamedCommStats max_comm;
|
||||
|
||||
size_t load_program(const string& threadname, const string& filename);
|
||||
|
||||
void prepare(const string& progname_str);
|
||||
@@ -82,7 +84,7 @@ class Machine : public BaseMachine
|
||||
static void init_binary_domains(int security_parameter, int lg2);
|
||||
|
||||
Machine(Names& playerNames, bool use_encryption = true,
|
||||
const OnlineOptions opts = sint(), int lg2 = 0);
|
||||
const OnlineOptions opts = sint());
|
||||
~Machine();
|
||||
|
||||
const Names& get_N() { return N; }
|
||||
|
||||
@@ -53,7 +53,7 @@ void Machine<sint, sgf2n>::init_binary_domains(int security_parameter, int lg2)
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
const OnlineOptions opts, int lg2)
|
||||
const OnlineOptions opts)
|
||||
: my_number(playerNames.my_num()), N(playerNames),
|
||||
use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts),
|
||||
external_clients(my_number)
|
||||
@@ -81,7 +81,7 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
if (opts.prime)
|
||||
sint::clear::init_field(opts.prime);
|
||||
|
||||
init_binary_domains(opts.security_parameter, lg2);
|
||||
init_binary_domains(opts.security_parameter, opts.lg2);
|
||||
|
||||
// make directory for outputs if necessary
|
||||
mkdir_p(PREP_DIR);
|
||||
@@ -132,16 +132,19 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
auto memtype = opts.memtype;
|
||||
if (memtype.compare("old")==0)
|
||||
{
|
||||
ifstream inpf;
|
||||
inpf.open(memory_filename(), ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(memory_filename()); }
|
||||
inpf >> M2 >> Mp >> Mi;
|
||||
if (inpf.get() != 'M')
|
||||
if (sint::real_shares(*P))
|
||||
{
|
||||
cerr << "Invalid memory file. Run with '-m empty'." << endl;
|
||||
exit(1);
|
||||
ifstream inpf;
|
||||
inpf.open(memory_filename(), ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(memory_filename()); }
|
||||
inpf >> M2 >> Mp >> Mi;
|
||||
if (inpf.get() != 'M')
|
||||
{
|
||||
cerr << "Invalid memory file. Run with '-m empty'." << endl;
|
||||
exit(1);
|
||||
}
|
||||
inpf.close();
|
||||
}
|
||||
inpf.close();
|
||||
}
|
||||
else if (!(memtype.compare("empty")==0))
|
||||
{ cerr << "Invalid memory argument" << endl;
|
||||
@@ -177,9 +180,9 @@ void Machine<sint, sgf2n>::prepare(const string& progname_str)
|
||||
|
||||
/* Set up the threads */
|
||||
tinfo.resize(nthreads);
|
||||
threads.resize(nthreads);
|
||||
queues.resize(nthreads);
|
||||
join_timer.resize(nthreads);
|
||||
assert(threads.size() == size_t(old_n_threads));
|
||||
|
||||
for (int i = old_n_threads; i < nthreads; i++)
|
||||
{
|
||||
@@ -191,9 +194,18 @@ void Machine<sint, sgf2n>::prepare(const string& progname_str)
|
||||
tinfo[i].alphapi=&alphapi;
|
||||
tinfo[i].alpha2i=&alpha2i;
|
||||
tinfo[i].machine=this;
|
||||
pthread_create(&threads[i],NULL,thread_info<sint, sgf2n>::Main_Func,&tinfo[i]);
|
||||
pthread_t thread;
|
||||
int res = pthread_create(&thread, NULL,
|
||||
thread_info<sint, sgf2n>::Main_Func, &tinfo[i]);
|
||||
|
||||
if (res == 0)
|
||||
threads.push_back(thread);
|
||||
else
|
||||
throw runtime_error("cannot start thread");
|
||||
}
|
||||
|
||||
assert(queues.size() == threads.size());
|
||||
|
||||
// synchronize with clients before starting timer
|
||||
for (int i=old_n_threads; i<nthreads; i++)
|
||||
{
|
||||
@@ -415,8 +427,16 @@ void Machine<sint, sgf2n>::run_function(const string& name,
|
||||
result.check_type(return_type);
|
||||
|
||||
vector<int> arg_regs(arguments.size());
|
||||
for (auto& arg_reg : arg_regs)
|
||||
file >> arg_reg;
|
||||
vector<int> address_regs(arguments.size());
|
||||
for (size_t i = 0; i < arguments.size(); i++)
|
||||
{
|
||||
file >> arg_regs.at(i);
|
||||
if (arguments[i].get_memory())
|
||||
file >> address_regs.at(i);
|
||||
}
|
||||
|
||||
if (not file.good())
|
||||
throw runtime_error("error reading file for function " + name);
|
||||
|
||||
prepare(progname);
|
||||
auto& processor = *tinfo.at(0).processor;
|
||||
@@ -447,6 +467,8 @@ void Machine<sint, sgf2n>::run_function(const string& name,
|
||||
assert(arguments[i].has_reg_type("ci"));
|
||||
processor.write_Ci(arg_regs.at(i) + j, arguments[i].get_value<long>(j));
|
||||
}
|
||||
if (arguments[i].get_memory())
|
||||
processor.write_Ci(address_regs.at(i), arg_regs.at(i));
|
||||
}
|
||||
|
||||
run_tape(0, tape_number, 0, N.num_players());
|
||||
@@ -476,11 +498,16 @@ void Machine<sint, sgf2n>::run_function(const string& name,
|
||||
template<class sint, class sgf2n>
|
||||
pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
{
|
||||
// only stop actually running threads
|
||||
nthreads = threads.size();
|
||||
|
||||
// Tell all C-threads to stop
|
||||
for (int i=0; i<nthreads; i++)
|
||||
{
|
||||
//printf("Send kill signal to client\n");
|
||||
queues[i]->schedule(-1);
|
||||
auto queue = queues.at(i);
|
||||
assert(queue);
|
||||
queue->schedule(-1);
|
||||
}
|
||||
|
||||
// sum actual usage
|
||||
@@ -498,6 +525,7 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
}
|
||||
|
||||
auto comm_stats = total_comm();
|
||||
max_comm = queues.max_comm();
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
{
|
||||
@@ -509,9 +537,11 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
}
|
||||
|
||||
for (auto& queue : queues)
|
||||
delete queue;
|
||||
if (queue)
|
||||
delete queue;
|
||||
|
||||
queues.clear();
|
||||
threads.clear();
|
||||
|
||||
nthreads = 0;
|
||||
|
||||
@@ -523,6 +553,12 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
{
|
||||
prepare(progname);
|
||||
|
||||
if (opts.verbose and setup_timer.is_running())
|
||||
{
|
||||
cerr << "Setup took " << setup_timer.elapsed() << " seconds." << endl;
|
||||
setup_timer.stop();
|
||||
}
|
||||
|
||||
Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID);
|
||||
proc_timer.start();
|
||||
timer[0].start({});
|
||||
@@ -564,9 +600,9 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
{
|
||||
cerr << "Communication details";
|
||||
if (multithread)
|
||||
cerr << " (rounds in parallel threads counted double)";
|
||||
cerr << " (rounds and time in parallel threads counted double)";
|
||||
cerr << ":" << endl;
|
||||
comm_stats.print();
|
||||
comm_stats.print(false, max_comm);
|
||||
cerr << "CPU time = " << proc_timer.elapsed();
|
||||
if (multithread)
|
||||
cerr << " (overall core time)";
|
||||
@@ -601,17 +637,25 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
Mp.resize_s(max_size);
|
||||
}
|
||||
|
||||
// Write out the memory to use next time
|
||||
ofstream outf(memory_filename(), ios::out | ios::binary);
|
||||
outf << M2 << Mp << Mi;
|
||||
outf << 'M';
|
||||
outf.close();
|
||||
if (sint::real_shares(*P) and not opts.has_option("no_memory_output"))
|
||||
{
|
||||
RunningTimer timer;
|
||||
// Write out the memory to use next time
|
||||
ofstream outf(memory_filename(), ios::out | ios::binary);
|
||||
outf << M2 << Mp << Mi;
|
||||
outf << 'M';
|
||||
outf.close();
|
||||
|
||||
bit_memories.write_memory(N.my_num());
|
||||
bit_memories.write_memory(N.my_num());
|
||||
|
||||
if (opts.has_option("time_memory_output"))
|
||||
cerr << "Writing memory to disk took " << timer.elapsed() << " seconds"
|
||||
<< endl;
|
||||
}
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Actual cost of program:" << endl;
|
||||
cerr << "Actual preprocessing cost of program:" << endl;
|
||||
pos.print_cost();
|
||||
}
|
||||
|
||||
@@ -641,6 +685,14 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
<< "have you considered using " << alt << " instead?" << endl;
|
||||
}
|
||||
|
||||
if (nan_warning and sint::real_shares(*P))
|
||||
{
|
||||
cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. See ";
|
||||
cerr << "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix";
|
||||
cerr << " for details" << endl;
|
||||
nan_warning = false;
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "End of prog" << endl;
|
||||
#endif
|
||||
|
||||
@@ -45,7 +45,7 @@ int OfflineMachine<W>::run()
|
||||
{
|
||||
T::clear::init_default(this->online_opts.prime_length());
|
||||
Machine<T, U>::init_binary_domains(this->online_opts.security_parameter,
|
||||
this->lg2);
|
||||
this->get_lg2());
|
||||
auto binary_mac_key = read_generate_write_mac_key<
|
||||
typename T::bit_type::part_type>(P);
|
||||
typename T::bit_type::LivePrep bit_prep(usage);
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
#include "Protocols/LimitedPrep.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "GC/BitAdder.hpp"
|
||||
|
||||
#include <iostream>
|
||||
@@ -354,11 +355,17 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
|
||||
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
|
||||
|
||||
assert(Proc.share_thread.protocol);
|
||||
queues->timers["random"] = Proc.Procp.protocol.randomness_time()
|
||||
+ Proc.Proc2.protocol.randomness_time()
|
||||
+ Proc.share_thread.protocol->randomness_time();
|
||||
|
||||
NamedStats stats;
|
||||
stats["integer multiplications"] = Proc.Procp.protocol.counter;
|
||||
stats["integer multiplication rounds"] = Proc.Procp.protocol.rounds;
|
||||
stats["integer dot products"] = Proc.Procp.protocol.dot_counter;
|
||||
stats["probabilistic truncations"] = Proc.Procp.protocol.trunc_pr_counter;
|
||||
stats["probabilistic truncations (big gap)"] = Proc.Procp.protocol.trunc_pr_big_counter;
|
||||
stats["probabilistic truncation rounds"] = Proc.Procp.protocol.trunc_rounds;
|
||||
stats["ANDs"] = Proc.share_thread.protocol->bit_counter;
|
||||
stats["AND rounds"] = Proc.share_thread.protocol->rounds;
|
||||
|
||||
@@ -17,8 +17,6 @@ protected:
|
||||
const char** argv;
|
||||
OnlineOptions& online_opts;
|
||||
|
||||
int lg2;
|
||||
|
||||
Names playerNames;
|
||||
|
||||
bool use_encryption;
|
||||
@@ -43,7 +41,7 @@ public:
|
||||
|
||||
int get_lg2()
|
||||
{
|
||||
return lg2;
|
||||
return online_opts.lg2;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -12,26 +12,19 @@
|
||||
#include <string>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
template<class V>
|
||||
OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
OnlineOptions& online_opts, int nplayers, V) :
|
||||
argc(argc), argv(argv), online_opts(online_opts), lg2(0),
|
||||
argc(argc), argv(argv), online_opts(online_opts),
|
||||
use_encryption(false),
|
||||
opt(opt), nplayers(nplayers)
|
||||
{
|
||||
opt.add(
|
||||
to_string(V::default_degree()).c_str(), // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Bit length of GF(2^n) field (default: "
|
||||
+ to_string(V::default_degree()) + "; options are "
|
||||
+ V::options() + ")").c_str(), // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"5000", // Default.
|
||||
0, // Required?
|
||||
@@ -83,7 +76,6 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
opt.resetArgs();
|
||||
}
|
||||
|
||||
@@ -197,7 +189,7 @@ int OnlineMachine::run_with_error()
|
||||
try
|
||||
#endif
|
||||
{
|
||||
Machine<T, U>(playerNames, use_encryption, online_opts, lg2).run(
|
||||
Machine<T, U>(playerNames, use_encryption, online_opts).run(
|
||||
online_opts.progname);
|
||||
|
||||
if (online_opts.verbose)
|
||||
|
||||
@@ -23,6 +23,7 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
{
|
||||
interactive = false;
|
||||
lgp = gfp0::MAX_N_BITS;
|
||||
lg2 = 0;
|
||||
live_prep = true;
|
||||
batch_size = 1000;
|
||||
memtype = "empty";
|
||||
@@ -38,6 +39,7 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
opening_sum = 0;
|
||||
max_broadcast = 0;
|
||||
receive_threads = false;
|
||||
code_locations = false;
|
||||
#ifdef VERBOSE
|
||||
verbose = true;
|
||||
#else
|
||||
@@ -123,6 +125,14 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-o", // Flag token.
|
||||
"--options" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Output code locations of the most relevant protocols used", // Help description.
|
||||
"--code-locations" // Flag token.
|
||||
);
|
||||
|
||||
if (security)
|
||||
opt.add(
|
||||
@@ -130,7 +140,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Statistical ecurity parameter (default: " + to_string(security_parameter)
|
||||
("Statistical security parameter (default: " + to_string(security_parameter)
|
||||
+ ")").c_str(), // Help description.
|
||||
"-S", // Flag token.
|
||||
"--security" // Flag token.
|
||||
@@ -151,6 +161,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
|
||||
opt.get("--options")->getStrings(options);
|
||||
|
||||
code_locations = opt.isSet("--code-locations");
|
||||
|
||||
#ifdef THROW_EXCEPTIONS
|
||||
options.push_back("throw_exceptions");
|
||||
#endif
|
||||
@@ -164,6 +176,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
else
|
||||
security_parameter = 1000;
|
||||
|
||||
opt.resetArgs();
|
||||
|
||||
@@ -417,6 +431,23 @@ void OnlineOptions::finalize_with_error(ez::ezOptionParser& opt)
|
||||
lgp = prog_lgp;
|
||||
}
|
||||
|
||||
if (opt.get("--lg2"))
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
|
||||
int prog_lg2 = BaseMachine::gf2n_length_from_schedule(progname);
|
||||
if (prog_lg2)
|
||||
{
|
||||
if (prog_lg2 != lg2 and opt.isSet("lg2"))
|
||||
{
|
||||
cerr << "GF(2^n) mismatch between command line and program" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (verbose)
|
||||
cerr << "Using GF(2^" << prog_lg2 << ") as requested by program" << endl;
|
||||
lg2 = prog_lg2;
|
||||
}
|
||||
|
||||
set_trunc_error(opt);
|
||||
|
||||
auto o = opt.get("--opening-sum");
|
||||
@@ -457,9 +488,8 @@ void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt)
|
||||
if (opt.get("-E"))
|
||||
{
|
||||
opt.get("-E")->getInt(trunc_error);
|
||||
#ifdef VERBOSE
|
||||
cerr << "Truncation error probability 2^-" << trunc_error << endl;
|
||||
#endif
|
||||
if (verbose)
|
||||
cerr << "Truncation error probability 2^-" << trunc_error << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/bigint.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
class OnlineOptions
|
||||
{
|
||||
@@ -19,6 +20,7 @@ public:
|
||||
|
||||
bool interactive;
|
||||
int lgp;
|
||||
int lg2;
|
||||
bigint prime;
|
||||
bool live_prep;
|
||||
int playerno;
|
||||
@@ -41,6 +43,7 @@ public:
|
||||
vector<long> args;
|
||||
vector<string> options;
|
||||
string executable;
|
||||
bool code_locations;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
@@ -48,9 +51,9 @@ public:
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
int default_batch_size = 0, bool default_live_prep = true,
|
||||
bool variable_prime_length = false, bool security = true);
|
||||
template<class T>
|
||||
template<class T, class V = gf2n>
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T,
|
||||
bool default_live_prep = true);
|
||||
bool default_live_prep = true, V = {});
|
||||
template<class T>
|
||||
OnlineOptions(T);
|
||||
~OnlineOptions() {}
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
#include "OnlineOptions.h"
|
||||
|
||||
template<class T>
|
||||
template<class T, class V>
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, T, bool default_live_prep) :
|
||||
const char** argv, T, bool default_live_prep, V) :
|
||||
OnlineOptions(opt, argc, argv, OnlineOptions(T()).batch_size,
|
||||
default_live_prep, T::clear::prime_field)
|
||||
default_live_prep, T::clear::prime_field,
|
||||
T::LivePrep::homomorphic or T::malicious)
|
||||
{
|
||||
if (T::has_trunc_pr)
|
||||
opt.add(
|
||||
@@ -58,6 +59,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
}
|
||||
|
||||
if (not T::clear::binary)
|
||||
{
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
@@ -69,6 +71,19 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"--disk-memory" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
to_string(V::default_degree()).c_str(), // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Bit length of GF(2^n) field (default: "
|
||||
+ to_string(V::default_degree()) + "; options are "
|
||||
+ V::options() + ")").c_str(), // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
if (T::variable_players)
|
||||
opt.add(
|
||||
T::dishonest_majority ? "2" : "3", // Default.
|
||||
|
||||
@@ -285,6 +285,7 @@ class Processor : public ArithmeticProcessor
|
||||
void convcintvec(const Instruction& instruction);
|
||||
void convcbit2s(const Instruction& instruction);
|
||||
void split(const Instruction& instruction);
|
||||
void unsplit(const Instruction& instruction);
|
||||
|
||||
// Access to external client sockets for reading clear/shared data
|
||||
void read_socket_ints(int client_id, const vector<int>& registers, int size);
|
||||
@@ -302,7 +303,7 @@ class Processor : public ArithmeticProcessor
|
||||
void fixinput(const Instruction& instruction);
|
||||
|
||||
// synchronize in asymmetric protocols
|
||||
long sync(long x) const;
|
||||
long sync(long x);
|
||||
|
||||
ofstream& get_public_output();
|
||||
ofstream& get_binary_output();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "GC/square64.h"
|
||||
#include "SpecificPrivateOutput.h"
|
||||
#include "Conv2dTuple.h"
|
||||
#include "Protocols/Replicated.h"
|
||||
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
@@ -111,6 +112,9 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
|
||||
secure_prng.ReSeed();
|
||||
shared_prng.SeedGlobally(P, false);
|
||||
vector<IntBase<octet>> seed(shared_prng.get_seed(), shared_prng.get_seed() + SEED_SIZE);
|
||||
Procp.protocol.forward_sync(seed);
|
||||
shared_prng.SetSeed((octet*) seed.data());
|
||||
|
||||
setup_redirection(P.my_num(), thread_num, opts, out, sint::real_shares(P));
|
||||
Procb.out = out;
|
||||
@@ -192,10 +196,11 @@ void Processor<sint, sgf2n>::dabit(const Instruction& instruction)
|
||||
{
|
||||
Procb.S[instruction.get_r(1) + i] = {};
|
||||
}
|
||||
auto a = Procp.get_S().iterator_for_size(instruction.get_r(0), size);
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
typename sint::bit_type tmp;
|
||||
Procp.DataF.get_dabit(Procp.get_S_ref(instruction.get_r(0) + i), tmp);
|
||||
Procp.DataF.get_dabit(*a++, tmp);
|
||||
Procb.S[instruction.get_r(1) + i / unit] ^= tmp << (i % unit);
|
||||
}
|
||||
}
|
||||
@@ -248,6 +253,12 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
|
||||
&read_Sp(instruction.get_r(0)), n_inputs, *share_thread.protocol);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::unsplit(const Instruction& instruction)
|
||||
{
|
||||
Procp.protocol.unsplit(Procp.S, Procb.S, instruction);
|
||||
}
|
||||
|
||||
|
||||
#include "Networking/sockets.h"
|
||||
#include "Math/Setup.h"
|
||||
@@ -487,25 +498,25 @@ template<class T>
|
||||
void SubProcessor<T>::muls(const vector<int>& reg)
|
||||
{
|
||||
assert(reg.size() % 4 == 0);
|
||||
int n = reg.size() / 4;
|
||||
|
||||
SubProcessor<T>& proc = *this;
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
auto& x = proc.S[reg[4 * i + 2] + j];
|
||||
auto& y = proc.S[reg[4 * i + 3] + j];
|
||||
protocol.prepare_mul(x, y);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (auto it = reg.begin(); it < reg.end(); it += 4)
|
||||
{
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul();
|
||||
}
|
||||
protocol.counter += reg[4 * i];
|
||||
for (int j = 1; j < 4; j++)
|
||||
assert(proc.S.begin() + *(it + j) <= proc.S.end());
|
||||
auto x = proc.S.begin() + *(it + 2);
|
||||
auto y = proc.S.begin() + *(it + 3);
|
||||
for (int j = 0; j < *it; j++)
|
||||
protocol.prepare_mul(*x++, *y++);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (auto it = reg.begin(); it < reg.end(); it += 4)
|
||||
{
|
||||
auto z = proc.S.begin() + *(it + 1);
|
||||
for (int j = 0; j < *it; j++)
|
||||
*z++ = protocol.finalize_mul();
|
||||
protocol.counter += *it;
|
||||
}
|
||||
|
||||
maybe_check();
|
||||
@@ -946,7 +957,7 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
|
||||
{
|
||||
input.reset_all(P);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
if (args[i + 1] == P.my_num())
|
||||
if (input.is_me(args[i + 1]))
|
||||
{
|
||||
auto begin = C.begin() + args[i + 3];
|
||||
auto end = begin + args[i];
|
||||
@@ -1047,8 +1058,14 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
throw runtime_error("unknown format for fixed-point input");
|
||||
}
|
||||
|
||||
if (not sint::real_shares(P))
|
||||
return;
|
||||
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
throw IO_Error(
|
||||
"Failure reading from " + binary_input_filename
|
||||
+ ". You might need to copy it "
|
||||
+ "from the location of compilation.");
|
||||
|
||||
if (binary_input.peek() == EOF)
|
||||
throw IO_Error("not enough inputs in " + binary_input_filename);
|
||||
@@ -1084,15 +1101,16 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
long Processor<sint, sgf2n>::sync(long x) const
|
||||
long Processor<sint, sgf2n>::sync(long x)
|
||||
{
|
||||
vector<Integer> tmp = {x};
|
||||
::sync<sint>(tmp, P);
|
||||
Procp.protocol.sync(tmp, P);
|
||||
return tmp[0].get();
|
||||
}
|
||||
|
||||
template<class sint>
|
||||
void sync(vector<Integer>& x, Player& P)
|
||||
template<class U>
|
||||
void ProtocolBase<sint>::sync(vector<U>& x, Player& P)
|
||||
{
|
||||
if (not sint::symmetric)
|
||||
{
|
||||
|
||||
@@ -32,7 +32,7 @@ void ProcessorBase::setup_redirection(int my_num, int thread_num,
|
||||
// only output on party 0 if not interactive
|
||||
bool always_stdout = opts.cmd_private_output_file == ".";
|
||||
bool output = my_num == 0 or opts.interactive or always_stdout;
|
||||
output &= real;
|
||||
output &= real or always_stdout;
|
||||
out.activate(output);
|
||||
|
||||
if (not (opts.cmd_private_output_file.empty() or always_stdout))
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
void Program::compute_constants()
|
||||
{
|
||||
bool debug = OnlineOptions::singleton.has_option("debug_alloc");
|
||||
|
||||
for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++)
|
||||
{
|
||||
max_reg[reg_type] = 0;
|
||||
@@ -18,8 +20,10 @@ void Program::compute_constants()
|
||||
unknown_usage = true;
|
||||
for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++)
|
||||
{
|
||||
max_reg[reg_type] = max(max_reg[reg_type],
|
||||
p[i].get_max_reg(reg_type));
|
||||
auto reg = p[i].get_max_reg(reg_type);
|
||||
if (debug and reg)
|
||||
cerr << i << ": " << reg << endl;
|
||||
max_reg[reg_type] = max(max_reg[reg_type], reg);
|
||||
max_mem[reg_type] = max(max_mem[reg_type],
|
||||
p[i].get_mem(RegType(reg_type)));
|
||||
}
|
||||
@@ -93,7 +97,11 @@ void Program::parse(istream& s)
|
||||
{
|
||||
instr.parse(s, p.size());
|
||||
}
|
||||
catch (bad_alloc&)
|
||||
catch (bytecode_error&)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (exception&)
|
||||
{
|
||||
fail = true;
|
||||
}
|
||||
|
||||
@@ -70,6 +70,9 @@ class Program
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void execute_with_errors(Processor<sint, sgf2n>& Proc) const;
|
||||
|
||||
template<class T>
|
||||
void mulm_check() const;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "Processor/RingOptions.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Protocols/Spdz2kPrep.h"
|
||||
#include "OnlineMachine.hpp"
|
||||
#include "OnlineOptions.hpp"
|
||||
|
||||
|
||||
@@ -112,5 +112,26 @@ void ThreadQueues::print_breakdown()
|
||||
<< " on the preprocessing/offline phase, and "
|
||||
<< sum("wait").full() << " idling." << endl;
|
||||
}
|
||||
|
||||
if (sum("random").elapsed())
|
||||
cerr << "Spent " << sum("random").full()
|
||||
<< " on correlated randomness generation." << endl;
|
||||
}
|
||||
}
|
||||
|
||||
NamedCommStats ThreadQueues::total_comm()
|
||||
{
|
||||
NamedCommStats res;
|
||||
for (auto& queue : *this)
|
||||
res += queue->get_comm_stats();
|
||||
return res;
|
||||
}
|
||||
|
||||
NamedCommStats ThreadQueues::max_comm()
|
||||
{
|
||||
NamedCommStats max;
|
||||
if (size() > 2)
|
||||
for (auto& queue : *this)
|
||||
max.imax(queue->get_comm_stats());
|
||||
return max;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,9 @@ public:
|
||||
TimerWithComm sum(const string& phase);
|
||||
|
||||
void print_breakdown();
|
||||
|
||||
NamedCommStats total_comm();
|
||||
NamedCommStats max_comm();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_THREADQUEUES_H_ */
|
||||
|
||||
@@ -11,9 +11,36 @@
|
||||
using namespace std;
|
||||
|
||||
#include "OnlineOptions.h"
|
||||
#include "GC/ArgTuples.h"
|
||||
|
||||
template<class T> class StackedVector;
|
||||
|
||||
void trunc_pr_check(int k, int m, int n_bits);
|
||||
|
||||
template<class T>
|
||||
class Range
|
||||
{
|
||||
typename T::iterator begin_, end_;
|
||||
|
||||
public:
|
||||
Range(T& whole, size_t start, size_t length)
|
||||
{
|
||||
begin_ = whole.begin() + start;
|
||||
end_ = begin_ + length;
|
||||
assert(end_ <= whole.end());
|
||||
}
|
||||
|
||||
typename T::iterator begin()
|
||||
{
|
||||
return begin_;
|
||||
}
|
||||
|
||||
typename T::iterator end()
|
||||
{
|
||||
return end_;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TruncPrTuple
|
||||
{
|
||||
@@ -46,19 +73,30 @@ public:
|
||||
|
||||
T upper(T mask)
|
||||
{
|
||||
return (mask << (n_shift + 1)) >> (n_shift + m + 1);
|
||||
return (mask.cheap_lshift(n_shift + 1)) >> (n_shift + m + 1);
|
||||
}
|
||||
|
||||
T msb(T mask)
|
||||
{
|
||||
return (mask << (n_shift)) >> (T::N_BITS - 1);
|
||||
return (mask.cheap_lshift(n_shift)) >> (T::N_BITS - 1);
|
||||
}
|
||||
|
||||
T add_before()
|
||||
{
|
||||
return T(1).cheap_lshift(k - 1);
|
||||
}
|
||||
|
||||
T subtract_after()
|
||||
{
|
||||
return T(1).cheap_lshift(k - m - 1);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TruncPrTupleWithGap : public TruncPrTuple<T>
|
||||
{
|
||||
bool big_gap_;
|
||||
|
||||
public:
|
||||
TruncPrTupleWithGap(const vector<int>& regs, size_t base) :
|
||||
TruncPrTupleWithGap<T>(regs.begin() + base)
|
||||
@@ -68,6 +106,7 @@ public:
|
||||
TruncPrTupleWithGap(vector<int>::const_iterator it) :
|
||||
TruncPrTuple<T>(it)
|
||||
{
|
||||
big_gap_ = this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error;
|
||||
if (T::prime_field and small_gap())
|
||||
throw runtime_error("domain too small for chosen truncation error");
|
||||
}
|
||||
@@ -75,7 +114,7 @@ public:
|
||||
T upper(T mask)
|
||||
{
|
||||
if (big_gap())
|
||||
return mask >> this->m;
|
||||
return mask.signed_rshift(this->m);
|
||||
else
|
||||
return TruncPrTuple<T>::upper(mask);
|
||||
}
|
||||
@@ -88,7 +127,7 @@ public:
|
||||
|
||||
bool big_gap()
|
||||
{
|
||||
return this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error;
|
||||
return big_gap_;
|
||||
}
|
||||
|
||||
bool small_gap()
|
||||
@@ -97,4 +136,78 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TruncPrTupleWithRange : public TruncPrTupleWithGap<typename T::open_type>
|
||||
{
|
||||
typedef TruncPrTupleWithGap<typename T::open_type> super;
|
||||
|
||||
public:
|
||||
Range<StackedVector<T>> source_range, dest_range;
|
||||
|
||||
TruncPrTupleWithRange(super info, StackedVector<T>& S, size_t size) :
|
||||
super(info), source_range(S, info.source_base, size),
|
||||
dest_range(S, info.dest_base, size)
|
||||
{
|
||||
}
|
||||
|
||||
template<class U>
|
||||
U correction_shift(U bit)
|
||||
{
|
||||
return bit.cheap_lshift(T::open_type::N_BITS - this->m);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class TruncPrTupleList : public vector<TruncPrTupleWithRange<T>>
|
||||
{
|
||||
typedef TruncPrTupleWithGap<T> part_type;
|
||||
typedef TruncPrTupleList This;
|
||||
|
||||
public:
|
||||
TruncPrTupleList(const vector<int>& args, StackedVector<T>& S, size_t size)
|
||||
{
|
||||
ArgList<TruncPrTupleWithGap<typename T::open_type>> tmp(args);
|
||||
for (auto x : tmp)
|
||||
this->push_back(TruncPrTupleWithRange<T>(x, S, size));
|
||||
}
|
||||
|
||||
TruncPrTupleList()
|
||||
{
|
||||
}
|
||||
|
||||
bool have_big_gap()
|
||||
{
|
||||
for (auto info : *this)
|
||||
if (info.big_gap())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool have_small_gap()
|
||||
{
|
||||
for (auto info : *this)
|
||||
if (info.small_gap())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
This get_big_gap()
|
||||
{
|
||||
This res;
|
||||
for (auto info : *this)
|
||||
if (info.big_gap())
|
||||
res.push_back(info);
|
||||
return res;
|
||||
}
|
||||
|
||||
This get_small_gap()
|
||||
{
|
||||
This res;
|
||||
for (auto info : *this)
|
||||
if (info.small_gap())
|
||||
res.push_back(info);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_TRUNCPRTUPLE_H_ */
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
X(PICKS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1] + r[2]], \
|
||||
*dest++ = *op1; op1 += int(n)) \
|
||||
X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
auto op2 = &Procp.get_C()[r[2]]; mulm_check<sint>(), \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(MULC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
@@ -149,7 +149,7 @@
|
||||
auto op2 = &Proc2.get_S()[r[2]], \
|
||||
*dest++ = sgf2n::constant(*op1++, Proc.P.my_num(), Proc2.MC.get_alphai()) - *op2++) \
|
||||
X(GMULM, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \
|
||||
auto op2 = &Proc2.get_C()[r[2]], \
|
||||
auto op2 = &Proc2.get_C()[r[2]]; mulm_check<sgf2n>(), \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(GMULSI, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \
|
||||
typename sgf2n::clear op2 = int(n), \
|
||||
@@ -384,6 +384,16 @@
|
||||
X(ACTIVE, throw not_implemented(),) \
|
||||
X(FIXINPUT, throw not_implemented(),) \
|
||||
X(CONCATS, throw not_implemented(),) \
|
||||
X(ZIPS, throw not_implemented(),) \
|
||||
X(GMATMULS, throw not_implemented(),) \
|
||||
X(GMATMULSM, throw not_implemented(),) \
|
||||
X(PRINTREGPLAINS, throw not_implemented(),) \
|
||||
X(GPRINTREGPLAINS, throw not_implemented(),) \
|
||||
X(CALL_TAPE, throw not_implemented(),) \
|
||||
X(CMDLINEARG, throw not_implemented(),) \
|
||||
X(INITCLIENTCONNECTION, throw not_implemented(),) \
|
||||
X(GWRITEFILESHARE, throw not_implemented(),) \
|
||||
X(GREADFILESHARE, throw not_implemented(),) \
|
||||
|
||||
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
|
||||
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS
|
||||
|
||||
Reference in New Issue
Block a user