mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Multinode computation.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include "Instruction.hpp"
|
||||
#include "Protocols/ShuffleSacrifice.hpp"
|
||||
|
||||
#include <iostream>
|
||||
@@ -38,12 +39,27 @@ bool BaseMachine::has_program()
|
||||
}
|
||||
|
||||
int BaseMachine::edabit_bucket_size(int n_bits)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
int BaseMachine::triple_bucket_size(DataFieldType type)
|
||||
{
|
||||
size_t usage = 0;
|
||||
if (has_program())
|
||||
usage = s().progs[0].get_offline_data_used().files[type][DATA_TRIPLE];
|
||||
return bucket_size(usage);
|
||||
}
|
||||
|
||||
int BaseMachine::bucket_size(size_t usage)
|
||||
{
|
||||
int res = OnlineOptions::singleton.bucket_size;
|
||||
|
||||
if (has_program())
|
||||
if (usage)
|
||||
{
|
||||
auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits);
|
||||
for (int B = res; B <= 5; B++)
|
||||
if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9)
|
||||
break;
|
||||
@@ -91,7 +107,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
string threadname;
|
||||
for (int i=0; i<nprogs; i++)
|
||||
{ inpf >> threadname;
|
||||
size_t split = threadname.find(":");
|
||||
size_t split = threadname.find_last_of(":");
|
||||
long expected = -1;
|
||||
if (split != string::npos)
|
||||
{
|
||||
@@ -125,6 +141,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
getline(inpf, compiler);
|
||||
getline(inpf, domain);
|
||||
getline(inpf, relevant_opts);
|
||||
getline(inpf, security);
|
||||
inpf.close();
|
||||
}
|
||||
|
||||
@@ -184,17 +201,19 @@ string BaseMachine::memory_filename(const string& type_short, int my_number)
|
||||
|
||||
string BaseMachine::get_domain(string progname)
|
||||
{
|
||||
if (singleton)
|
||||
{
|
||||
assert(s().progname == progname);
|
||||
return s().domain;
|
||||
}
|
||||
return get_basics(progname).domain;
|
||||
}
|
||||
|
||||
assert(not singleton);
|
||||
BaseMachine BaseMachine::get_basics(string progname)
|
||||
{
|
||||
if (singleton and s().progname == progname)
|
||||
return s();
|
||||
|
||||
auto backup = singleton;
|
||||
BaseMachine machine;
|
||||
singleton = 0;
|
||||
singleton = backup;
|
||||
machine.load_schedule(progname, false);
|
||||
return machine.domain;
|
||||
return machine;
|
||||
}
|
||||
|
||||
int BaseMachine::ring_size_from_schedule(string progname)
|
||||
@@ -226,6 +245,15 @@ bigint BaseMachine::prime_from_schedule(string progname)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int BaseMachine::security_from_schedule(string progname)
|
||||
{
|
||||
string sec = get_basics(progname).security;
|
||||
if (sec.substr(0, 4).compare("sec:") == 0)
|
||||
return stoi(sec.substr(4));
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
NamedCommStats BaseMachine::total_comm()
|
||||
{
|
||||
NamedCommStats res;
|
||||
|
||||
@@ -32,10 +32,13 @@ protected:
|
||||
string compiler;
|
||||
string domain;
|
||||
string relevant_opts;
|
||||
string security;
|
||||
|
||||
virtual size_t load_program(const string& threadname,
|
||||
const string& filename);
|
||||
|
||||
static BaseMachine get_basics(string progname);
|
||||
|
||||
public:
|
||||
static thread_local int thread_num;
|
||||
|
||||
@@ -58,12 +61,15 @@ public:
|
||||
static int ring_size_from_schedule(string progname);
|
||||
static int prime_length_from_schedule(string progname);
|
||||
static bigint prime_from_schedule(string progname);
|
||||
static int security_from_schedule(string progname);
|
||||
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
template<class T>
|
||||
static int edabit_batch_size(int n_bits, int buffer_size = 0);
|
||||
static int edabit_bucket_size(int n_bits);
|
||||
static int triple_bucket_size(DataFieldType type);
|
||||
static int bucket_size(size_t usage);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "Processor/ExternalClients.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
#include "Networking/ServerSocket.h"
|
||||
#include "Networking/ssl_sockets.h"
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <thread>
|
||||
@@ -25,6 +27,8 @@ ExternalClients::~ExternalClients()
|
||||
}
|
||||
if (ctx)
|
||||
delete ctx;
|
||||
for (auto it = peer_ctxs.begin(); it != peer_ctxs.end(); it++)
|
||||
delete it->second;
|
||||
}
|
||||
|
||||
void ExternalClients::start_listening(int portnum_base)
|
||||
@@ -32,8 +36,9 @@ void ExternalClients::start_listening(int portnum_base)
|
||||
ScopeLock _(lock);
|
||||
client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num());
|
||||
client_connection_servers[portnum_base]->init();
|
||||
cerr << "Start listening on thread " << this_thread::get_id() << endl;
|
||||
cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num())
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Party " << get_party_num() << " is listening on port "
|
||||
<< (portnum_base + get_party_num())
|
||||
<< " for external client connections." << endl;
|
||||
}
|
||||
|
||||
@@ -46,7 +51,6 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl;
|
||||
throw runtime_error("No connection on port " + to_string(portnum_base));
|
||||
}
|
||||
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
|
||||
int client_id, socket;
|
||||
string client;
|
||||
socket = client_connection_servers[portnum_base]->get_connection_socket(
|
||||
@@ -57,10 +61,38 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket,
|
||||
"C" + to_string(client_id), "P" + to_string(get_party_num()), false);
|
||||
client_ports[client_id] = portnum_base;
|
||||
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "Party " << get_party_num()
|
||||
<< " received external client connection from client id: " << dec
|
||||
<< client_id << endl;
|
||||
return client_id;
|
||||
}
|
||||
|
||||
int ExternalClients::init_client_connection(const string& host, int portnum,
|
||||
int my_client_id)
|
||||
{
|
||||
ScopeLock _(lock);
|
||||
int plain_socket;
|
||||
set_up_client_socket(plain_socket, host.c_str(), portnum);
|
||||
octetStream(to_string(my_client_id)).Send(plain_socket);
|
||||
string my_client_name = "C" + to_string(my_client_id);
|
||||
if (peer_ctxs.find(my_client_id) == peer_ctxs.end())
|
||||
peer_ctxs[my_client_id] = new client_ctx(my_client_name);
|
||||
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
|
||||
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
|
||||
true);
|
||||
if (party_num == 0)
|
||||
{
|
||||
octetStream specification;
|
||||
specification.Receive(socket);
|
||||
}
|
||||
int id = -1;
|
||||
if (not external_client_sockets.empty())
|
||||
id = min(id, external_client_sockets.begin()->first);
|
||||
external_client_sockets[id] = socket;
|
||||
return id;
|
||||
}
|
||||
|
||||
void ExternalClients::close_connection(int client_id)
|
||||
{
|
||||
ScopeLock _(lock);
|
||||
|
||||
@@ -32,6 +32,7 @@ class ExternalClients
|
||||
|
||||
ssl_service io_service;
|
||||
client_ctx* ctx;
|
||||
map<int, client_ctx*> peer_ctxs;
|
||||
|
||||
Lock lock;
|
||||
|
||||
@@ -43,6 +44,7 @@ class ExternalClients
|
||||
void start_listening(int portnum_base);
|
||||
|
||||
int get_client_connection(int portnum_base);
|
||||
int init_client_connection(const string& host, int portnum, int my_client_id);
|
||||
|
||||
void close_connection(int client_id);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "OnlineMachine.hpp"
|
||||
#include "OnlineOptions.hpp"
|
||||
|
||||
|
||||
template<template<class U> class T, class V>
|
||||
HonestMajorityFieldMachine<T, V>::HonestMajorityFieldMachine(int argc,
|
||||
const char **argv)
|
||||
@@ -34,6 +33,7 @@ template<template<class U> class T, template<class U> class V, class W, class X>
|
||||
FieldMachine<T, V, W, X>::FieldMachine(int argc, const char** argv,
|
||||
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
|
||||
{
|
||||
assert(nplayers or T<gfpvar>::variable_players);
|
||||
W machine(argc, argv, opt, online_opts, X(), nplayers);
|
||||
int n_limbs = online_opts.prime_limbs();
|
||||
switch (n_limbs)
|
||||
|
||||
@@ -10,11 +10,13 @@
|
||||
#include "Math/gf2n.h"
|
||||
#include "GC/instructions.h"
|
||||
|
||||
#include "Memory.hpp"
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
vector<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
{
|
||||
auto& C2 = registers;
|
||||
auto& M2C = memory;
|
||||
@@ -123,6 +125,6 @@ ostream& operator<<(ostream& s, const Instruction& instr)
|
||||
}
|
||||
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
|
||||
vector<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
|
||||
vector<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
|
||||
@@ -72,6 +72,7 @@ enum
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -153,7 +154,7 @@ enum
|
||||
LISTEN = 0x6c,
|
||||
ACCEPTCLIENTCONNECTION = 0x6d,
|
||||
CLOSECLIENTCONNECTION = 0x6e,
|
||||
READCLIENTPUBLICKEY = 0x6f,
|
||||
INITCLIENTCONNECTION = 0x6f,
|
||||
// Bitwise logic
|
||||
ANDC = 0x70,
|
||||
XORC = 0x71,
|
||||
@@ -197,6 +198,7 @@ enum
|
||||
PRINTREG = 0XB1,
|
||||
RAND = 0xB2,
|
||||
PRINTREGPLAIN = 0xB3,
|
||||
PRINTREGPLAINS = 0xEA,
|
||||
PRINTCHR = 0xB4,
|
||||
PRINTSTR = 0xB5,
|
||||
PUBINPUT = 0xB6,
|
||||
@@ -345,6 +347,7 @@ protected:
|
||||
int r[4]; // Fixed parameter registers
|
||||
size_t n; // Possible immediate value
|
||||
vector<int> start; // Values for a start/stop open
|
||||
string str;
|
||||
|
||||
public:
|
||||
virtual ~BaseInstruction() {};
|
||||
@@ -387,7 +390,7 @@ public:
|
||||
void execute(Processor<sint, sgf2n>& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
void execute_clear_gf2n(vector<cgf2n>& registers, vector<cgf2n>& memory,
|
||||
void execute_clear_gf2n(vector<cgf2n>& registers, MemoryPart<cgf2n>& memory,
|
||||
ArithmeticProcessor& Proc) const;
|
||||
|
||||
template<class cgf2n>
|
||||
|
||||
@@ -105,7 +105,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case STMCBI:
|
||||
case MOVC:
|
||||
case MOVS:
|
||||
case MOVSB:
|
||||
case MOVINT:
|
||||
case LDMINTI:
|
||||
case STMINTI:
|
||||
@@ -131,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case SHUFFLE:
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
case PREFIXSUMS:
|
||||
case CMDLINEARG:
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
// instructions with 1 register operand
|
||||
@@ -139,6 +139,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case RANDOMFULLS:
|
||||
case PRINTREGPLAIN:
|
||||
case PRINTREGPLAINB:
|
||||
case PRINTREGPLAINS:
|
||||
case LDTN:
|
||||
case LDARG:
|
||||
case STARG:
|
||||
@@ -316,13 +317,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
case CONV2DS:
|
||||
case MATMULS:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case MATMULS:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(3, start, s);
|
||||
break;
|
||||
case MATMULSM:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(9, start, s);
|
||||
@@ -358,7 +356,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITCLIENTCONNECTION:
|
||||
get_ints(r, s, 3);
|
||||
get_string(str, s);
|
||||
break;
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
throw runtime_error("VM-controlled encryption not supported any more");
|
||||
@@ -459,6 +460,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case CONVCBIT2S:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
case MOVSB:
|
||||
n = get_int(s);
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
@@ -566,7 +568,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case MOVINT:
|
||||
case READSOCKETINT:
|
||||
case WRITESOCKETINT:
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITCLIENTCONNECTION:
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
case LDARG:
|
||||
@@ -584,6 +586,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case INTOUTPUT:
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
case GENSECSHUFFLE:
|
||||
case CMDLINEARG:
|
||||
return INT;
|
||||
case PREP:
|
||||
case GPREP:
|
||||
@@ -723,6 +726,15 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return res;
|
||||
}
|
||||
case MATMULS:
|
||||
{
|
||||
int res = 0;
|
||||
for (auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
int tmp = *it + *(it + 3) * *(it + 5);
|
||||
res = max(res, tmp);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case MATMULSM:
|
||||
return r[0] + start[0] * start[2];
|
||||
case CONV2DS:
|
||||
@@ -817,7 +829,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
while (it < start.end())
|
||||
{
|
||||
int n = *it - n_prefix;
|
||||
int size = DIV_CEIL(*(it + 1), 64);
|
||||
size = max((long long) size, DIV_CEIL(*(it + 1), 64));
|
||||
it += n_prefix;
|
||||
assert(it + n <= start.end());
|
||||
for (int i = 0; i < n; i++)
|
||||
@@ -922,16 +934,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n));
|
||||
n++;
|
||||
break;
|
||||
case LDMCI:
|
||||
Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1])));
|
||||
break;
|
||||
case STMC:
|
||||
Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0]));
|
||||
n++;
|
||||
break;
|
||||
case STMCI:
|
||||
Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0]));
|
||||
break;
|
||||
case MOVC:
|
||||
Proc.write_Cp(r[0],Proc.read_Cp(r[1]));
|
||||
break;
|
||||
@@ -1089,10 +1095,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.POpen(*this);
|
||||
return;
|
||||
case MULS:
|
||||
Proc.Procp.muls(start, size);
|
||||
Proc.Procp.muls(start);
|
||||
return;
|
||||
case GMULS:
|
||||
Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size);
|
||||
Proc.Proc2.muls(start);
|
||||
return;
|
||||
case MULRS:
|
||||
Proc.Procp.mulrs(start);
|
||||
@@ -1107,7 +1113,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.dotprods(start, size);
|
||||
return;
|
||||
case MATMULS:
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]);
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
|
||||
return;
|
||||
case MATMULSM:
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
|
||||
@@ -1126,13 +1132,15 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Proc2.secure_shuffle(*this);
|
||||
return;
|
||||
case GENSECSHUFFLE:
|
||||
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this));
|
||||
Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this,
|
||||
Proc.machine.shuffle_store));
|
||||
return;
|
||||
case APPLYSHUFFLE:
|
||||
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)));
|
||||
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)),
|
||||
Proc.machine.shuffle_store);
|
||||
return;
|
||||
case DELSHUFFLE:
|
||||
Proc.Procp.delete_shuffle(Proc.read_Ci(r[0]));
|
||||
Proc.machine.shuffle_store.del(Proc.read_Ci(r[0]));
|
||||
return;
|
||||
case INVPERM:
|
||||
Proc.Procp.inverse_permutation(*this);
|
||||
@@ -1170,6 +1178,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case PRINTREGPLAIN:
|
||||
print(Proc.out, &Proc.read_Cp(r[0]));
|
||||
return;
|
||||
case PRINTREGPLAINS:
|
||||
Proc.out << Proc.read_Sp(r[0]);
|
||||
return;
|
||||
case CONDPRINTPLAIN:
|
||||
if (not Proc.read_Cp(r[0]).is_zero())
|
||||
{
|
||||
@@ -1237,6 +1248,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case PLAYERID:
|
||||
Proc.write_Ci(r[0], Proc.P.my_num());
|
||||
break;
|
||||
case CMDLINEARG:
|
||||
{
|
||||
size_t idx = Proc.read_Ci(r[1]);
|
||||
auto& args = OnlineOptions::singleton.args;
|
||||
if (idx < args.size())
|
||||
Proc.write_Ci(r[0], args[idx]);
|
||||
else
|
||||
{
|
||||
cerr << idx << "-th command-line argument not given" << endl;
|
||||
exit(1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// ***
|
||||
// TODO: read/write shared GF(2^n) data instructions
|
||||
// ***
|
||||
@@ -1255,11 +1279,17 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
octetStream os;
|
||||
os.store(int(sint::open_type::type_char()));
|
||||
sint::specification(os);
|
||||
sint::clear::specification(os);
|
||||
os.Send(Proc.external_clients.get_socket(client_handle));
|
||||
}
|
||||
Proc.write_Ci(r[0], client_handle);
|
||||
break;
|
||||
}
|
||||
case INITCLIENTCONNECTION:
|
||||
Proc.write_Ci(r[0],
|
||||
Proc.external_clients.init_client_connection(str,
|
||||
Proc.read_Ci(r[1]), Proc.read_Ci(r[2])));
|
||||
break;
|
||||
case CLOSECLIENTCONNECTION:
|
||||
Proc.external_clients.close_connection(Proc.read_Ci(r[0]));
|
||||
break;
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/ExecutionStats.h"
|
||||
|
||||
#include "Protocols/SecureShuffle.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <atomic>
|
||||
@@ -70,6 +72,8 @@ class Machine : public BaseMachine
|
||||
|
||||
ExternalClients external_clients;
|
||||
|
||||
typename sint::Protocol::Shuffler::store_type shuffle_store;
|
||||
|
||||
static void init_binary_domains(int security_parameter, int lg2);
|
||||
|
||||
Machine(Names& playerNames, bool use_encryption = true,
|
||||
|
||||
@@ -60,11 +60,21 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
|
||||
{
|
||||
OnlineOptions::singleton = opts;
|
||||
|
||||
if (N.num_players() == 1 and sint::is_real)
|
||||
int min_players = 3 - sint::dishonest_majority;
|
||||
if (sint::is_real)
|
||||
{
|
||||
cerr << "Need more than one player to run a protocol." << endl;
|
||||
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
|
||||
exit(1);
|
||||
if (N.num_players() == 1)
|
||||
{
|
||||
cerr << "Need more than one player to run a protocol." << endl;
|
||||
cerr << "Use 'emulate.x' for just running the virtual machine" << endl;
|
||||
exit(1);
|
||||
}
|
||||
else if (N.num_players() < min_players)
|
||||
{
|
||||
cerr << "Need at least " << min_players << " players for this protocol."
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Set the prime modulus from command line or program if applicable
|
||||
@@ -480,8 +490,10 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Communication details "
|
||||
"(rounds in parallel threads counted double):" << endl;
|
||||
cerr << "Communication details";
|
||||
if (multithread)
|
||||
cerr << " (rounds in parallel threads counted double)";
|
||||
cerr << ":" << endl;
|
||||
comm_stats.print();
|
||||
cerr << "CPU time = " << proc_timer.elapsed();
|
||||
if (multithread)
|
||||
@@ -547,6 +559,14 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
suggest_optimizations();
|
||||
|
||||
if (N.num_players() > 4)
|
||||
{
|
||||
string alt = sint::alt();
|
||||
if (alt.size())
|
||||
cerr << "This protocol doesn't scale well with the number of parties, "
|
||||
<< "have you considered using " << alt << " instead?" << endl;
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "End of prog" << endl;
|
||||
#endif
|
||||
|
||||
@@ -14,34 +14,90 @@ template<class T> istream& operator>>(istream& s,Memory<T>& M);
|
||||
|
||||
#include "Processor/Program.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
#include "Tools/DiskVector.h"
|
||||
|
||||
template<class T>
|
||||
class MemoryPart : public CheckVector<T>
|
||||
class MemoryPart
|
||||
{
|
||||
public:
|
||||
template<class U>
|
||||
static void check_index(const vector<U>& M, size_t i)
|
||||
virtual ~MemoryPart() {}
|
||||
|
||||
virtual size_t size() const = 0;
|
||||
virtual void resize(size_t) = 0;
|
||||
|
||||
virtual T* data() = 0;
|
||||
virtual const T* data() const = 0;
|
||||
|
||||
void check_index(size_t i) const
|
||||
{
|
||||
(void) M, (void) i;
|
||||
(void) i;
|
||||
#ifndef NO_CHECK_INDEX
|
||||
if (i >= M.size())
|
||||
throw overflow(U::type_string() + " memory", i, M.size());
|
||||
if (i >= this->size())
|
||||
throw overflow(T::type_string() + " memory", i, this->size());
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual T& operator[](size_t i) = 0;
|
||||
virtual const T& operator[](size_t i) const = 0;
|
||||
|
||||
virtual T& at(size_t i) = 0;
|
||||
virtual const T& at(size_t i) const = 0;
|
||||
|
||||
template<class U>
|
||||
void indirect_read(const Instruction& inst, vector<T>& regs,
|
||||
const U& indices);
|
||||
template<class U>
|
||||
void indirect_write(const Instruction& inst, vector<T>& regs,
|
||||
const U& indices);
|
||||
|
||||
void minimum_size(size_t size);
|
||||
};
|
||||
|
||||
template<class T, template<class> class V>
|
||||
class MemoryPartImpl : public MemoryPart<T>, public V<T>
|
||||
{
|
||||
public:
|
||||
size_t size() const
|
||||
{
|
||||
return V<T>::size();
|
||||
}
|
||||
|
||||
void resize(size_t size)
|
||||
{
|
||||
V<T>::resize(size);
|
||||
}
|
||||
|
||||
T* data()
|
||||
{
|
||||
return V<T>::data();
|
||||
}
|
||||
|
||||
const T* data() const
|
||||
{
|
||||
return V<T>::data();
|
||||
}
|
||||
|
||||
T& operator[](size_t i)
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
this->check_index(i);
|
||||
return V<T>::operator[](i);
|
||||
}
|
||||
|
||||
const T& operator[](size_t i) const
|
||||
{
|
||||
check_index(*this, i);
|
||||
return CheckVector<T>::operator[](i);
|
||||
this->check_index(i);
|
||||
return V<T>::operator[](i);
|
||||
}
|
||||
|
||||
void minimum_size(size_t size);
|
||||
T& at(size_t i)
|
||||
{
|
||||
return V<T>::at(i);
|
||||
}
|
||||
|
||||
const T& at(size_t i) const
|
||||
{
|
||||
return V<T>::at(i);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -49,8 +105,11 @@ class Memory
|
||||
{
|
||||
public:
|
||||
|
||||
MemoryPart<T> MS;
|
||||
MemoryPart<typename T::clear> MC;
|
||||
MemoryPart<T>& MS;
|
||||
MemoryPartImpl<typename T::clear, CheckVector> MC;
|
||||
|
||||
Memory();
|
||||
~Memory();
|
||||
|
||||
void resize_s(size_t sz)
|
||||
{ MS.resize(sz); }
|
||||
|
||||
@@ -3,6 +3,54 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_read(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto dest = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(dest + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
const T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
throw overflow(T::type_string() + " memory read", it->get(), size);
|
||||
#endif
|
||||
*dest++ = data[it->get()];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void MemoryPart<T>::indirect_write(const Instruction& inst,
|
||||
vector<T>& regs, const U& indices)
|
||||
{
|
||||
size_t n = inst.get_size();
|
||||
auto source = regs.begin() + inst.get_r(0);
|
||||
auto start = indices.begin() + inst.get_r(1);
|
||||
#ifdef CHECK_SIZE
|
||||
assert(start + n <= indices.end());
|
||||
assert(source + n <= regs.end());
|
||||
#endif
|
||||
long size = this->size();
|
||||
T* data = this->data();
|
||||
for (auto it = start; it < start + n; it++)
|
||||
{
|
||||
#ifndef NO_CHECK_SIZE
|
||||
if (*it >= size)
|
||||
throw overflow(T::type_string() + " memory write", it->get(), size);
|
||||
#endif
|
||||
data[it->get()] = *source++;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
|
||||
const Program &program, const string& threadname)
|
||||
@@ -29,6 +77,21 @@ void MemoryPart<T>::minimum_size(size_t size)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Memory<T>::Memory() :
|
||||
MS(
|
||||
*(OnlineOptions::singleton.disk_memory.size() ?
|
||||
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, DiskVector>) :
|
||||
static_cast<MemoryPart<T>*>(new MemoryPartImpl<T, CheckVector>)))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Memory<T>::~Memory()
|
||||
{
|
||||
delete &MS;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ostream& operator<<(ostream& s,const Memory<T>& M)
|
||||
{
|
||||
|
||||
@@ -71,18 +71,6 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op
|
||||
"--ip-file-name" // Flag token.
|
||||
);
|
||||
|
||||
if (nplayers == 0)
|
||||
opt.add(
|
||||
"2", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of players (default: 2). "
|
||||
"Ignored if external server is used.", // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
|
||||
@@ -22,12 +22,13 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
interactive = false;
|
||||
lgp = gfp0::MAX_N_BITS;
|
||||
live_prep = true;
|
||||
batch_size = 10000;
|
||||
batch_size = 1000;
|
||||
memtype = "empty";
|
||||
bits_from_squares = false;
|
||||
direct = false;
|
||||
bucket_size = 4;
|
||||
security_parameter = DEFAULT_SECURITY;
|
||||
use_security_parameter = false;
|
||||
cmd_private_input_file = "Player-Data/Input";
|
||||
cmd_private_output_file = "";
|
||||
file_prep_per_thread = false;
|
||||
@@ -46,6 +47,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, bool security) :
|
||||
OnlineOptions()
|
||||
{
|
||||
use_security_parameter = security;
|
||||
|
||||
opt.syntax = std::string(argv[0]) + " [OPTIONS] [<playerno>] <progname>";
|
||||
|
||||
opt.add(
|
||||
@@ -116,7 +119,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Security parameter (default: " + to_string(security_parameter)
|
||||
("Statistical ecurity parameter (default: " + to_string(security_parameter)
|
||||
+ ")").c_str(), // Help description.
|
||||
"-S", // Flag token.
|
||||
"--security" // Flag token.
|
||||
@@ -138,7 +141,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
if (security)
|
||||
{
|
||||
opt.get("-S")->getInt(security_parameter);
|
||||
cerr << "Using security parameter " << security_parameter << endl;
|
||||
if (security_parameter <= 0)
|
||||
{
|
||||
cerr << "Invalid security parameter: " << security_parameter << endl;
|
||||
@@ -280,7 +282,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
}
|
||||
|
||||
void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv)
|
||||
const char** argv, bool networking)
|
||||
{
|
||||
opt.resetArgs();
|
||||
opt.parse(argc, argv);
|
||||
@@ -292,17 +294,21 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
vector<string> badOptions;
|
||||
unsigned int i;
|
||||
|
||||
opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for documentation on the networking setup.\n";
|
||||
if (networking)
|
||||
opt.footer += "See also "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for documentation on the networking setup.\n\n";
|
||||
|
||||
if (allArgs.size() != 3u - opt.isSet("-p"))
|
||||
size_t name_index = 1 + networking - opt.isSet("-p");
|
||||
|
||||
if (allArgs.size() < name_index + 1)
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl;
|
||||
cerr << "Arguments given were:\n";
|
||||
for (unsigned int j = 1; j < allArgs.size(); j++)
|
||||
cout << "'" << *allArgs[j] << "'" << endl;
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(1);
|
||||
}
|
||||
else
|
||||
@@ -311,25 +317,25 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
opt.get("-p")->getInt(playerno);
|
||||
else
|
||||
sscanf((*allArgs[1]).c_str(), "%d", &playerno);
|
||||
progname = *allArgs[2 - opt.isSet("-p")];
|
||||
progname = *allArgs.at(name_index);
|
||||
}
|
||||
|
||||
if (!opt.gotRequired(badOptions))
|
||||
{
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (!opt.gotExpected(badOptions))
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
for (i = 0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Got unexpected number of arguments for option "
|
||||
<< badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -347,6 +353,22 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
prime = schedule_prime;
|
||||
}
|
||||
|
||||
for (size_t i = name_index + 1; i < allArgs.size(); i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
args.push_back(stol(*allArgs[i]));
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cerr << usage;
|
||||
cerr << "Additional argument has to be integer: " << *allArgs[i]
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// ignore program if length explicitly set from command line
|
||||
if (opt.get("-lgp") and not opt.isSet("-lgp"))
|
||||
{
|
||||
@@ -367,7 +389,29 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
if (o)
|
||||
o->getInt(max_broadcast);
|
||||
|
||||
o = opt.get("--disk-memory");
|
||||
if (o)
|
||||
o->getString(disk_memory);
|
||||
|
||||
receive_threads = opt.isSet("--threads");
|
||||
|
||||
if (use_security_parameter)
|
||||
{
|
||||
int program_sec = BaseMachine::security_from_schedule(progname);
|
||||
|
||||
if (program_sec > 0)
|
||||
{
|
||||
if (not opt.isSet("-S"))
|
||||
security_parameter = program_sec;
|
||||
if (program_sec < security_parameter)
|
||||
{
|
||||
cerr << "Security parameter used in compilation is insufficient" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "Using statistical security parameter " << security_parameter << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt)
|
||||
|
||||
@@ -27,6 +27,7 @@ public:
|
||||
bool direct;
|
||||
int bucket_size;
|
||||
int security_parameter;
|
||||
bool use_security_parameter;
|
||||
std::string cmd_private_input_file;
|
||||
std::string cmd_private_output_file;
|
||||
bool verbose;
|
||||
@@ -34,6 +35,8 @@ public:
|
||||
int trunc_error;
|
||||
int opening_sum, max_broadcast;
|
||||
bool receive_threads;
|
||||
std::string disk_memory;
|
||||
vector<long> args;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
@@ -48,7 +51,8 @@ public:
|
||||
OnlineOptions(T);
|
||||
~OnlineOptions() {}
|
||||
|
||||
void finalize(ez::ezOptionParser& opt, int argc, const char** argv);
|
||||
void finalize(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
bool networking = true);
|
||||
|
||||
void set_trunc_error(ez::ezOptionParser& opt);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
template<class T>
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, T, bool default_live_prep) :
|
||||
OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0,
|
||||
OnlineOptions(opt, argc, argv, OnlineOptions(T()).batch_size,
|
||||
default_live_prep, T::clear::prime_field)
|
||||
{
|
||||
if (T::has_trunc_pr)
|
||||
@@ -56,13 +56,39 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"--max-broadcast" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
if (not T::clear::binary)
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Use directory on disk for memory (container data structures) "
|
||||
"instead of RAM", // Help description.
|
||||
"-D", // Flag token.
|
||||
"--disk-memory" // Flag token.
|
||||
);
|
||||
|
||||
if (T::variable_players)
|
||||
opt.add(
|
||||
T::dishonest_majority ? "2" : "3", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Number of players (default: "
|
||||
+ (T::dishonest_majority ?
|
||||
to_string("2") : to_string("3")) + "). " +
|
||||
"Ignored if external server is used.").c_str(), // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OnlineOptions::OnlineOptions(T) : OnlineOptions()
|
||||
{
|
||||
if (T::dishonest_majority)
|
||||
batch_size = 1000;
|
||||
if (not T::dishonest_majority)
|
||||
batch_size = 10000;
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */
|
||||
|
||||
@@ -36,7 +36,7 @@ class SubProcessor
|
||||
|
||||
void resize(size_t size) { C.resize(size); S.resize(size); }
|
||||
|
||||
void matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
void matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b);
|
||||
void matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
typename vector<T>::iterator C);
|
||||
@@ -48,6 +48,8 @@ class SubProcessor
|
||||
|
||||
typedef typename T::bit_type::part_type BT;
|
||||
|
||||
typedef typename T::Protocol::Shuffler::store_type ShuffleStore;
|
||||
|
||||
public:
|
||||
ArithmeticProcessor* Proc;
|
||||
typename T::MAC_Check& MC;
|
||||
@@ -71,19 +73,19 @@ public:
|
||||
// Access to PO (via calls to POpen start/stop)
|
||||
void POpen(const Instruction& inst);
|
||||
|
||||
void muls(const vector<int>& reg, int size);
|
||||
void muls(const vector<int>& reg);
|
||||
void mulrs(const vector<int>& reg);
|
||||
void dotprods(const vector<int>& reg, int size);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, size_t a,
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction);
|
||||
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
void secure_shuffle(const Instruction& instruction);
|
||||
size_t generate_secure_shuffle(const Instruction& instruction);
|
||||
void apply_shuffle(const Instruction& instruction, int handle);
|
||||
void delete_shuffle(int handle);
|
||||
size_t generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store);
|
||||
void apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store);
|
||||
void inverse_permutation(const Instruction& instruction);
|
||||
|
||||
void input_personal(const vector<int>& args);
|
||||
@@ -116,7 +118,7 @@ public:
|
||||
class ArithmeticProcessor : public ProcessorBase
|
||||
{
|
||||
protected:
|
||||
CheckVector<long> Ci;
|
||||
CheckVector<Integer> Ci;
|
||||
|
||||
ofstream public_output;
|
||||
ofstream binary_output;
|
||||
@@ -162,13 +164,13 @@ public:
|
||||
return thread_num;
|
||||
}
|
||||
|
||||
const long& read_Ci(size_t i) const
|
||||
{ return Ci[i]; }
|
||||
long& get_Ci_ref(size_t i)
|
||||
long read_Ci(size_t i) const
|
||||
{ return Ci[i].get(); }
|
||||
Integer& get_Ci_ref(size_t i)
|
||||
{ return Ci[i]; }
|
||||
void write_Ci(size_t i, const long& x)
|
||||
{ Ci[i]=x; }
|
||||
CheckVector<long>& get_Ci()
|
||||
CheckVector<Integer>& get_Ci()
|
||||
{ return Ci; }
|
||||
|
||||
virtual ofstream& get_public_output()
|
||||
|
||||
@@ -379,9 +379,20 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id,
|
||||
client_timer.stop();
|
||||
client_stats.add(socket_stream.get_length());
|
||||
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
|
||||
int j, i;
|
||||
try
|
||||
{
|
||||
for (j = 0; j < size; j++)
|
||||
for (i = 0; i < m; i++)
|
||||
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
throw insufficient_shares(m * size, j * m + i, e);
|
||||
}
|
||||
|
||||
if (socket_stream.left())
|
||||
throw runtime_error("unexpected share data");
|
||||
}
|
||||
|
||||
|
||||
@@ -468,28 +479,29 @@ void SubProcessor<T>::POpen(const Instruction& inst)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::muls(const vector<int>& reg, int size)
|
||||
void SubProcessor<T>::muls(const vector<int>& reg)
|
||||
{
|
||||
assert(reg.size() % 3 == 0);
|
||||
int n = reg.size() / 3;
|
||||
assert(reg.size() % 4 == 0);
|
||||
int n = reg.size() / 4;
|
||||
|
||||
SubProcessor<T>& proc = *this;
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
auto& x = proc.S[reg[3 * i + 1] + j];
|
||||
auto& y = proc.S[reg[3 * i + 2] + j];
|
||||
auto& x = proc.S[reg[4 * i + 2] + j];
|
||||
auto& y = proc.S[reg[4 * i + 3] + j];
|
||||
protocol.prepare_mul(x, y);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (int i = 0; i < n; i++)
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
for (int j = 0; j < reg[4 * i]; j++)
|
||||
{
|
||||
proc.S[reg[3 * i] + j] = protocol.finalize_mul();
|
||||
proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul();
|
||||
}
|
||||
|
||||
protocol.counter += n * size;
|
||||
protocol.counter += n * reg[4 * i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -553,33 +565,46 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
const Instruction& instruction)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
auto A = source.begin() + a;
|
||||
auto B = source.begin() + b;
|
||||
auto C = S.begin() + (instruction.get_r(0));
|
||||
assert(A + dim[0] * dim[1] <= source.end());
|
||||
assert(B + dim[1] * dim[2] <= source.end());
|
||||
assert(C + dim[0] * dim[2] <= S.end());
|
||||
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
protocol.prepare_dotprod(*(A + i * dim[1] + k),
|
||||
*(B + k * dim[2] + j));
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
|
||||
auto& start = instruction.get_start();
|
||||
assert(start.size() % 6 == 0);
|
||||
|
||||
for(auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
auto dim = it + 3;
|
||||
auto A = source.begin() + *(it + 1);
|
||||
auto B = source.begin() + *(it + 2);
|
||||
assert(A + dim[0] * dim[1] <= source.end());
|
||||
assert(B + dim[1] * dim[2] <= source.end());
|
||||
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
protocol.prepare_dotprod(*(A + i * dim[1] + k),
|
||||
*(B + k * dim[2] + j));
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
}
|
||||
|
||||
protocol.exchange();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
|
||||
for(auto it = start.begin(); it < start.end(); it += 6)
|
||||
{
|
||||
auto C = S.begin() + *it;
|
||||
auto dim = it + 3;
|
||||
assert(C + dim[0] * dim[2] <= S.end());
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
@@ -592,7 +617,7 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
{
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i);
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i).get();
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
#ifdef DEBUG_MATMULSM
|
||||
@@ -628,16 +653,21 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b)
|
||||
{
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j);
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j).get();
|
||||
const T* base = source.data();
|
||||
size_t size = source.size();
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k);
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k);
|
||||
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
|
||||
source.at(b + ll * dim[8] + jj));
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k).get();
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k).get();
|
||||
auto aa = a + ii * dim[7] + kk;
|
||||
auto bb = b + ll * dim[8] + jj;
|
||||
assert(aa < size);
|
||||
assert(bb < size);
|
||||
protocol.prepare_dotprod(base[aa], base[bb]);
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
@@ -655,16 +685,22 @@ void SubProcessor<T>::matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
template<class T>
|
||||
void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
auto& args = instruction.get_start();
|
||||
vector<Conv2dTuple> tuples;
|
||||
for (size_t i = 0; i < args.size(); i += 15)
|
||||
tuples.push_back(Conv2dTuple(args, i));
|
||||
for (auto& tuple : tuples)
|
||||
tuple.pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (auto& tuple : tuples)
|
||||
tuple.post(S, protocol);
|
||||
size_t done = 0;
|
||||
while (done < tuples.size())
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
size_t i;
|
||||
for (i = done; i < tuples.size() and protocol.get_buffer_size() <
|
||||
OnlineOptions::singleton.batch_size; i++)
|
||||
tuples[i].pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (; done < i; done++)
|
||||
tuples[done].post(S, protocol);
|
||||
}
|
||||
}
|
||||
|
||||
inline
|
||||
@@ -766,25 +802,22 @@ void SubProcessor<T>::secure_shuffle(const Instruction& instruction)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction)
|
||||
size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
return shuffler.generate(instruction.get_n());
|
||||
return shuffler.generate(instruction.get_n(), shuffle_store);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle)
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
|
||||
instruction.get_start()[0], instruction.get_start()[1], handle,
|
||||
instruction.get_start()[0], instruction.get_start()[1],
|
||||
shuffle_store.get(handle),
|
||||
instruction.get_start()[4]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::delete_shuffle(int handle)
|
||||
{
|
||||
shuffler.del(handle);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::inverse_permutation(const Instruction& instruction) {
|
||||
shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0],
|
||||
@@ -796,17 +829,26 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
|
||||
{
|
||||
input.reset_all(P);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
if (args[i + 1] == P.my_num())
|
||||
{
|
||||
if (args[i + 1] == P.my_num())
|
||||
input.add_mine(C[args[i + 3] + j]);
|
||||
else
|
||||
input.add_other(args[i + 1]);
|
||||
auto begin = C.begin() + args[i + 3];
|
||||
auto end = begin + args[i];
|
||||
assert(end <= C.end());
|
||||
for (auto it = begin; it < end; it++)
|
||||
input.add_mine(*it);
|
||||
}
|
||||
else
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
input.add_other(args[i + 1]);
|
||||
input.exchange();
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
S[args[i + 2] + j] = input.finalize(args[i + 1]);
|
||||
{
|
||||
auto begin = S.begin() + args[i + 2];
|
||||
auto end = begin + args[i];
|
||||
assert(end <= S.end());
|
||||
for (auto it = begin; it < end; it++)
|
||||
*it = input.finalize(args[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -858,6 +900,16 @@ typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
|
||||
return inverses2m[m];
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
void fixinput_int(T& proc, const Instruction& instruction, U)
|
||||
{
|
||||
U* x = new U[instruction.get_size()];
|
||||
proc.binary_input.read((char*) x, sizeof(U) * instruction.get_size());
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
proc.write_Cp(instruction.get_r(0) + i, x[i]);
|
||||
delete[] x;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
{
|
||||
@@ -878,19 +930,24 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
throw runtime_error("unknown format for fixed-point input");
|
||||
}
|
||||
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
|
||||
if (binary_input.peek() == EOF)
|
||||
throw IO_Error("not enough inputs in " + binary_input_filename);
|
||||
|
||||
if (instruction.get_r(2) == 0)
|
||||
{
|
||||
if (binary_input.peek() == EOF)
|
||||
throw IO_Error("not enough inputs in " + binary_input_filename);
|
||||
double buf;
|
||||
if (instruction.get_r(2) == 0)
|
||||
{
|
||||
int64_t x;
|
||||
binary_input.read((char*) &x, sizeof(x));
|
||||
tmp = x;
|
||||
}
|
||||
if (instruction.get_r(1) == 1)
|
||||
fixinput_int(*this, instruction, int8_t());
|
||||
else
|
||||
fixinput_int(*this, instruction, int64_t());
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < instruction.get_size(); i++)
|
||||
{
|
||||
double buf;
|
||||
if (use_double)
|
||||
binary_input.read((char*) &buf, sizeof(double));
|
||||
else
|
||||
@@ -900,11 +957,12 @@ void Processor<sint, sgf2n>::fixinput(const Instruction& instruction)
|
||||
buf = x;
|
||||
}
|
||||
tmp = bigint::tmp = round(buf * exp2(instruction.get_r(1)));
|
||||
write_Cp(instruction.get_r(0) + i, tmp);
|
||||
}
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
write_Cp(instruction.get_r(0) + i, tmp);
|
||||
}
|
||||
|
||||
if (binary_input.fail())
|
||||
throw IO_Error("failure reading from " + binary_input_filename);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@ using namespace std;
|
||||
#include "Tools/ExecutionStats.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "OnlineOptions.h"
|
||||
#include "Math/Integer.h"
|
||||
|
||||
class ProcessorBase
|
||||
{
|
||||
// Stack
|
||||
stack<long> stacki;
|
||||
stack<Integer> stacki;
|
||||
|
||||
ifstream input_file;
|
||||
string input_filename;
|
||||
@@ -26,7 +27,7 @@ class ProcessorBase
|
||||
|
||||
protected:
|
||||
// Optional argument to tape
|
||||
int arg;
|
||||
Integer arg;
|
||||
|
||||
string get_parameterized_filename(int my_num, int thread_num,
|
||||
const string& prefix);
|
||||
@@ -38,15 +39,15 @@ public:
|
||||
|
||||
ProcessorBase();
|
||||
|
||||
void pushi(long x) { stacki.push(x); }
|
||||
void popi(long& x) { x = stacki.top(); stacki.pop(); }
|
||||
void pushi(Integer x) { stacki.push(x); }
|
||||
void popi(Integer& x) { x = stacki.top(); stacki.pop(); }
|
||||
|
||||
int get_arg() const
|
||||
Integer get_arg() const
|
||||
{
|
||||
return arg;
|
||||
}
|
||||
|
||||
void set_arg(int new_arg)
|
||||
void set_arg(Integer new_arg)
|
||||
{
|
||||
arg=new_arg;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ template<template<int L> class U, template<class T> class V, class W>
|
||||
RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
|
||||
ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers)
|
||||
{
|
||||
assert(nplayers or U<64>::variable_players);
|
||||
RingOptions opts(opt, argc, argv);
|
||||
W machine(argc, argv, opt, online_opts, gf2n(), nplayers);
|
||||
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);
|
||||
@@ -65,7 +66,7 @@ template<template<int K, int S> class U, template<class T> class V>
|
||||
HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecurity(
|
||||
int argc, const char** argv, ez::ezOptionParser& opt)
|
||||
{
|
||||
OnlineOptions online_opts(opt, argc, argv);
|
||||
OnlineOptions online_opts(opt, argc, argv, U<64, 40>());
|
||||
RingOptions opts(opt, argc, argv);
|
||||
HonestMajorityMachine machine(argc, argv, opt, online_opts);
|
||||
int R = opts.ring_size_from_opts_or_schedule(online_opts.progname);
|
||||
|
||||
@@ -18,10 +18,10 @@
|
||||
*dest++ = *source++) \
|
||||
X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = Proc.machine.Mp.read_S(*source++)) \
|
||||
X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
|
||||
Proc.machine.Mp.write_S(*dest++, *source++)) \
|
||||
X(LDMSI, Proc.machine.Mp.MS.indirect_read(instruction, Procp.get_S(), Proc.get_Ci()),) \
|
||||
X(STMSI, Proc.machine.Mp.MS.indirect_write(instruction, Procp.get_S(), Proc.get_Ci()),) \
|
||||
X(LDMCI, Proc.machine.Mp.MC.indirect_read(instruction, Procp.get_C(), Proc.get_Ci()),) \
|
||||
X(STMCI, Proc.machine.Mp.MC.indirect_write(instruction, Procp.get_C(), Proc.get_Ci()),) \
|
||||
X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
@@ -121,10 +121,8 @@
|
||||
*dest++ = *source++) \
|
||||
X(GSTMS, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.machine.M2.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(GLDMSI, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = Proc.machine.M2.read_S(*source++)) \
|
||||
X(GSTMSI, auto source = &Proc2.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
|
||||
Proc.machine.M2.write_S(*dest++, *source++)) \
|
||||
X(GLDMSI, Proc.machine.M2.MS.indirect_read(instruction, Proc2.get_S(), Proc.get_Ci()),) \
|
||||
X(GSTMSI, Proc.machine.M2.MS.indirect_write(instruction, Proc2.get_S(), Proc.get_Ci()),) \
|
||||
X(GMOVS, auto dest = &Proc2.get_S()[r[0]]; auto source = &Proc2.get_S()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(GADDS, auto dest = &Proc2.get_S()[r[0]]; auto op1 = &Proc2.get_S()[r[1]]; \
|
||||
@@ -171,10 +169,8 @@
|
||||
*dest++ = (*source).get(); source++) \
|
||||
X(STMINT, auto dest = &Mi[n]; auto source = &Proc.get_Ci()[r[0]], \
|
||||
*dest++ = *source++) \
|
||||
X(LDMINTI, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = Mi[*source].get(); source++) \
|
||||
X(STMINTI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &Ci[r[0]], \
|
||||
Mi[*dest] = *source++; dest++) \
|
||||
X(LDMINTI, Mi.indirect_read(*this, Proc.get_Ci(), Proc.get_Ci()),) \
|
||||
X(STMINTI, Mi.indirect_write(*this, Proc.get_Ci(), Proc.get_Ci()),) \
|
||||
X(MOVINT, auto dest = &Proc.get_Ci()[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(PUSHINT, Proc.pushi(Ci[r[0]]),) \
|
||||
@@ -213,7 +209,7 @@
|
||||
X(SHUFFLE, shuffle(Proc),) \
|
||||
X(BITDECINT, bitdecint(Proc),) \
|
||||
X(RAND, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], \
|
||||
*dest++ = Proc.shared_prng.get_uint() % (1 << *source++)) \
|
||||
*dest++ = Proc.shared_prng.get_uint() % (1 << (*source++).get())) \
|
||||
|
||||
#define CLEAR_GF2N_INSTRUCTIONS \
|
||||
X(GLDI, auto dest = &C2[r[0]]; cgf2n tmp = int(n), \
|
||||
@@ -222,10 +218,8 @@
|
||||
*dest++ = (*source).get(); source++) \
|
||||
X(GSTMC, auto dest = &M2C[n]; auto source = &C2[r[0]], \
|
||||
*dest++ = *source++) \
|
||||
X(GLDMCI, auto dest = &C2[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = M2C[*source++]) \
|
||||
X(GSTMCI, auto dest = &Proc.get_Ci()[r[1]]; auto source = &C2[r[0]], \
|
||||
M2C[*dest++] = *source++) \
|
||||
X(GLDMCI, M2C.indirect_read(*this, C2, Proc.get_Ci()),) \
|
||||
X(GSTMCI, M2C.indirect_write(*this, C2, Proc.get_Ci()),) \
|
||||
X(GMOVC, auto dest = &C2[r[0]]; auto source = &C2[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(GADDC, auto dest = &C2[r[0]]; auto op1 = &C2[r[1]]; \
|
||||
@@ -288,9 +282,7 @@
|
||||
#define REMAINING_INSTRUCTIONS \
|
||||
X(CONVMODP, throw not_implemented(),) \
|
||||
X(LDMC, throw not_implemented(),) \
|
||||
X(LDMCI, throw not_implemented(),) \
|
||||
X(STMC, throw not_implemented(),) \
|
||||
X(STMCI, throw not_implemented(),) \
|
||||
X(MOVC, throw not_implemented(),) \
|
||||
X(DIVC, throw not_implemented(),) \
|
||||
X(GDIVC, throw not_implemented(),) \
|
||||
@@ -390,6 +382,8 @@
|
||||
X(APPLYSHUFFLE, throw not_implemented(),) \
|
||||
X(DELSHUFFLE, throw not_implemented(),) \
|
||||
X(ACTIVE, throw not_implemented(),) \
|
||||
X(FIXINPUT, throw not_implemented(),) \
|
||||
X(CONCATS, throw not_implemented(),) \
|
||||
|
||||
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
|
||||
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS
|
||||
|
||||
Reference in New Issue
Block a user