mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 05:27:56 -05:00
Various improvements.
This commit is contained in:
@@ -42,7 +42,6 @@ CommonFakeParty::~CommonFakeParty()
|
||||
|
||||
CommonParty::~CommonParty()
|
||||
{
|
||||
cerr << "Total time: " << timer.elapsed() << endl;
|
||||
#ifdef VERBOSE
|
||||
cerr << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl;
|
||||
cerr << "CPU time: " << cpu_timer.elapsed() << endl;
|
||||
@@ -50,6 +49,7 @@ CommonParty::~CommonParty()
|
||||
cerr << "Second phase time: " << timers[1].elapsed() << endl;
|
||||
cerr << "Number of gates: " << gate_counter << endl;
|
||||
#endif
|
||||
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
|
||||
}
|
||||
|
||||
void CommonParty::check(int n_parties)
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
#include "Party.h"
|
||||
|
||||
#include "GC/ShareSecret.hpp"
|
||||
|
||||
template<class T>
|
||||
ProgramPartySpec<T>* ProgramPartySpec<T>::singleton = 0;
|
||||
|
||||
|
||||
@@ -9,15 +9,32 @@
|
||||
#include "Register.h"
|
||||
|
||||
template<class T> class RealProgramParty;
|
||||
template<class T> class RealGarbleWire;
|
||||
|
||||
template<class T>
|
||||
class GarbleInputter
|
||||
{
|
||||
public:
|
||||
RealProgramParty<T>& party;
|
||||
|
||||
Bundle<octetStream> oss;
|
||||
PointerVector<pair<RealGarbleWire<T>*, int>> tuples;
|
||||
|
||||
GarbleInputter();
|
||||
void exchange();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class RealGarbleWire : public PRFRegister
|
||||
{
|
||||
friend class RealProgramParty<T>;
|
||||
friend class GarbleInputter<T>;
|
||||
|
||||
T mask;
|
||||
|
||||
public:
|
||||
typedef GarbleInputter<T> Input;
|
||||
|
||||
static void store(NoMemory& dest,
|
||||
const vector<GC::WriteAccess<GC::Secret<RealGarbleWire>>>& accesses);
|
||||
static void load(vector<GC::ReadAccess<GC::Secret<RealGarbleWire>>>& accesses,
|
||||
@@ -26,6 +43,11 @@ public:
|
||||
static void convcbit(Integer& dest, const GC::Clear& source,
|
||||
GC::Processor<GC::Secret<RealGarbleWire>>& processor);
|
||||
|
||||
static void inputb(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
|
||||
const vector<int>& args);
|
||||
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
|
||||
ProcessorBase& input_processor, const vector<int>& args);
|
||||
|
||||
RealGarbleWire(const Register& reg) : PRFRegister(reg) {}
|
||||
|
||||
void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
|
||||
@@ -37,6 +59,10 @@ public:
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
|
||||
void my_input(Input& Inputter, bool value, int n_bits);
|
||||
void other_input(Input& Inputter, int from);
|
||||
void finalize_input(Input&, int from, int);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -94,36 +94,81 @@ void RealGarbleWire<T>::XOR(const RealGarbleWire<T>& left, const RealGarbleWire<
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::input(party_id_t from, char input)
|
||||
void RealGarbleWire<T>::inputb(
|
||||
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
GarbleInputter<T> inputter;
|
||||
processor.inputb(inputter, processor, args,
|
||||
inputter.party.P->my_num());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::inputbvec(
|
||||
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
|
||||
ProcessorBase& input_processor, const vector<int>& args)
|
||||
{
|
||||
GarbleInputter<T> inputter;
|
||||
processor.inputbvec(inputter, input_processor, args,
|
||||
inputter.party.P->my_num());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
GarbleInputter<T>::GarbleInputter() :
|
||||
party(RealProgramParty<T>::s()), oss(*party.P)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::my_input(Input& inputter, bool, int n_bits)
|
||||
{
|
||||
assert(n_bits == 1);
|
||||
inputter.tuples.push_back({this, inputter.party.P->my_num()});
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::other_input(Input& inputter, int from)
|
||||
{
|
||||
inputter.tuples.push_back({this, from});
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GarbleInputter<T>::exchange()
|
||||
{
|
||||
PRFRegister::input(from, input);
|
||||
auto& party = RealProgramParty<T>::s();
|
||||
assert(party.shared_proc != 0);
|
||||
auto& inputter = party.shared_proc->input;
|
||||
inputter.reset(from - 1);
|
||||
if (from == party.get_id())
|
||||
inputter.reset_all(*party.P);
|
||||
for (auto& tuple : tuples)
|
||||
{
|
||||
char my_mask;
|
||||
my_mask = party.prng.get_bit();
|
||||
party.input_masks.serialize(my_mask);
|
||||
inputter.add_mine(my_mask);
|
||||
inputter.send_mine();
|
||||
mask = inputter.finalize_mine();
|
||||
int from = tuple.second;
|
||||
party_id_t from_id = from + 1;
|
||||
tuple.first->PRFRegister::input(from_id, -1);
|
||||
if (from_id == party.get_id())
|
||||
{
|
||||
char my_mask;
|
||||
my_mask = party.prng.get_bit();
|
||||
party.garble_input_masks.serialize(my_mask);
|
||||
inputter.add_mine(my_mask);
|
||||
#ifdef DEBUG_MASK
|
||||
cout << "my mask: " << (int)my_mask << endl;
|
||||
cout << "my mask: " << (int)my_mask << endl;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
inputter.add_other(from);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
inputter.add_other(from - 1);
|
||||
octetStream os;
|
||||
party.P->receive_player(from - 1, os, true);
|
||||
inputter.finalize_other(from - 1, mask, os);
|
||||
}
|
||||
|
||||
inputter.exchange();
|
||||
|
||||
for (auto& tuple : tuples)
|
||||
tuple.first->mask = (inputter.finalize(tuple.second));
|
||||
|
||||
// important to make sure that mask is a bit
|
||||
try
|
||||
{
|
||||
mask.force_to_bit();
|
||||
for (auto& tuple : tuples)
|
||||
tuple.first->mask.force_to_bit();
|
||||
}
|
||||
catch (not_implemented& e)
|
||||
{
|
||||
@@ -131,16 +176,36 @@ void RealGarbleWire<T>::input(party_id_t from, char input)
|
||||
assert(party.MC != 0);
|
||||
auto& protocol = party.shared_proc->protocol;
|
||||
protocol.init_mul(party.shared_proc);
|
||||
protocol.prepare_mul(mask, T::constant(1, party.P->my_num(), party.mac_key) - mask);
|
||||
for (auto& tuple : tuples)
|
||||
protocol.prepare_mul(tuple.first->mask,
|
||||
T::constant(1, party.P->my_num(), party.mac_key)
|
||||
- tuple.first->mask);
|
||||
protocol.exchange();
|
||||
if (party.MC->open(protocol.finalize_mul(), *party.P) != 0)
|
||||
vector<T> to_check;
|
||||
to_check.reserve(tuples.size());
|
||||
for (size_t i = 0; i < tuples.size(); i++)
|
||||
{
|
||||
to_check.push_back(protocol.finalize_mul());
|
||||
}
|
||||
try
|
||||
{
|
||||
party.MC->CheckFor(0, to_check, *party.P);
|
||||
}
|
||||
catch (mac_fail&)
|
||||
{
|
||||
throw runtime_error("input mask not a bit");
|
||||
}
|
||||
}
|
||||
#ifdef DEBUG_MASK
|
||||
cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::finalize_input(GarbleInputter<T>&, int, int)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void RealGarbleWire<T>::public_input(bool value)
|
||||
{
|
||||
@@ -169,7 +234,7 @@ void RealGarbleWire<T>::output()
|
||||
assert(party.MC != 0);
|
||||
assert(party.P != 0);
|
||||
auto m = party.MC->open(mask, *party.P);
|
||||
party.output_masks.push_back(m.get_bit(0));
|
||||
party.garble_output_masks.push_back(m.get_bit(0));
|
||||
party.taint();
|
||||
#ifdef DEBUG_MASK
|
||||
cout << "output mask: " << m << endl;
|
||||
|
||||
@@ -19,6 +19,7 @@ class RealProgramParty : public ProgramPartySpec<T>
|
||||
typedef typename T::Input Inputter;
|
||||
|
||||
friend class RealGarbleWire<T>;
|
||||
friend class GarbleInputter<T>;
|
||||
friend class GarbleJob<T>;
|
||||
|
||||
static RealProgramParty* singleton;
|
||||
@@ -40,9 +41,15 @@ class RealProgramParty : public ProgramPartySpec<T>
|
||||
|
||||
GC::BreakType next;
|
||||
|
||||
bool one_shot;
|
||||
|
||||
size_t data_sent;
|
||||
|
||||
public:
|
||||
static RealProgramParty& s();
|
||||
|
||||
LocalBuffer garble_input_masks, garble_output_masks;
|
||||
|
||||
RealProgramParty(int argc, const char** argv);
|
||||
~RealProgramParty();
|
||||
|
||||
|
||||
@@ -44,10 +44,20 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Evaluate only after garbling.", // Help description.
|
||||
"-O", // Flag token.
|
||||
"--oneshot" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
int nparties;
|
||||
opt.get("-N")->getInt(nparties);
|
||||
this->check(nparties);
|
||||
one_shot = opt.isSet("-O");
|
||||
|
||||
NetworkOptions network_opts(opt, argc, argv);
|
||||
OnlineOptions& online_opts = OnlineOptions::singleton;
|
||||
@@ -90,7 +100,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
mac_key.randomize(prng);
|
||||
if (T::needs_ot)
|
||||
BaseMachine::s().ot_setups.push_back({*P, true});
|
||||
prep = Preprocessing<T>::get_live_prep(0, usage);
|
||||
prep = new typename T::TriplePrep(0, usage);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -122,6 +132,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
for (int i = 0; i < SPDZ_OP_N; i++)
|
||||
this->spdz_wires[i].push_back({});
|
||||
|
||||
this->timer.reset();
|
||||
do
|
||||
{
|
||||
next = GC::TIME_BREAK;
|
||||
@@ -129,7 +140,14 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
try
|
||||
{
|
||||
this->online_timer.start();
|
||||
this->start_online_round();
|
||||
if (one_shot)
|
||||
this->start_online_round();
|
||||
else
|
||||
{
|
||||
this->load_garbled_circuit();
|
||||
next = this->second_phase(program, this->processor,
|
||||
this->machine, this->dynamic_memory);
|
||||
}
|
||||
this->online_timer.stop();
|
||||
}
|
||||
catch (needs_cleaning& e)
|
||||
@@ -139,6 +157,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
while (next != GC::DONE_BREAK);
|
||||
|
||||
MC->Check(*P);
|
||||
data_sent = P->comm_stats.total_data() + prep->data_sent();
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
@@ -152,7 +171,7 @@ void RealProgramParty<T>::garble()
|
||||
auto& program = this->program;
|
||||
auto& MC = this->MC;
|
||||
|
||||
while (next == GC::TIME_BREAK)
|
||||
do
|
||||
{
|
||||
garble_jobs.clear();
|
||||
garble_inputter->reset_all(*P);
|
||||
@@ -178,13 +197,15 @@ void RealProgramParty<T>::garble()
|
||||
vector<typename T::clear> opened;
|
||||
MC->POpen(opened, wires, *P);
|
||||
|
||||
LocalBuffer garbled_circuit;
|
||||
for (auto& x : opened)
|
||||
this->garbled_circuit.serialize(x);
|
||||
garbled_circuit.serialize(x);
|
||||
|
||||
this->garbled_circuits.push_and_clear(this->garbled_circuit);
|
||||
this->input_masks_store.push_and_clear(this->input_masks);
|
||||
this->output_masks_store.push_and_clear(this->output_masks);
|
||||
this->garbled_circuits.push_and_clear(garbled_circuit);
|
||||
this->input_masks_store.push_and_clear(garble_input_masks);
|
||||
this->output_masks_store.push_and_clear(garble_output_masks);
|
||||
}
|
||||
while (one_shot and next == GC::TIME_BREAK);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -194,6 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
|
||||
delete prep;
|
||||
delete garble_inputter;
|
||||
delete garble_protocol;
|
||||
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -601,37 +601,64 @@ void EvalRegister::input_helper(char value, octetStream& os)
|
||||
os.serialize(get_external());
|
||||
}
|
||||
|
||||
void EvalRegister::input(party_id_t from, char value)
|
||||
EvalInputter::EvalInputter() :
|
||||
party(ProgramParty::s()), oss(*party.P)
|
||||
{
|
||||
auto& party = ProgramParty::s();
|
||||
}
|
||||
|
||||
void EvalRegister::my_input(EvalInputter& inputter, bool input, int n_bits)
|
||||
{
|
||||
assert(n_bits == 1);
|
||||
auto& party = inputter.party;
|
||||
party.load_wire(*this);
|
||||
octetStream os;
|
||||
if (from == party.get_id())
|
||||
{
|
||||
if (value and (1 - value))
|
||||
throw runtime_error("invalid input");
|
||||
input_helper(value, os);
|
||||
party.P->send_all(os, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
party.P->receive_player(from - 1, os);
|
||||
char ext;
|
||||
os.unserialize(ext);
|
||||
set_external(ext);
|
||||
}
|
||||
vector<octetStream> oss(party.get_n_parties());
|
||||
size_t id = party.get_id() - 1;
|
||||
oss[id].serialize(garbled_entry[id]);
|
||||
#ifdef DEBUG_COMM
|
||||
cout << "send " << garbled_entry[id] << ", "
|
||||
<< oss[id].get_length() << " bytes from " << id << endl;
|
||||
#endif
|
||||
input_helper(input, inputter.oss.mine);
|
||||
inputter.tuples.push_back({this, party.P->my_num()});
|
||||
}
|
||||
|
||||
void EvalRegister::other_input(EvalInputter& inputter, int from)
|
||||
{
|
||||
auto& party = inputter.party;
|
||||
party.load_wire(*this);
|
||||
inputter.tuples.push_back({this, from});
|
||||
}
|
||||
|
||||
void EvalInputter::exchange()
|
||||
{
|
||||
party.P->Broadcast_Receive(oss, true);
|
||||
for (auto& tuple : tuples)
|
||||
{
|
||||
if (tuple.from != party.P->my_num())
|
||||
{
|
||||
char ext;
|
||||
oss[tuple.from].unserialize(ext);
|
||||
tuple.reg->set_external(ext);
|
||||
}
|
||||
}
|
||||
|
||||
size_t id = party.get_id() - 1;
|
||||
for (auto& os : oss)
|
||||
os.reset_write_head();
|
||||
|
||||
for (auto& tuple : tuples)
|
||||
{
|
||||
oss[id].serialize(tuple.reg->get_garbled_entry()[id]);
|
||||
#ifdef DEBUG_COMM
|
||||
cout << "send " << garbled_entry[id] << ", "
|
||||
<< oss[id].get_length() << " bytes from " << id << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
party.P->Broadcast_Receive(oss, true);
|
||||
}
|
||||
|
||||
void EvalRegister::finalize_input(EvalInputter& inputter, int, int)
|
||||
{
|
||||
auto& party = inputter.party;
|
||||
size_t id = party.get_id() - 1;
|
||||
for (size_t i = 0; i < (size_t)party.get_n_parties(); i++)
|
||||
{
|
||||
if (i != id)
|
||||
oss[i].unserialize(garbled_entry[i]);
|
||||
inputter.oss[i].unserialize(garbled_entry[i]);
|
||||
}
|
||||
keys[external] = garbled_entry;
|
||||
#ifdef DEBUG
|
||||
@@ -934,4 +961,4 @@ void KeyTuple<I>::print(int wire_id, party_id_t pid)
|
||||
}
|
||||
|
||||
template class KeyTuple<2>;
|
||||
template class KeyTuple<4>;
|
||||
template class KeyTuple<4> ;
|
||||
|
||||
@@ -20,6 +20,8 @@ using namespace std;
|
||||
#include "GC/ArgTuples.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
#include "Tools/PointerVector.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
//#define PAD_TO_8(n) (n+8-n%8)
|
||||
#define PAD_TO_8(n) (n)
|
||||
@@ -201,6 +203,8 @@ public:
|
||||
inline BlackHole& endl(BlackHole& b) { return b; }
|
||||
inline BlackHole& flush(BlackHole& b) { return b; }
|
||||
|
||||
class ProcessorBase;
|
||||
|
||||
class Phase
|
||||
{
|
||||
public:
|
||||
@@ -209,6 +213,8 @@ public:
|
||||
typedef BlackHole out_type;
|
||||
static BlackHole out;
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
template <class T>
|
||||
static void store_clear_in_dynamic(T& mem, const vector<GC::ClearWriteAccess>& accesses)
|
||||
{ (void)mem; (void)accesses; }
|
||||
@@ -231,6 +237,9 @@ public:
|
||||
template <class T>
|
||||
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
|
||||
template <class T>
|
||||
static void inputbvec(T&, ProcessorBase&, const vector<int>&)
|
||||
{ throw not_implemented(); }
|
||||
template <class T>
|
||||
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
|
||||
{ return T::input(from, processor.get_input(n_bits), n_bits); }
|
||||
|
||||
@@ -244,9 +253,24 @@ public:
|
||||
void output() {}
|
||||
};
|
||||
|
||||
class NoOpInputter
|
||||
{
|
||||
public:
|
||||
PointerVector<char> inputs;
|
||||
|
||||
void exchange()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
class ProgramRegister : public Phase, public Register
|
||||
{
|
||||
public:
|
||||
typedef NoOpInputter Input;
|
||||
|
||||
// only true for evaluation
|
||||
static const bool actual_inputs = false;
|
||||
|
||||
static Register new_reg();
|
||||
static Register tmp_reg() { return new_reg(); }
|
||||
static Register and_reg() { return new_reg(); }
|
||||
@@ -255,11 +279,18 @@ public:
|
||||
static void store(NoMemory& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
|
||||
|
||||
template <class T>
|
||||
static void inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
|
||||
// most BMR phases don't need actual input
|
||||
template<class T>
|
||||
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
|
||||
{ (void)processor; return T::input(args.from + 1, 0, args.n_bits); }
|
||||
|
||||
void my_input(Input&, bool, int) {}
|
||||
void other_input(Input&, int) {}
|
||||
|
||||
char get_output() { return 0; }
|
||||
|
||||
ProgramRegister(const Register& reg) : Register(reg) {}
|
||||
@@ -282,6 +313,37 @@ public:
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
|
||||
void finalize_input(NoOpInputter&, int from, int)
|
||||
{ input(from + 1, -1); }
|
||||
};
|
||||
|
||||
class ProgramParty;
|
||||
class EvalRegister;
|
||||
|
||||
class EvalInputter
|
||||
{
|
||||
class Tuple
|
||||
{
|
||||
public:
|
||||
EvalRegister* reg;
|
||||
int from;
|
||||
|
||||
Tuple(EvalRegister* reg, int from) :
|
||||
reg(reg), from(from)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ProgramParty& party;
|
||||
|
||||
Bundle<octetStream> oss;
|
||||
vector<Tuple> tuples;
|
||||
|
||||
EvalInputter();
|
||||
void add_other(int from);
|
||||
void exchange();
|
||||
};
|
||||
|
||||
class EvalRegister : public ProgramRegister
|
||||
@@ -289,9 +351,13 @@ class EvalRegister : public ProgramRegister
|
||||
public:
|
||||
static string name() { return "Evaluation"; }
|
||||
|
||||
typedef EvalInputter Input;
|
||||
|
||||
typedef ostream& out_type;
|
||||
static ostream& out;
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
template<class T, class U>
|
||||
static void store(GC::Memory<U>& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses);
|
||||
@@ -303,6 +369,9 @@ public:
|
||||
static void andrs(T& processor, const vector<int>& args);
|
||||
template <class T>
|
||||
static void inputb(T& processor, const vector<int>& args);
|
||||
template <class T>
|
||||
static void inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
|
||||
template <class T>
|
||||
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
|
||||
@@ -330,6 +399,10 @@ public:
|
||||
static void check_input(long long input, int n_bits);
|
||||
void input(party_id_t from, char value = -1);
|
||||
void input_helper(char value, octetStream& os);
|
||||
|
||||
void my_input(EvalInputter& inputter, bool input, int);
|
||||
void other_input(EvalInputter& inputter, int from);
|
||||
void finalize_input(EvalInputter& inputter, int from, int);
|
||||
};
|
||||
|
||||
class GarbleRegister : public ProgramRegister
|
||||
@@ -349,6 +422,9 @@ public:
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output() {}
|
||||
|
||||
void finalize_input(NoOpInputter&, int from, int)
|
||||
{ input(from + 1, -1); }
|
||||
};
|
||||
|
||||
class RandomRegister : public ProgramRegister
|
||||
@@ -374,6 +450,9 @@ public:
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
|
||||
void finalize_input(NoOpInputter&, int from, int)
|
||||
{ input(from + 1, -1); }
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,31 @@
|
||||
#include "Register.h"
|
||||
#include "Party.h"
|
||||
|
||||
template<class T>
|
||||
void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
NoOpInputter inputter;
|
||||
int my_num = -1;
|
||||
try
|
||||
{
|
||||
my_num = ProgramParty::s().P->my_num();
|
||||
}
|
||||
catch (exception&)
|
||||
{
|
||||
}
|
||||
processor.inputbvec(inputter, input_processor, args, my_num);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
EvalInputter inputter;
|
||||
processor.inputbvec(inputter, input_processor, args,
|
||||
ProgramParty::s().P->my_num());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void PRFRegister::load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const NoMemory& source)
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Program.hpp"
|
||||
#include "GC/Instruction.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.1.9 (Aug 24, 2020)
|
||||
|
||||
- Streamline inputs to binary circuits
|
||||
- Improved private output
|
||||
- Emulator for arithmetic circuits
|
||||
- Efficient dot product with Shamir's secret sharing
|
||||
- Lower memory usage for TensorFlow inference
|
||||
- This version breaks bytecode compatibilty.
|
||||
|
||||
## 0.1.8 (June 15, 2020)
|
||||
|
||||
- Half-gate garbling
|
||||
|
||||
@@ -37,6 +37,7 @@ opcodes = dict(
|
||||
STMSBI = 0x243,
|
||||
MOVSB = 0x244,
|
||||
INPUTB = 0x246,
|
||||
INPUTBVEC = 0x247,
|
||||
SPLIT = 0x248,
|
||||
CONVCBIT2S = 0x249,
|
||||
XORCBI = 0x210,
|
||||
@@ -269,6 +270,24 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
arg_format = tools.cycle(['p','int','int','sbw'])
|
||||
is_vec = lambda self: True
|
||||
|
||||
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
base.Mergeable):
|
||||
__slots__ = []
|
||||
code = opcodes['INPUTBVEC']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.arg_format = []
|
||||
i = 0
|
||||
while i < len(args):
|
||||
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (args[i] - 3)
|
||||
i += args[i]
|
||||
assert i == len(args)
|
||||
super(inputbvec, self).__init__(*args, **kwargs)
|
||||
|
||||
def merge(self, other):
|
||||
self.args += other.args
|
||||
self.arg_format += other.arg_format
|
||||
|
||||
class print_regb(base.VectorInstruction, base.IOInstruction):
|
||||
code = opcodes['PRINTREGB']
|
||||
arg_format = ['cb','i']
|
||||
|
||||
@@ -88,7 +88,8 @@ class bits(Tape.Register, _structure, _bit):
|
||||
if mem_type == 'sd':
|
||||
return cls.load_dynamic_mem(address)
|
||||
else:
|
||||
cls.load_inst[util.is_constant(address)](res, address)
|
||||
for i in range(res.size):
|
||||
cls.load_inst[util.is_constant(address)](res[i], address + i)
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, int)](self, address)
|
||||
@@ -120,17 +121,19 @@ class bits(Tape.Register, _structure, _bit):
|
||||
assert(other.size == math.ceil(self.n / self.unit))
|
||||
for i, (x, y) in enumerate(zip(self, other)):
|
||||
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
|
||||
elif isinstance(self, type(other)) or isinstance(other, type(self)):
|
||||
assert(self.n == other.n)
|
||||
elif (isinstance(self, type(other)) or isinstance(other, type(self))) \
|
||||
and self.n == other.n:
|
||||
for i in range(math.ceil(self.n / self.unit)):
|
||||
self.mov(self[i], other[i])
|
||||
else:
|
||||
try:
|
||||
other = self.bit_compose(other.bit_decompose())
|
||||
bits = other.bit_decompose()
|
||||
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
|
||||
other = self.bit_compose(bits)
|
||||
self.load_other(other)
|
||||
except:
|
||||
raise CompilerError('cannot convert from %s to %s' % \
|
||||
(type(other), type(self)))
|
||||
raise CompilerError('cannot convert %s/%s from %s to %s' % \
|
||||
(str(other), repr(other), type(other), type(self)))
|
||||
def long_one(self):
|
||||
return 2**self.n - 1 if self.n != None else None
|
||||
def __repr__(self):
|
||||
@@ -160,6 +163,9 @@ class cbits(bits):
|
||||
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
|
||||
conv_cint_vec = inst.convcintvec
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
return sum(bit << i for i, bit in enumerate(bits))
|
||||
@classmethod
|
||||
def conv_regint_by_bit(cls, n, res, other):
|
||||
assert n == res.n
|
||||
assert n == other.size
|
||||
@@ -459,46 +465,68 @@ class sbits(bits):
|
||||
class sbitvec(_vec):
|
||||
@classmethod
|
||||
def get_type(cls, n):
|
||||
class sbitvecn(cls):
|
||||
class sbitvecn(cls, _structure):
|
||||
@staticmethod
|
||||
def malloc(size):
|
||||
return sbits.malloc(size * n)
|
||||
return sbit.malloc(size * n)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return n
|
||||
@classmethod
|
||||
def get_input_from(cls, player):
|
||||
return cls.from_vec(
|
||||
sbits.get_input_from(player, n).bit_decompose(n))
|
||||
res = cls.from_vec(sbit() for i in range(n))
|
||||
inst.inputbvec(n + 3, 0, player, *res.v)
|
||||
return res
|
||||
get_raw_input_from = get_input_from
|
||||
def __init__(self, other=None):
|
||||
if other is not None:
|
||||
self.v = sbits(other, n=n).bit_decompose(n)
|
||||
if util.is_constant(other):
|
||||
self.v = [sbit((other >> i) & 1) for i in range(n)]
|
||||
else:
|
||||
self.v = sbits(other, n=n).bit_decompose(n)
|
||||
@classmethod
|
||||
def load_mem(cls, address):
|
||||
try:
|
||||
assert len(address) == n
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
return cls.from_vec(sbit.load_mem(x) for x in address)
|
||||
except:
|
||||
else:
|
||||
return cls.from_vec(sbit.load_mem(address + i)
|
||||
for i in range(n))
|
||||
def store_in_mem(self, address):
|
||||
assert self.v[0].n == 1
|
||||
try:
|
||||
assert len(address) == n
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
for x, y in zip(self.v, address):
|
||||
x.store_in_mem(y)
|
||||
except:
|
||||
else:
|
||||
for i in range(n):
|
||||
self.v[i].store_in_mem(address + i)
|
||||
def reveal(self):
|
||||
return self.elements()[0].reveal()
|
||||
revealed = [cbit() for i in range(len(self))]
|
||||
for i in range(len(self)):
|
||||
inst.reveal(1, revealed[i], self.v[i])
|
||||
return cbits.get_type(len(self)).bit_compose(revealed)
|
||||
@classmethod
|
||||
def two_power(cls, nn):
|
||||
return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1))
|
||||
def coerce(self, other):
|
||||
if util.is_constant(other):
|
||||
return self.from_vec(util.bit_decompose(other, n))
|
||||
else:
|
||||
return super(sbitvecn, self).coerce(other)
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
if len(bits) < n:
|
||||
bits += [0] * (n - len(bits))
|
||||
assert len(bits) == n
|
||||
return cls.from_vec(bits)
|
||||
def __str__(self):
|
||||
return 'sbitvec(%d)' % n
|
||||
return sbitvecn
|
||||
@classmethod
|
||||
def from_vec(cls, vector):
|
||||
res = cls()
|
||||
res.v = list(vector)
|
||||
return res
|
||||
compose = from_vec
|
||||
@classmethod
|
||||
def combine(cls, vectors):
|
||||
res = cls()
|
||||
@@ -512,11 +540,7 @@ class sbitvec(_vec):
|
||||
if length:
|
||||
assert isinstance(elements, sint)
|
||||
if Program.prog.use_split():
|
||||
n = Program.prog.use_split()
|
||||
columns = [[sbits.get_type(elements.size)()
|
||||
for i in range(n)] for i in range(length)]
|
||||
inst.split(n, elements, *sum(columns, []))
|
||||
x = sbitint.wallace_tree_without_finish(columns, False)
|
||||
x = elements.split_to_two_summands(length)
|
||||
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
|
||||
else:
|
||||
assert Program.prog.options.ring
|
||||
@@ -569,11 +593,14 @@ class sbitvec(_vec):
|
||||
@property
|
||||
def size(self):
|
||||
return self.v[0].n
|
||||
@property
|
||||
def n_bits(self):
|
||||
return len(self.v)
|
||||
def store_in_mem(self, address):
|
||||
for i, x in enumerate(self.elements()):
|
||||
x.store_in_mem(address + i)
|
||||
def bit_decompose(self):
|
||||
return self.v
|
||||
def bit_decompose(self, n_bits=None):
|
||||
return self.v[:n_bits]
|
||||
bit_compose = from_vec
|
||||
def reveal(self):
|
||||
assert len(self) == 1
|
||||
@@ -673,7 +700,41 @@ cbits.dynamic_array = Array
|
||||
def _complement_two_extend(bits, k):
|
||||
return bits + [bits[-1]] * (k - len(bits))
|
||||
|
||||
class sbitint(_bitint, _number, sbits):
|
||||
class _sbitintbase:
|
||||
def extend(self, n):
|
||||
bits = self.bit_decompose()
|
||||
bits += [bits[-1]] * (n - len(bits))
|
||||
return self.get_type(n).bit_compose(bits)
|
||||
def cast(self, n):
|
||||
bits = self.bit_decompose()[:n]
|
||||
bits += [bits[-1]] * (n - len(bits))
|
||||
return self.get_type(n).bit_compose(bits)
|
||||
def round(self, k, m, kappa=None, nearest=None, signed=None):
|
||||
bits = self.bit_decompose()
|
||||
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
|
||||
return self.get_type(k - m).compose(res_bits)
|
||||
def int_div(self, other, bit_length=None):
|
||||
k = bit_length or max(self.n, other.n)
|
||||
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
|
||||
def Norm(self, k, f, kappa=None, simplex_flag=False):
|
||||
absolute_val = abs(self)
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k)[::-1]
|
||||
suffixes = floatingpoint.PreOR(bits)[::-1]
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
z[i] = suffixes[i] - suffixes[i+1]
|
||||
z[k - 1] = suffixes[k-1]
|
||||
z.reverse()
|
||||
t2k = self.get_type(2 * k)
|
||||
acc = t2k.bit_compose(z)
|
||||
sign = self.bit_decompose()[-1]
|
||||
signed_acc = util.if_else(sign, -acc, acc)
|
||||
absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose())
|
||||
part_reciprocal = absolute_val_2k * acc
|
||||
return part_reciprocal, signed_acc
|
||||
|
||||
class sbitint(_bitint, _number, sbits, _sbitintbase):
|
||||
n_bits = None
|
||||
bin_type = None
|
||||
types = {}
|
||||
@@ -728,43 +789,11 @@ class sbitint(_bitint, _number, sbits):
|
||||
res_bits = product.bit_decompose()[m:k]
|
||||
t = self.combo_type(other)
|
||||
return t.bit_compose(res_bits)
|
||||
def Norm(self, k, f, kappa=None, simplex_flag=False):
|
||||
absolute_val = abs(self)
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k)[::-1]
|
||||
suffixes = floatingpoint.PreOR(bits)[::-1]
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
z[i] = suffixes[i] - suffixes[i+1]
|
||||
z[k - 1] = suffixes[k-1]
|
||||
z.reverse()
|
||||
t2k = self.get_type(2 * k)
|
||||
acc = t2k.bit_compose(z)
|
||||
sign = self.bit_decompose()[-1]
|
||||
signed_acc = util.if_else(sign, -acc, acc)
|
||||
absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose())
|
||||
part_reciprocal = absolute_val_2k * acc
|
||||
return part_reciprocal, signed_acc
|
||||
def extend(self, n):
|
||||
bits = self.bit_decompose()
|
||||
bits += [bits[-1]] * (n - len(bits))
|
||||
return self.get_type(n).bit_compose(bits)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbitintvec):
|
||||
return other * self
|
||||
else:
|
||||
return super(sbitint, self).__mul__(other)
|
||||
def cast(self, n):
|
||||
bits = self.bit_decompose()[:n]
|
||||
bits += [bits[-1]] * (n - len(bits))
|
||||
return self.get_type(n).bit_compose(bits)
|
||||
def int_div(self, other, bit_length=None):
|
||||
k = bit_length or max(self.n, other.n)
|
||||
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
|
||||
def round(self, k, m, kappa=None, nearest=None, signed=None):
|
||||
bits = self.bit_decompose()
|
||||
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
|
||||
return self.get_type(k - m).compose(res_bits)
|
||||
@classmethod
|
||||
def get_bit_matrix(cls, self_bits, other):
|
||||
n = len(self_bits)
|
||||
@@ -782,10 +811,11 @@ class sbitint(_bitint, _number, sbits):
|
||||
res.append([(x & bit) for x in other.bit_decompose(n - i)])
|
||||
return res
|
||||
|
||||
class sbitintvec(sbitvec, _number):
|
||||
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
|
||||
def __add__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
other = self.coerce(other)
|
||||
assert(len(self.v) == len(other.v))
|
||||
v = sbitint.bit_adder(self.v, other.v)
|
||||
return self.from_vec(v)
|
||||
@@ -797,7 +827,7 @@ class sbitintvec(sbitvec, _number):
|
||||
if isinstance(other, sbits):
|
||||
return self.from_vec(other * x for x in self.v)
|
||||
matrix = []
|
||||
for i, b in enumerate(other.bit_decompose()):
|
||||
for i, b in enumerate(util.bit_decompose(other)):
|
||||
matrix.append([x * b for x in self.v[:len(self.v)-i]])
|
||||
v = sbitint.wallace_tree_from_matrix(matrix)
|
||||
return self.from_vec(v[:len(self.v)])
|
||||
|
||||
@@ -12,6 +12,51 @@ import operator
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
class BlockAllocator:
|
||||
""" Manages freed memory blocks. """
|
||||
def __init__(self):
|
||||
self.by_logsize = [defaultdict(set) for i in range(32)]
|
||||
self.by_address = {}
|
||||
|
||||
def by_size(self, size):
|
||||
return self.by_logsize[int(math.log(size, 2))][size]
|
||||
|
||||
def push(self, address, size):
|
||||
end = address + size
|
||||
if end in self.by_address:
|
||||
next_size = self.by_address.pop(end)
|
||||
self.by_size(next_size).remove(end)
|
||||
size += next_size
|
||||
self.by_size(size).add(address)
|
||||
self.by_address[address] = size
|
||||
|
||||
def pop(self, size):
|
||||
if len(self.by_size(size)) > 0:
|
||||
block_size = size
|
||||
else:
|
||||
logsize = int(math.log(size, 2))
|
||||
for block_size, addresses in self.by_logsize[logsize].items():
|
||||
if block_size >= size and len(addresses) > 0:
|
||||
break
|
||||
else:
|
||||
done = False
|
||||
for x in self.by_logsize[logsize + 1:]:
|
||||
for block_size, addresses in x.items():
|
||||
if len(addresses) > 0:
|
||||
done = True
|
||||
break
|
||||
if done:
|
||||
break
|
||||
else:
|
||||
block_size = 0
|
||||
if block_size >= size:
|
||||
addr = self.by_size(block_size).pop()
|
||||
del self.by_address[addr]
|
||||
diff = block_size - size
|
||||
if diff:
|
||||
self.by_size(diff).add(addr + size)
|
||||
self.by_address[addr + size] = diff
|
||||
return addr
|
||||
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
@@ -509,3 +554,47 @@ class Merger:
|
||||
for i in range(self.G.n):
|
||||
print('%d: %s' % (self.depths[i], self.instructions[i]), file=f)
|
||||
f.close()
|
||||
|
||||
class RegintOptimizer:
|
||||
def __init__(self):
|
||||
self.cache = util.dict_by_id()
|
||||
|
||||
def run(self, instructions):
|
||||
for i, inst in enumerate(instructions):
|
||||
if isinstance(inst, ldint_class):
|
||||
self.cache[inst.args[0]] = inst.args[1]
|
||||
elif isinstance(inst, IntegerInstruction):
|
||||
if inst.args[1] in self.cache and inst.args[2] in self.cache:
|
||||
res = inst.op(self.cache[inst.args[1]],
|
||||
self.cache[inst.args[2]])
|
||||
if abs(res) < 2 ** 31:
|
||||
self.cache[inst.args[0]] = res
|
||||
instructions[i] = ldint(inst.args[0], res,
|
||||
add_to_prog=False)
|
||||
elif isinstance(inst, addint_class):
|
||||
if inst.args[1] in self.cache and \
|
||||
self.cache[inst.args[1]] == 0:
|
||||
instructions[i] = inst.args[0].link(inst.args[2])
|
||||
elif inst.args[2] in self.cache and \
|
||||
self.cache[inst.args[2]] == 0:
|
||||
instructions[i] = inst.args[0].link(inst.args[1])
|
||||
elif isinstance(inst, IndirectMemoryInstruction):
|
||||
if inst.args[1] in self.cache:
|
||||
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
|
||||
elif isinstance(inst, convint_class):
|
||||
if inst.args[1] in self.cache:
|
||||
res = self.cache[inst.args[1]]
|
||||
self.cache[inst.args[0]] = res
|
||||
if abs(res) < 2 ** 31:
|
||||
instructions[i] = ldi(inst.args[0], res,
|
||||
add_to_prog=False)
|
||||
elif isinstance(inst, mulm_class):
|
||||
if inst.args[2] in self.cache:
|
||||
op = self.cache[inst.args[2]]
|
||||
if op == 0:
|
||||
instructions[i] = ldsi(inst.args[0], 0,
|
||||
add_to_prog=False)
|
||||
elif op == 1:
|
||||
instructions[i] = None
|
||||
inst.args[0].link(inst.args[1])
|
||||
instructions[:] = filter(lambda x: x is not None, instructions)
|
||||
|
||||
@@ -79,7 +79,10 @@ def LTZ(s, a, k, kappa):
|
||||
from .types import sint, _bitint
|
||||
from .GC.types import sbitvec
|
||||
if program.use_split():
|
||||
movs(s, sint.conv(sbitvec(a, k).v[-1]))
|
||||
summands = a.split_to_two_summands(k)
|
||||
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
|
||||
msb = carry ^ summands[0][-1] ^ summands[1][-1]
|
||||
movs(s, sint.conv(msb))
|
||||
return
|
||||
elif program.options.ring:
|
||||
from . import floatingpoint
|
||||
@@ -132,9 +135,31 @@ def Trunc(d, a, k, m, kappa, signed):
|
||||
mulm(d, t, c[2])
|
||||
|
||||
def TruncRing(d, a, k, m, signed):
|
||||
a_prime = Mod2mRing(None, a, k, m, signed)
|
||||
a -= a_prime
|
||||
res = TruncLeakyInRing(a, k, m, signed)
|
||||
if program.use_split() == 3:
|
||||
from Compiler.types import sint
|
||||
from .GC.types import sbitint
|
||||
length = int(program.options.ring)
|
||||
summands = a.split_to_n_summands(length, 3)
|
||||
x = sbitint.wallace_tree_without_finish(summands, True)
|
||||
if m == 1:
|
||||
low = x[1][1]
|
||||
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
|
||||
sint.conv(x[0][-1])
|
||||
else:
|
||||
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
|
||||
low = sint.conv(mid_carry) + sint.conv(x[0][m])
|
||||
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
|
||||
for xx, yy in zip(x[1][m:-1],
|
||||
x[0][m:-1])))
|
||||
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
|
||||
high = top_carry + sint.conv(x[0][-1])
|
||||
shifted = sint()
|
||||
shrsi(shifted, a, m)
|
||||
res = shifted + sint.conv(low) - (high << (length - m))
|
||||
else:
|
||||
a_prime = Mod2mRing(None, a, k, m, signed)
|
||||
a -= a_prime
|
||||
res = TruncLeakyInRing(a, k, m, signed)
|
||||
if d is not None:
|
||||
movs(d, res)
|
||||
return res
|
||||
@@ -281,7 +306,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
"""
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
from .types import sint
|
||||
if program.use_edabit() and m > 1:
|
||||
if program.use_edabit() and m > 1 and not const_rounds:
|
||||
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
|
||||
tmp, b[:] = sint.get_edabit(m, True)
|
||||
movs(r_prime, tmp)
|
||||
@@ -290,7 +315,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
t[0][1] = b[-1]
|
||||
PRandInt(r_dprime, k + kappa - m)
|
||||
# r_dprime is always multiplied by 2^m
|
||||
if use_dabit and program.use_dabit and m > 1:
|
||||
if use_dabit and program.use_dabit and m > 1 and not const_rounds:
|
||||
r, b[:] = zip(*(sint.get_dabit() for i in range(m)))
|
||||
r = sint.bit_compose(r)
|
||||
movs(r_prime, r)
|
||||
@@ -383,7 +408,7 @@ def BitLTC1(u, a, b, kappa):
|
||||
Mod2(u, t[4][k-1], k, kappa, False)
|
||||
return p, a_bits, d, s, t, c, b, pre_input
|
||||
|
||||
def carry(b, a, compute_p):
|
||||
def carry(b, a, compute_p=True):
|
||||
""" Carry propogation:
|
||||
return (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
|
||||
"""
|
||||
|
||||
@@ -9,14 +9,14 @@ import time
|
||||
import sys
|
||||
|
||||
|
||||
def run(args, options, param=-1, merge_opens=True,
|
||||
def run(args, options, merge_opens=True,
|
||||
reallocate=True, debug=False):
|
||||
""" Compile a file and output a Program object.
|
||||
|
||||
If merge_opens is set to True, will attempt to merge any parallelisable open
|
||||
instructions. """
|
||||
|
||||
prog = Program(args, options, param)
|
||||
prog = Program(args, options)
|
||||
instructions.program = prog
|
||||
instructions_base.program = prog
|
||||
types.program = prog
|
||||
@@ -24,7 +24,7 @@ def run(args, options, param=-1, merge_opens=True,
|
||||
prog.DEBUG = debug
|
||||
VARS['program'] = prog
|
||||
if options.binary:
|
||||
VARS['sint'] = GC_types.sbitint.get_type(int(options.binary))
|
||||
VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary))
|
||||
VARS['sfix'] = GC_types.sbitfix
|
||||
comparison.set_variant(options)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ USER_MEM = 8192
|
||||
P_VALUES = { 32: 2147565569, \
|
||||
64: 9223372036855103489, \
|
||||
128: 170141183460469231731687303715885907969, \
|
||||
192: 3138550867693340381917894711603833208051177722232017256453, \
|
||||
256: 57896044618658097711785492504343953926634992332820282019728792003956566065153, \
|
||||
512: 6703903964971298549787012499102923063739682910296196688861780721860882015036773488400937149083451713845015929093243025426876941405973284973216824503566337 }
|
||||
|
||||
@@ -19,7 +20,7 @@ BIT_LENGTHS = { -1: 64,
|
||||
512: 64 }
|
||||
|
||||
|
||||
COST = defaultdict(lambda: defaultdict(lambda: 0),
|
||||
COST = defaultdict(lambda: defaultdict(lambda: 0),
|
||||
{ 'modp': defaultdict(lambda: 0,
|
||||
{ 'triple': 0.00020652622883106154,
|
||||
'square': 0.00020652622883106154,
|
||||
|
||||
@@ -104,11 +104,15 @@ def PreORC(a, kappa=None, m=None, raw=False):
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return [a[0]]
|
||||
prog = program.Program.prog
|
||||
kappa = kappa or prog.security
|
||||
m = m or k
|
||||
if isinstance(a[0], types.sgf2n):
|
||||
max_k = program.Program.prog.galois_length - 1
|
||||
else:
|
||||
max_k = int(log(program.Program.prog.P) / log(2)) - kappa
|
||||
# assume prime length is power of two
|
||||
prime_length = 2 ** int(ceil(log(prog.bit_length + kappa, 2)))
|
||||
max_k = prime_length - kappa - 2
|
||||
assert(max_k > 0)
|
||||
if k <= max_k:
|
||||
p = [None] * m
|
||||
@@ -132,7 +136,8 @@ def PreORC(a, kappa=None, m=None, raw=False):
|
||||
# not constant-round anymore
|
||||
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
|
||||
return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])
|
||||
return sum(([or_op(x, y) for x in si]
|
||||
for si,y in zip(s[1:],t)), s[0])[-m:]
|
||||
|
||||
def PreOpL(op, items):
|
||||
"""
|
||||
@@ -549,9 +554,10 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
trunc_pr(res, a, k, m)
|
||||
else:
|
||||
# extra bit to mask overflow
|
||||
if program.Program.prog.use_edabit():
|
||||
lower = sint.get_edabit(m, True)[0]
|
||||
upper = sint.get_edabit(k - m, True)[0]
|
||||
prog = program.Program.prog
|
||||
if prog.use_edabit() or prog.use_split() == 3:
|
||||
lower = sint.get_random_int(m)
|
||||
upper = sint.get_random_int(k - m)
|
||||
msb = sint.get_random_bit()
|
||||
r = (msb << k) + (upper << m) + lower
|
||||
else:
|
||||
|
||||
@@ -91,64 +91,74 @@ class stmint(base.DirectMemoryWriteInstruction):
|
||||
|
||||
# must have seperate instructions because address is always modp
|
||||
@base.vectorize
|
||||
class ldmci(base.ReadMemoryInstruction):
|
||||
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
|
||||
code = base.opcodes['LDMCI']
|
||||
arg_format = ['cw','ci']
|
||||
direct = staticmethod(ldmc)
|
||||
|
||||
@base.vectorize
|
||||
class ldmsi(base.ReadMemoryInstruction):
|
||||
class ldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
|
||||
code = base.opcodes['LDMSI']
|
||||
arg_format = ['sw','ci']
|
||||
direct = staticmethod(ldms)
|
||||
|
||||
@base.vectorize
|
||||
class stmci(base.WriteMemoryInstruction):
|
||||
class stmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
|
||||
code = base.opcodes['STMCI']
|
||||
arg_format = ['c','ci']
|
||||
direct = staticmethod(stmc)
|
||||
|
||||
@base.vectorize
|
||||
class stmsi(base.WriteMemoryInstruction):
|
||||
class stmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
|
||||
code = base.opcodes['STMSI']
|
||||
arg_format = ['s','ci']
|
||||
direct = staticmethod(stms)
|
||||
|
||||
@base.vectorize
|
||||
class ldminti(base.ReadMemoryInstruction):
|
||||
class ldminti(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """
|
||||
code = base.opcodes['LDMINTI']
|
||||
arg_format = ['ciw','ci']
|
||||
direct = staticmethod(ldmint)
|
||||
|
||||
@base.vectorize
|
||||
class stminti(base.WriteMemoryInstruction):
|
||||
class stminti(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """
|
||||
code = base.opcodes['STMINTI']
|
||||
arg_format = ['ci','ci']
|
||||
direct = staticmethod(stmint)
|
||||
|
||||
@base.vectorize
|
||||
class gldmci(base.ReadMemoryInstruction):
|
||||
class gldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
|
||||
code = base.opcodes['LDMCI'] + 0x100
|
||||
arg_format = ['cgw','ci']
|
||||
direct = staticmethod(gldmc)
|
||||
|
||||
@base.vectorize
|
||||
class gldmsi(base.ReadMemoryInstruction):
|
||||
class gldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
|
||||
code = base.opcodes['LDMSI'] + 0x100
|
||||
arg_format = ['sgw','ci']
|
||||
direct = staticmethod(gldms)
|
||||
|
||||
@base.vectorize
|
||||
class gstmci(base.WriteMemoryInstruction):
|
||||
class gstmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
|
||||
code = base.opcodes['STMCI'] + 0x100
|
||||
arg_format = ['cg','ci']
|
||||
direct = staticmethod(gstmc)
|
||||
|
||||
@base.vectorize
|
||||
class gstmsi(base.WriteMemoryInstruction):
|
||||
class gstmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
|
||||
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
|
||||
code = base.opcodes['STMSI'] + 0x100
|
||||
arg_format = ['sg','ci']
|
||||
direct = staticmethod(gstms)
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -268,7 +278,7 @@ class use_edabit(base.Instruction):
|
||||
class run_tape(base.Instruction):
|
||||
r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """
|
||||
code = base.opcodes['RUN_TAPE']
|
||||
arg_format = ['int','int','int']
|
||||
arg_format = tools.cycle(['int','int','int'])
|
||||
|
||||
class join_tape(base.Instruction):
|
||||
r""" Join thread $c_i$. """
|
||||
@@ -661,6 +671,12 @@ class shrci(base.ClearShiftInstruction):
|
||||
code = base.opcodes['SHRCI']
|
||||
op = '__rshift__'
|
||||
|
||||
@base.vectorize
|
||||
class shrsi(base.ClearShiftInstruction):
|
||||
r""" Secret bitwise shift right by immediate value. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['SHRSI']
|
||||
arg_format = ['sw','s','i']
|
||||
|
||||
###
|
||||
### Data access instructions
|
||||
@@ -742,6 +758,14 @@ class sedabit(base.Instruction):
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('sedabit', len(self.args) - 1), self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
class randoms(base.Instruction):
|
||||
""" Random share """
|
||||
__slots__ = []
|
||||
code = base.opcodes['RANDOMS']
|
||||
arg_format = ['sw','int']
|
||||
field_type = 'modp'
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class square(base.DataInstruction):
|
||||
@@ -781,6 +805,17 @@ class inputmask(base.Instruction):
|
||||
req_node.increment((self.field_type, 'input', self.args[1]), \
|
||||
self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
class inputmaskreg(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTMASKREG']
|
||||
arg_format = ['sw', 'cw', 'ci']
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
# player 0 as proxy
|
||||
req_node.increment((self.field_type, 'input', 0), float('inf'))
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class prep(base.Instruction):
|
||||
@@ -1165,21 +1200,25 @@ class ldint(base.Instruction):
|
||||
class addint(base.IntegerInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['ADDINT']
|
||||
op = operator.add
|
||||
|
||||
@base.vectorize
|
||||
class subint(base.IntegerInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['SUBINT']
|
||||
op = operator.sub
|
||||
|
||||
@base.vectorize
|
||||
class mulint(base.IntegerInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['MULINT']
|
||||
op = operator.mul
|
||||
|
||||
@base.vectorize
|
||||
class divint(base.IntegerInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['DIVINT']
|
||||
op = operator.floordiv
|
||||
|
||||
@base.vectorize
|
||||
class bitdecint(base.Instruction):
|
||||
@@ -1228,18 +1267,21 @@ class ltc(base.IntegerInstruction):
|
||||
r""" Clear comparison $c_i = (c_j \stackrel{?}{<} c_k)$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['LTC']
|
||||
op = operator.lt
|
||||
|
||||
@base.vectorize
|
||||
class gtc(base.IntegerInstruction):
|
||||
r""" Clear comparison $c_i = (c_j \stackrel{?}{>} c_k)$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['GTC']
|
||||
op = operator.gt
|
||||
|
||||
@base.vectorize
|
||||
class eqc(base.IntegerInstruction):
|
||||
r""" Clear comparison $c_i = (c_j \stackrel{?}{==} c_k)$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['EQC']
|
||||
op = operator.eq
|
||||
|
||||
|
||||
###
|
||||
|
||||
@@ -106,10 +106,12 @@ opcodes = dict(
|
||||
GBITTRIPLE = 0x154,
|
||||
GBITGF2NTRIPLE = 0x155,
|
||||
INPUTMASK = 0x56,
|
||||
INPUTMASKREG = 0x5C,
|
||||
PREP = 0x57,
|
||||
DABIT = 0x58,
|
||||
EDABIT = 0x59,
|
||||
SEDABIT = 0x5A,
|
||||
RANDOMS = 0x5B,
|
||||
# Input
|
||||
INPUT = 0x60,
|
||||
INPUTFIX = 0xF0,
|
||||
@@ -143,6 +145,7 @@ opcodes = dict(
|
||||
SHRC = 0x81,
|
||||
SHLCI = 0x82,
|
||||
SHRCI = 0x83,
|
||||
SHRSI = 0x84,
|
||||
# Branching and comparison
|
||||
JMP = 0x90,
|
||||
JMPNZ = 0x91,
|
||||
@@ -446,10 +449,11 @@ def cisc(function):
|
||||
assert int(program.options.max_parallel_open) == 0, \
|
||||
'merging restriction not compatible with ' \
|
||||
'mergeable CISC instructions'
|
||||
merger.longest_paths_merge()
|
||||
n_rounds = merger.longest_paths_merge()
|
||||
filtered = filter(lambda x: x is not None, block.instructions)
|
||||
self.instructions[self.merge_id()] = list(filtered), args
|
||||
template, args = self.instructions[self.merge_id()]
|
||||
self.instructions[self.merge_id()] = list(filtered), args, \
|
||||
n_rounds
|
||||
template, args, self.n_rounds = self.instructions[self.merge_id()]
|
||||
subs = util.dict_by_id()
|
||||
for arg, reg in zip(args, regs):
|
||||
subs[arg] = reg
|
||||
@@ -486,6 +490,9 @@ def cisc(function):
|
||||
base += reg.size
|
||||
return block.instructions
|
||||
|
||||
def expanded_rounds(self):
|
||||
return self.n_rounds - 1
|
||||
|
||||
MergeCISC.__name__ = function.__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -649,7 +656,7 @@ class String(ArgFormat):
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return arg + '\0' * (cls.length - len(arg))
|
||||
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
|
||||
|
||||
ArgFormats = {
|
||||
'c': ClearModpAF,
|
||||
@@ -683,6 +690,7 @@ class Instruction(object):
|
||||
"""
|
||||
__slots__ = ['args', 'arg_format', 'code', 'caller']
|
||||
count = 0
|
||||
code_length = 10
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
""" Create an instruction and append it to the program list. """
|
||||
@@ -701,7 +709,7 @@ class Instruction(object):
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
|
||||
def get_code(self, prefix=0):
|
||||
return (prefix << 10) + self.code
|
||||
return (prefix << self.code_length) + self.code
|
||||
|
||||
def get_encoding(self):
|
||||
enc = int_to_bytes(self.get_code())
|
||||
@@ -713,7 +721,10 @@ class Instruction(object):
|
||||
return enc
|
||||
|
||||
def get_bytes(self):
|
||||
return bytearray(self.get_encoding())
|
||||
try:
|
||||
return bytearray(self.get_encoding())
|
||||
except TypeError:
|
||||
raise CompilerError('cannot encode %s/%s' % (self, self.get_encoding()))
|
||||
|
||||
def check_args(self):
|
||||
""" Check the args match up with that specified in arg_format """
|
||||
@@ -787,6 +798,9 @@ class Instruction(object):
|
||||
def expand_merged(self):
|
||||
return [self]
|
||||
|
||||
def expanded_rounds(self):
|
||||
return 0
|
||||
|
||||
def get_new_args(self, size, subs):
|
||||
new_args = []
|
||||
for arg, f in zip(self.args, self.arg_format):
|
||||
@@ -864,6 +878,12 @@ class DirectMemoryInstruction(Instruction):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
class IndirectMemoryInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def get_direct(self, address):
|
||||
return self.direct(self.args[0], address, add_to_prog=False)
|
||||
|
||||
class ReadMemoryInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
@@ -875,7 +895,7 @@ class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
|
||||
__slots__ = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
if program.curr_tape.prevent_direct_memory_write:
|
||||
raise CompilerError('Direct memory writing prevented')
|
||||
raise CompilerError('Direct memory writing prevented in threads')
|
||||
super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
###
|
||||
|
||||
@@ -6,6 +6,7 @@ in particularly providing flow control and output.
|
||||
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc
|
||||
from Compiler.instructions import *
|
||||
from Compiler.util import tuplify,untuplify,is_zero
|
||||
from Compiler.allocator import RegintOptimizer
|
||||
from Compiler import instructions,instructions_base,comparison,program,util
|
||||
import inspect,math
|
||||
import random
|
||||
@@ -54,6 +55,9 @@ def set_instruction_type(function):
|
||||
return instruction_typed_function
|
||||
|
||||
|
||||
def _expand_to_print(val):
|
||||
return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val)
|
||||
|
||||
def print_str(s, *args):
|
||||
""" Print a string, with optional args for adding
|
||||
variables/registers with ``%s``. """
|
||||
@@ -90,7 +94,7 @@ def print_str(s, *args):
|
||||
elif isinstance(val, cfloat):
|
||||
val.print_float_plain()
|
||||
elif isinstance(val, (list, tuple, Array)):
|
||||
print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val)
|
||||
print_str(*_expand_to_print(val))
|
||||
else:
|
||||
try:
|
||||
val.output()
|
||||
@@ -127,6 +131,9 @@ def print_ln_if(cond, ss, *args):
|
||||
|
||||
print_ln_if(get_player_id() == 0, 'Player 0 here')
|
||||
"""
|
||||
print_str_if(cond, ss + '\n', *args)
|
||||
|
||||
def print_str_if(cond, ss, *args):
|
||||
if util.is_constant(cond):
|
||||
if cond:
|
||||
print_ln(ss, *args)
|
||||
@@ -138,9 +145,14 @@ def print_ln_if(cond, ss, *args):
|
||||
cond = cint.conv(cond)
|
||||
for i, s in enumerate(subs):
|
||||
if i != 0:
|
||||
args[i - 1].output_if(cond)
|
||||
if i == len(args):
|
||||
s += '\n'
|
||||
val = args[i - 1]
|
||||
try:
|
||||
val.output_if(cond)
|
||||
except:
|
||||
if isinstance(val, (list, tuple, Array)):
|
||||
print_str_if(cond, *_expand_to_print(val))
|
||||
else:
|
||||
print_str_if(cond, str(val))
|
||||
s += '\0' * ((-len(s)) % 4)
|
||||
while s:
|
||||
cond.print_if(s[:4])
|
||||
@@ -162,7 +174,15 @@ def print_ln_to(player, ss, *args):
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, personal):
|
||||
assert arg.player == player
|
||||
if util.is_constant(arg.player) ^ util.is_constant(player):
|
||||
match = False
|
||||
else:
|
||||
if util.is_constant(player):
|
||||
match = arg.player == player
|
||||
else:
|
||||
match = id(arg.player) == id(player)
|
||||
if not match:
|
||||
raise CompilerError('player mismatch in personal printing')
|
||||
new_args.append(arg._v)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
@@ -968,6 +988,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
del blocks[-n_to_merge + 1:]
|
||||
del get_tape().req_node.children[-1]
|
||||
merged.children = []
|
||||
RegintOptimizer().run(merged.instructions)
|
||||
get_tape().active_basicblock = merged
|
||||
else:
|
||||
req_node = get_tape().req_node.children[-1].nodes[0]
|
||||
@@ -1117,7 +1138,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
else:
|
||||
return loop_body(base + i)
|
||||
prog = get_program()
|
||||
threads = []
|
||||
thread_args = []
|
||||
if not util.is_zero(thread_rounds):
|
||||
tape = prog.new_tape(f, (0,), 'multithread')
|
||||
for i in range(n_threads - remainder):
|
||||
@@ -1125,7 +1146,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
args[remainder + i][0] = i * thread_rounds
|
||||
if len(mem_state):
|
||||
args[remainder + i][1] = mem_state.address
|
||||
threads.append(prog.run_tape(tape, remainder + i))
|
||||
thread_args.append((tape, remainder + i))
|
||||
if remainder:
|
||||
tape1 = prog.new_tape(f, (1,), 'multithread1')
|
||||
for i in range(remainder):
|
||||
@@ -1133,7 +1154,8 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
args[i][0] = (n_threads - remainder + i) * thread_rounds + i
|
||||
if len(mem_state):
|
||||
args[i][1] = mem_state.address
|
||||
threads.append(prog.run_tape(tape1, i))
|
||||
thread_args.append((tape1, i))
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
if state:
|
||||
@@ -1625,6 +1647,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
y = y.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
return y
|
||||
@@ -1659,7 +1682,7 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k, kappa)[::-1]
|
||||
suffixes = PreOR(bits)[::-1]
|
||||
suffixes = PreOR(bits, kappa)[::-1]
|
||||
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
|
||||
@@ -112,6 +112,12 @@ def lse_0_from_e_x(x, e_x):
|
||||
def lse_0(x):
|
||||
return lse_0_from_e_x(x, exp(x))
|
||||
|
||||
def approx_lse_0(x, n=3):
|
||||
assert n != 5
|
||||
a = x < -0.5
|
||||
b = x > 0.5
|
||||
return a.if_else(0, b.if_else(x, 0.5 * (x + 0.5) ** 2)) - x
|
||||
|
||||
def relu_prime(x):
|
||||
""" ReLU derivative. """
|
||||
return (0 <= x)
|
||||
@@ -145,10 +151,19 @@ class Tensor(MultiArray):
|
||||
kwargs['alloc'] = False
|
||||
super(Tensor, self).__init__(*args, **kwargs)
|
||||
|
||||
def input_from(self, *args, **kwargs):
|
||||
self.alloc()
|
||||
super(Tensor, self).input_from(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, *args):
|
||||
self.alloc()
|
||||
return super(Tensor, self).__getitem__(*args)
|
||||
|
||||
class Layer:
|
||||
n_threads = 1
|
||||
inputs = []
|
||||
input_bias = True
|
||||
thetas = lambda self: ()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -193,14 +208,13 @@ class Output(Layer):
|
||||
self.approx = approx
|
||||
|
||||
nablas = lambda self: ()
|
||||
thetas = lambda self: ()
|
||||
reset = lambda self: None
|
||||
|
||||
def divisor(self, divisor, size):
|
||||
return cfix(1.0 / divisor, size=size)
|
||||
|
||||
def forward(self, batch):
|
||||
if self.approx:
|
||||
if self.approx == 5:
|
||||
self.l.write(999)
|
||||
return
|
||||
N = len(batch)
|
||||
@@ -209,10 +223,12 @@ class Output(Layer):
|
||||
def _(base, size):
|
||||
x = self.X.get_vector(base, size)
|
||||
y = self.Y.get(batch.get_vector(base, size))
|
||||
if self.approx:
|
||||
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
|
||||
return
|
||||
e_x = exp(-x)
|
||||
self.e_x.assign(e_x, base)
|
||||
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
||||
e_x = self.e_x.get_vector(0, N)
|
||||
self.l.write(sum(lse) * \
|
||||
self.divisor(N, 1))
|
||||
|
||||
@@ -323,7 +339,7 @@ class Dense(DenseBase):
|
||||
|
||||
self.X = MultiArray([N, d, d_in], sfix)
|
||||
self.Y = MultiArray([N, d, d_out], sfix)
|
||||
self.W = sfix.Matrix(d_in, d_out)
|
||||
self.W = Tensor([d_in, d_out], sfix)
|
||||
self.b = sfix.Array(d_out)
|
||||
|
||||
self.nabla_Y = MultiArray([N, d, d_out], sfix)
|
||||
@@ -546,12 +562,12 @@ class MaxPool(NoVariableLayer):
|
||||
for x in strides, ksize:
|
||||
for i in 0, 3:
|
||||
assert x[i] == 1
|
||||
self.X = MultiArray(shape, sfix)
|
||||
self.X = Tensor(shape, sfix)
|
||||
if padding == 'SAME':
|
||||
output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)]
|
||||
else:
|
||||
output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)]
|
||||
self.Y = MultiArray(output_shape, sfix)
|
||||
self.Y = Tensor(output_shape, sfix)
|
||||
self.strides = strides
|
||||
self.ksize = ksize
|
||||
|
||||
@@ -741,6 +757,7 @@ class ConvBase(BaseLayer):
|
||||
use_conv2ds = False
|
||||
temp_weights = None
|
||||
temp_inputs = None
|
||||
thetas = lambda self: (self.weights, self.bias)
|
||||
|
||||
@classmethod
|
||||
def init_temp(cls, layers):
|
||||
@@ -780,7 +797,7 @@ class ConvBase(BaseLayer):
|
||||
self.weight_squant = self.new_squant()
|
||||
self.bias_squant = self.new_squant()
|
||||
|
||||
self.weights = MultiArray(weight_shape, self.weight_squant)
|
||||
self.weights = Tensor(weight_shape, self.weight_squant)
|
||||
self.bias = Array(output_shape[-1], self.bias_squant)
|
||||
|
||||
self.unreduced = Tensor(self.output_shape, sint)
|
||||
@@ -1158,7 +1175,8 @@ class Optimizer:
|
||||
layer.last_used = list(filter(lambda x: x not in used, layer.inputs))
|
||||
used.update(layer.inputs)
|
||||
|
||||
def forward(self, N=None, batch=None, keep_intermediate=True):
|
||||
def forward(self, N=None, batch=None, keep_intermediate=True,
|
||||
model_from=None):
|
||||
""" Compute graph.
|
||||
|
||||
:param N: batch size (used if batch not given)
|
||||
@@ -1172,12 +1190,16 @@ class Optimizer:
|
||||
if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None:
|
||||
layer._X.address = layer.inputs[0].Y.address
|
||||
layer.Y.alloc()
|
||||
if model_from is not None:
|
||||
layer.input_from(model_from)
|
||||
break_point()
|
||||
layer.forward(batch=batch)
|
||||
break_point()
|
||||
if not keep_intermediate:
|
||||
for l in layer.last_used:
|
||||
l.Y.delete()
|
||||
for theta in layer.thetas():
|
||||
theta.delete()
|
||||
|
||||
def eval(self, data):
|
||||
""" Compute evaluation after training. """
|
||||
@@ -1207,6 +1229,7 @@ class Optimizer:
|
||||
else:
|
||||
N = self.layers[0].N
|
||||
i = MemValue(0)
|
||||
n_iterations = MemValue(0)
|
||||
@do_while
|
||||
def _():
|
||||
if self.X_by_label is None:
|
||||
@@ -1216,6 +1239,7 @@ class Optimizer:
|
||||
n = N // len(self.X_by_label)
|
||||
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
||||
self.X_by_label) / n))
|
||||
n_iterations.iadd(n_per_epoch)
|
||||
print('%d runs per epoch' % n_per_epoch)
|
||||
indices_by_label = []
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
@@ -1235,7 +1259,7 @@ class Optimizer:
|
||||
self.backward(batch=batch)
|
||||
self.update(i)
|
||||
loss = self.layers[-1].l
|
||||
if self.report_loss and not self.layers[-1].approx:
|
||||
if self.report_loss and self.layers[-1].approx != 5:
|
||||
print_ln('loss after epoch %s: %s', i, loss.reveal())
|
||||
else:
|
||||
print_ln('done with epoch %s', i)
|
||||
@@ -1245,7 +1269,7 @@ class Optimizer:
|
||||
if self.tol > 0:
|
||||
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
|
||||
return res
|
||||
print_ln('finished after %s epochs', i)
|
||||
print_ln('finished after %s epochs and %s iterations', i, n_iterations)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, layers, n_epochs):
|
||||
|
||||
@@ -3,6 +3,7 @@ from Compiler.exceptions import *
|
||||
from Compiler.instructions_base import RegType
|
||||
import Compiler.instructions
|
||||
import Compiler.instructions_base
|
||||
import Compiler.instructions_base as inst_base
|
||||
from . import compilerLib
|
||||
from . import allocator as al
|
||||
from . import util
|
||||
@@ -40,22 +41,20 @@ class Program(object):
|
||||
|
||||
These are created by executing a file containing appropriate instructions
|
||||
and threads. """
|
||||
def __init__(self, args, options, param=-1):
|
||||
def __init__(self, args, options):
|
||||
self.options = options
|
||||
self.verbose = options.verbose
|
||||
self.args = args
|
||||
self.init_names(args)
|
||||
self.P = P_VALUES[param]
|
||||
self.param = param
|
||||
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
|
||||
if sum(x != 0 for x in(options.ring, options.field,
|
||||
options.binary)) > 1:
|
||||
raise CompilerError('can only use one out of -p, -B, -R, -F')
|
||||
raise CompilerError('can only use one out of -B, -R, -F')
|
||||
if options.ring:
|
||||
self.bit_length = int(options.ring) - 1
|
||||
else:
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if not self.bit_length:
|
||||
self.bit_length = BIT_LENGTHS[param]
|
||||
self.bit_length = 64
|
||||
print('Default bit length:', self.bit_length)
|
||||
self.security = 40
|
||||
print('Default security parameter:', self.security)
|
||||
@@ -67,7 +66,7 @@ class Program(object):
|
||||
self._curr_tape = None
|
||||
self.DEBUG = False
|
||||
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
|
||||
self.free_mem_blocks = defaultdict(lambda: defaultdict(set))
|
||||
self.free_mem_blocks = defaultdict(al.BlockAllocator)
|
||||
self.allocated_mem_blocks = {}
|
||||
self.saved = 0
|
||||
self.req_num = None
|
||||
@@ -100,6 +99,7 @@ class Program(object):
|
||||
self._edabit = options.edabit
|
||||
self._split = False
|
||||
self._square = False
|
||||
self._always_raw = False
|
||||
Program.prog = self
|
||||
|
||||
def get_args(self):
|
||||
@@ -165,20 +165,28 @@ class Program(object):
|
||||
return tape_index
|
||||
|
||||
def run_tape(self, tape_index, arg):
|
||||
return self.run_tapes([[tape_index, arg]])[0]
|
||||
|
||||
def run_tapes(self, args):
|
||||
if self.curr_tape is not self.tapes[0]:
|
||||
raise CompilerError('Compiler does not support ' \
|
||||
'recursive spawning of threads')
|
||||
if self.free_threads:
|
||||
thread_number = min(self.free_threads)
|
||||
self.free_threads.remove(thread_number)
|
||||
else:
|
||||
thread_number = self.n_threads
|
||||
self.n_threads += 1
|
||||
thread_numbers = []
|
||||
while len(thread_numbers) < len(args):
|
||||
if self.free_threads:
|
||||
thread_numbers.append(min(self.free_threads))
|
||||
self.free_threads.remove(thread_numbers[-1])
|
||||
else:
|
||||
thread_numbers.append(self.n_threads)
|
||||
self.n_threads += 1
|
||||
self.curr_tape.start_new_basicblock(name='pre-run_tape')
|
||||
Compiler.instructions.run_tape(thread_number, arg, tape_index)
|
||||
Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in
|
||||
zip(thread_numbers, args)), []))
|
||||
self.curr_tape.start_new_basicblock(name='post-run_tape')
|
||||
self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree)
|
||||
return thread_number
|
||||
for arg in args:
|
||||
self.curr_tape.req_node.children.append(
|
||||
self.tapes[arg[0]].req_tree)
|
||||
return thread_numbers
|
||||
|
||||
def join_tape(self, thread_number):
|
||||
self.curr_tape.start_new_basicblock(name='pre-join_tape')
|
||||
@@ -250,19 +258,9 @@ class Program(object):
|
||||
mem_type = mem_type.reg_type
|
||||
elif reg_type is not None:
|
||||
self.types[mem_type] = reg_type
|
||||
block_size = 0
|
||||
blocks = self.free_mem_blocks[mem_type]
|
||||
if len(blocks[size]) > 0:
|
||||
block_size = size
|
||||
else:
|
||||
for block_size, addresses in blocks.items():
|
||||
if block_size >= size and len(addresses) > 0:
|
||||
break
|
||||
else:
|
||||
block_size = 0
|
||||
if block_size >= size:
|
||||
addr = self.free_mem_blocks[mem_type][block_size].pop()
|
||||
self.free_mem_blocks[mem_type][block_size - size].add(addr + size)
|
||||
addr = blocks.pop(size)
|
||||
if addr is not None:
|
||||
self.saved += size
|
||||
else:
|
||||
addr = self.allocated_mem[mem_type]
|
||||
@@ -278,7 +276,7 @@ class Program(object):
|
||||
is not self.curr_tape.basicblocks[0].alloc_pool:
|
||||
raise CompilerError('Cannot free memory within function block')
|
||||
size = self.allocated_mem_blocks.pop((addr,mem_type))
|
||||
self.free_mem_blocks[mem_type][size].add(addr)
|
||||
self.free_mem_blocks[mem_type].push(addr, size)
|
||||
|
||||
def finalize_memory(self):
|
||||
from . import library
|
||||
@@ -341,6 +339,20 @@ class Program(object):
|
||||
else:
|
||||
self._square = change
|
||||
|
||||
def always_raw(self, change=None):
|
||||
if change is None:
|
||||
return self._always_raw
|
||||
else:
|
||||
self._always_raw = change
|
||||
|
||||
def options_from_args(self):
|
||||
if 'trunc_pr' in self.args:
|
||||
self.use_trunc_pr = True
|
||||
if 'split' in self.args or 'split3' in self.args:
|
||||
self.use_split(3)
|
||||
if 'raw' in self.args:
|
||||
self.always_raw(True)
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
def __init__(self, name, program):
|
||||
@@ -445,13 +457,14 @@ class Tape:
|
||||
instructions = self.instructions
|
||||
for inst in instructions:
|
||||
inst.add_usage(req_node)
|
||||
req_node.num['all', 'round'] = self.n_rounds
|
||||
req_node.num['all', 'inv'] = self.n_to_merge
|
||||
req_node.num['all', 'round'] += self.n_rounds
|
||||
req_node.num['all', 'inv'] += self.n_to_merge
|
||||
|
||||
def expand_cisc(self):
|
||||
new_instructions = []
|
||||
for inst in self.instructions:
|
||||
new_instructions.extend(inst.expand_merged())
|
||||
self.n_rounds += inst.expanded_rounds()
|
||||
self.instructions = new_instructions
|
||||
|
||||
def __str__(self):
|
||||
@@ -503,9 +516,6 @@ class Tape:
|
||||
def unpurged(function):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.purged:
|
||||
if self.program.verbose:
|
||||
print('%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name))
|
||||
return
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
@@ -585,7 +595,8 @@ class Tape:
|
||||
reg_counts = self.count_regs()
|
||||
if not options.noreallocate:
|
||||
if self.program.verbose:
|
||||
print('Tape register usage:', dict(reg_counts))
|
||||
print('Tape register usage before re-allocation:',
|
||||
dict(reg_counts))
|
||||
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
|
||||
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
|
||||
print('Re-allocating...')
|
||||
@@ -613,6 +624,8 @@ class Tape:
|
||||
alloc_loop(block.exit_block.scope)
|
||||
allocator.process(block.instructions, block.alloc_pool)
|
||||
allocator.finalize(options)
|
||||
if self.program.verbose:
|
||||
print('Tape register usage:', dict(allocator.usage))
|
||||
|
||||
# offline data requirements
|
||||
if self.program.verbose:
|
||||
@@ -847,10 +860,6 @@ class Tape:
|
||||
if t == 'p':
|
||||
self.req_bit_length[t] = max(bit_length + 1, \
|
||||
self.req_bit_length[t])
|
||||
if self.program.param != -1 and bit_length >= self.program.param:
|
||||
raise CompilerError('Inadequate bit length %d for prime, ' \
|
||||
'program requires %d bits' % \
|
||||
(self.program.param, self.req_bit_length['p']))
|
||||
else:
|
||||
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
|
||||
|
||||
@@ -862,6 +871,7 @@ class Tape:
|
||||
__slots__ = ["reg_type", "program", "absolute_i", "relative_i", \
|
||||
"size", "vector", "vectorbase", "caller", \
|
||||
"can_eliminate", "duplicates"]
|
||||
maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1
|
||||
|
||||
def __init__(self, reg_type, program, size=None, i=None):
|
||||
""" Creates a new register.
|
||||
@@ -875,6 +885,8 @@ class Tape:
|
||||
self.program = program
|
||||
if size is None:
|
||||
size = Compiler.instructions_base.get_global_vector_size()
|
||||
if size is not None and size > self.maximum_size:
|
||||
raise CompilerError('vector too large')
|
||||
self.size = size
|
||||
self.vectorbase = self
|
||||
self.relative_i = 0
|
||||
|
||||
@@ -147,7 +147,8 @@ def vectorize(operation):
|
||||
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
|
||||
and not isinstance(args[0], bits) \
|
||||
and args[0].size != self.size:
|
||||
raise CompilerError('Different vector sizes of operands')
|
||||
raise CompilerError('Different vector sizes of operands: %d/%d'
|
||||
% (self.size, args[0].size))
|
||||
set_global_vector_size(self.size)
|
||||
res = operation(self, *args, **kwargs)
|
||||
reset_global_vector_size()
|
||||
@@ -789,7 +790,7 @@ class cint(_clear, _int):
|
||||
bit_length = 1 + int(math.ceil(math.log(abs(val))))
|
||||
if program.options.ring:
|
||||
assert(bit_length <= int(program.options.ring))
|
||||
elif program.param != -1 or program.options.field:
|
||||
elif program.options.field:
|
||||
program.curr_tape.require_bit_length(bit_length)
|
||||
if self.in_immediate_range(val):
|
||||
ldi(self, val)
|
||||
@@ -1731,6 +1732,15 @@ class sint(_secret, _int):
|
||||
""" Secret random n-bit number according to security model.
|
||||
|
||||
:param bits: compile-time integer (int) """
|
||||
if program.use_split() == 3:
|
||||
tmp = sint()
|
||||
randoms(tmp, bits)
|
||||
x = tmp.split_to_two_summands(bits, True)
|
||||
overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \
|
||||
sint.conv(x[0][-1])
|
||||
return tmp - (overflow << bits)
|
||||
elif program.use_edabit():
|
||||
return sint.get_edabit(bits, True)[0]
|
||||
res = sint()
|
||||
comparison.PRandInt(res, bits)
|
||||
return res
|
||||
@@ -2068,6 +2078,37 @@ class sint(_secret, _int):
|
||||
def two_power(n):
|
||||
return floatingpoint.two_power(n)
|
||||
|
||||
def split_to_n_summands(self, length, n):
|
||||
from .GC.types import sbits
|
||||
from .GC.instructions import split
|
||||
columns = [[sbits.get_type(self.size)()
|
||||
for i in range(n)] for i in range(length)]
|
||||
split(n, self, *sum(columns, []))
|
||||
return columns
|
||||
|
||||
def split_to_two_summands(self, length, get_carry=False):
|
||||
n = program.use_split()
|
||||
assert n
|
||||
columns = self.split_to_n_summands(length, n)
|
||||
return _bitint.wallace_tree_without_finish(columns, get_carry)
|
||||
|
||||
@vectorize
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Result potentially written to ``Player-Data/Private-Output-P<player>.``
|
||||
|
||||
:param player: public integer (int/regint/cint):
|
||||
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
|
||||
"""
|
||||
if not util.is_constant(player) or self.size > 1:
|
||||
secret_mask = sint()
|
||||
player_mask = cint()
|
||||
inputmaskreg(secret_mask, player_mask, player)
|
||||
return personal(player,
|
||||
(self + secret_mask).reveal() - player_mask)
|
||||
else:
|
||||
return super(sint, self).reveal_to(player)
|
||||
|
||||
class sgf2n(_secret, _gf2n):
|
||||
""" Secret GF(2^n) value. """
|
||||
__slots__ = []
|
||||
@@ -2255,7 +2296,8 @@ class _bitint(object):
|
||||
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
|
||||
get_carry=False):
|
||||
lower = []
|
||||
for (ai,bi) in zip(a,b):
|
||||
a, b = a[:], b[:]
|
||||
for (ai, bi) in zip(a[:], b[:]):
|
||||
if is_zero(ai) or is_zero(bi):
|
||||
lower.append(ai + bi)
|
||||
a.pop(0)
|
||||
@@ -2404,6 +2446,7 @@ class _bitint(object):
|
||||
@classmethod
|
||||
def wallace_tree_without_finish(cls, columns, get_carry=True):
|
||||
self = cls
|
||||
columns = [col[:] for col in columns]
|
||||
while max(len(c) for c in columns) > 2:
|
||||
new_columns = [[] for i in range(len(columns) + 1)]
|
||||
for i,col in enumerate(columns):
|
||||
@@ -2439,11 +2482,12 @@ class _bitint(object):
|
||||
raise CompilerError('Unclear subtraction')
|
||||
a = self.bit_decompose()
|
||||
b = util.bit_decompose(other, self.n_bits)
|
||||
d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)]
|
||||
d = [(reduce(util.bit_xor, (ai, bi, 1)), (1 - ai) * bi)
|
||||
for (ai,bi) in zip(a,b)]
|
||||
borrow = lambda y,x,*args: \
|
||||
(x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1]))
|
||||
borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1]
|
||||
return self.compose(ai + bi + borrow \
|
||||
return self.compose(reduce(util.bit_xor, (ai, bi, borrow)) \
|
||||
for (ai,bi,borrow) in zip(a,b,borrows))
|
||||
|
||||
def __rsub__(self, other):
|
||||
@@ -2715,10 +2759,15 @@ class cfix(_number, _structure):
|
||||
def set_precision(cls, f, k = None):
|
||||
""" Set the precision of the integer representation. Note that some
|
||||
operations are undefined when the precision of :py:class:`sfix` and
|
||||
:py:class:`cfix` differs.
|
||||
:py:class:`cfix` differs. The initial defaults are chosen to
|
||||
allow the best optimization of probabilistic truncation in
|
||||
computation modulo 2^64 (2*k < 64). Generally, 2*k must be at
|
||||
most the integer length for rings and at most m-s-1 for
|
||||
computation modulo an m-bit prime and statistical security s
|
||||
(default 40).
|
||||
|
||||
:param f: bit length of decimal part
|
||||
:param k: whole bit length of fixed point, defaults to twice :py:obj:`f`.
|
||||
:param f: bit length of decimal part (initial default 16)
|
||||
:param k: whole bit length of fixed point, defaults to twice :py:obj:`f` if not given (initial default 31)
|
||||
|
||||
"""
|
||||
cls.f = f
|
||||
@@ -2789,6 +2838,10 @@ class cfix(_number, _structure):
|
||||
else:
|
||||
raise CompilerError('cannot initialize cfix with %s' % v)
|
||||
|
||||
def __iter__(self):
|
||||
for x in self.v:
|
||||
yield type(self)(x, self.k, self.f)
|
||||
|
||||
@vectorize
|
||||
def load_int(self, v):
|
||||
self.v = cint(v) * (2 ** self.f)
|
||||
@@ -3141,7 +3194,6 @@ class _fix(_single):
|
||||
""" Secret fixed point type. """
|
||||
__slots__ = ['v', 'f', 'k', 'size']
|
||||
|
||||
@classmethod
|
||||
def set_precision(cls, f, k = None):
|
||||
cls.f = f
|
||||
# default bitlength = 2*precision
|
||||
@@ -3152,6 +3204,7 @@ class _fix(_single):
|
||||
raise CompilerError('bit length cannot be less than precision')
|
||||
cls.k = k
|
||||
set_precision.__doc__ = cfix.set_precision.__doc__
|
||||
set_precision = classmethod(set_precision)
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
@@ -3377,9 +3430,10 @@ class sfix(_fix):
|
||||
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Raw representation written to ``Player-Data/Private-Output-P<player>``
|
||||
Raw representation possibly written to
|
||||
``Player-Data/Private-Output-P<player>.``
|
||||
|
||||
:param player: int
|
||||
:param player: public integer (int/regint/cint)
|
||||
:returns: value to be used with :py:func:`Compiler.library.print_ln_to`
|
||||
"""
|
||||
return personal(player, cfix(self.v.reveal_to(player)._v,
|
||||
@@ -3419,17 +3473,8 @@ class unreduced_sfix(_single):
|
||||
|
||||
sfix.unreduced_type = unreduced_sfix
|
||||
|
||||
# this is for 20 bit decimal precision
|
||||
# with 40 bitlength of entire number
|
||||
# these constants have been chosen for multiplications to fit in 128 bit prime field
|
||||
# (precision n1) 41 + (precision n2) 41 + (stat_sec) 40 = 82 + 40 = 122 <= 128
|
||||
# with statistical security of 40
|
||||
|
||||
fixed_lower = 20
|
||||
fixed_upper = 40
|
||||
|
||||
sfix.set_precision(fixed_lower, fixed_upper)
|
||||
cfix.set_precision(fixed_lower, fixed_upper)
|
||||
sfix.set_precision(16, 31)
|
||||
cfix.set_precision(16, 31)
|
||||
|
||||
class squant(_single):
|
||||
""" Quantization as in ArXiv:1712.05877v1 """
|
||||
@@ -4222,8 +4267,10 @@ class Array(object):
|
||||
:param value: convertible to basic type """
|
||||
if conv:
|
||||
value = self.value_type.conv(value)
|
||||
if value.size != 1:
|
||||
raise CompilerError('cannot assign vector to all elements')
|
||||
mem_value = MemValue(value)
|
||||
self.address = MemValue(self.address)
|
||||
self.address = MemValue.if_necessary(self.address)
|
||||
n_threads = 8 if use_threads and len(self) > 2**20 else 1
|
||||
@library.for_range_multithread(n_threads, 1024, len(self))
|
||||
def f(i):
|
||||
@@ -4242,7 +4289,7 @@ class Array(object):
|
||||
|
||||
def get(self, indices):
|
||||
return self.value_type.load_mem(
|
||||
regint(self.address, size=len(indices)) + indices,
|
||||
regint.inc(len(indices), self.address, 0) + indices,
|
||||
size=len(indices))
|
||||
|
||||
def expand_to_vector(self, index, size):
|
||||
@@ -4258,13 +4305,13 @@ class Array(object):
|
||||
""" Fill with inputs from player if supported by type.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
if raw:
|
||||
if raw or program.always_raw():
|
||||
input_from = self.value_type.get_raw_input_from
|
||||
else:
|
||||
input_from = self.value_type.get_input_from
|
||||
try:
|
||||
self.assign(input_from(player, size=len(self)))
|
||||
except:
|
||||
except TypeError:
|
||||
@library.for_range_opt(len(self), budget=budget)
|
||||
def _(i):
|
||||
self[i] = input_from(player)
|
||||
@@ -4318,6 +4365,12 @@ class Array(object):
|
||||
:returns: Array of relevant clear type. """
|
||||
return Array.create_from(x.reveal() for x in self)
|
||||
|
||||
def reveal_list(self):
|
||||
""" Reveal as list. """
|
||||
return list(self.get_vector().reveal())
|
||||
|
||||
reveal_nested = reveal_list
|
||||
|
||||
sint.dynamic_array = Array
|
||||
sgf2n.dynamic_array = Array
|
||||
|
||||
@@ -4454,16 +4507,17 @@ class SubMultiArray(object):
|
||||
""" Fill with inputs from player if supported by type.
|
||||
|
||||
:param player: public (regint/cint/int) """
|
||||
budget = budget or 2 ** 21
|
||||
budget = budget or Tape.Register.maximum_size
|
||||
if (self.total_size() < budget) and \
|
||||
self.value_type.n_elements() == 1:
|
||||
if raw:
|
||||
if raw or program.always_raw():
|
||||
input_from = self.value_type.get_raw_input_from
|
||||
else:
|
||||
input_from = self.value_type.get_input_from
|
||||
self.assign_vector(input_from(player, size=self.total_size()))
|
||||
else:
|
||||
@library.for_range_opt(self.sizes[0], budget=budget)
|
||||
@library.for_range_opt(self.sizes[0],
|
||||
budget=budget / self[0].total_size())
|
||||
def _(i):
|
||||
self[i].input_from(player, budget=budget, raw=raw)
|
||||
|
||||
@@ -4687,6 +4741,21 @@ class SubMultiArray(object):
|
||||
library.break_point()
|
||||
return res
|
||||
|
||||
def reveal_list(self):
|
||||
""" Reveal as list. """
|
||||
return list(self.get_vector().reveal())
|
||||
|
||||
def reveal_nested(self):
|
||||
""" Reveal as nested list. """
|
||||
flat = iter(self.get_vector().reveal())
|
||||
res = []
|
||||
def f(sizes):
|
||||
if len(sizes) == 1:
|
||||
return [next(flat) for i in range(sizes[0])]
|
||||
else:
|
||||
return [f(sizes[1:]) for i in range(sizes[0])]
|
||||
return f(self.sizes)
|
||||
|
||||
class MultiArray(SubMultiArray):
|
||||
""" Multidimensional array. """
|
||||
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
|
||||
@@ -4836,10 +4905,12 @@ class MemValue(_mem):
|
||||
self.value_type = type(value)
|
||||
self.deleted = False
|
||||
if address is None:
|
||||
self.address = self.value_type.malloc(1)
|
||||
self.address = self.value_type.malloc(value.size)
|
||||
self.size = value.size
|
||||
self.write(value)
|
||||
else:
|
||||
self.address = address
|
||||
self.size = 1
|
||||
|
||||
def delete(self):
|
||||
self.value_type.free(self.address)
|
||||
@@ -4869,6 +4940,8 @@ class MemValue(_mem):
|
||||
elif isinstance(value, int):
|
||||
self.register = self.value_type(value)
|
||||
else:
|
||||
if value.size != self.size:
|
||||
raise CompilerError('size mismatch')
|
||||
self.register = value
|
||||
if not isinstance(self.register, self.value_type):
|
||||
raise CompilerError('Mismatch in register type, cannot write \
|
||||
|
||||
@@ -46,8 +46,10 @@ def right_shift(a, b, bits):
|
||||
else:
|
||||
return a.right_shift(b, bits)
|
||||
|
||||
def bit_decompose(a, bits):
|
||||
def bit_decompose(a, bits=None):
|
||||
if isinstance(a, int):
|
||||
if bits is None:
|
||||
bits = int_len(a)
|
||||
return [int((a >> i) & 1) for i in range(bits)]
|
||||
else:
|
||||
return a.bit_decompose(bits)
|
||||
@@ -122,6 +124,15 @@ def or_op(a, b):
|
||||
|
||||
OR = or_op
|
||||
|
||||
def bit_xor(a, b):
|
||||
if is_constant(a):
|
||||
if is_constant(b):
|
||||
return a ^ b
|
||||
else:
|
||||
return b.bit_xor(a)
|
||||
else:
|
||||
return a.bit_xor(b)
|
||||
|
||||
def pow2(bits):
|
||||
powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)]
|
||||
return tree_reduce(operator.mul, powers)
|
||||
|
||||
@@ -92,8 +92,8 @@ bool Proof::check_bounds(T& z, X& t, int i) const
|
||||
unsigned int j,k;
|
||||
|
||||
// Check Bound 1 and Bound 2
|
||||
AbsoluteBoundChecker<fixint<2>> plain_checker(plain_check * n_proofs);
|
||||
AbsoluteBoundChecker<fixint<2>> rand_checker(rand_check * n_proofs);
|
||||
AbsoluteBoundChecker<fixint<GFP_MOD_SZ>> plain_checker(plain_check * n_proofs);
|
||||
AbsoluteBoundChecker<fixint<GFP_MOD_SZ>> rand_checker(rand_check * n_proofs);
|
||||
for (j=0; j<phim; j++)
|
||||
{
|
||||
auto& te = z[j];
|
||||
|
||||
@@ -252,8 +252,10 @@ void SummingEncCommit<FD>::create_more()
|
||||
this->c.unpack(ciphertexts, this->pk);
|
||||
commitments.unpack(ciphertexts, this->pk);
|
||||
|
||||
#ifdef VERBOSE_HE
|
||||
cout << "Tree-wise sum of ciphertexts with "
|
||||
<< 1e-9 * ciphertexts.get_length() << " GB" << endl;
|
||||
#endif
|
||||
this->timers["Exchanging ciphertexts"].start();
|
||||
tree_sum.run(this->c, P);
|
||||
tree_sum.run(commitments, P);
|
||||
|
||||
@@ -13,9 +13,12 @@ template <class T>
|
||||
class ArgIter
|
||||
{
|
||||
vector<int>::const_iterator it;
|
||||
vector<int>::const_iterator end;
|
||||
|
||||
public:
|
||||
ArgIter(const vector<int>::const_iterator it) : it(it)
|
||||
ArgIter(const vector<int>::const_iterator it,
|
||||
const vector<int>::const_iterator end) :
|
||||
it(it), end(end)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -27,8 +30,10 @@ public:
|
||||
ArgIter<T> operator++()
|
||||
{
|
||||
auto res = it;
|
||||
it += T::n;
|
||||
return res;
|
||||
it += T(res).n;
|
||||
if (it > end)
|
||||
throw runtime_error("wrong number of args");
|
||||
return {res, end};
|
||||
}
|
||||
|
||||
bool operator!=(const ArgIter<T>& other)
|
||||
@@ -46,18 +51,16 @@ public:
|
||||
ArgList(const vector<int>& args) :
|
||||
args(args)
|
||||
{
|
||||
if (args.size() % T::n != 0)
|
||||
throw runtime_error("wrong number of args");
|
||||
}
|
||||
|
||||
ArgIter<T> begin()
|
||||
{
|
||||
return args.begin();
|
||||
return {args.begin(), args.end()};
|
||||
}
|
||||
|
||||
ArgIter<T> end()
|
||||
{
|
||||
return args.end();
|
||||
return {args.end(), args.end()};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -81,11 +84,12 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class InputArgList : public ArgList<InputArgs>
|
||||
template<class T>
|
||||
class InputArgListBase : public ArgList<T>
|
||||
{
|
||||
public:
|
||||
InputArgList(const vector<int>& args) :
|
||||
ArgList<InputArgs>(args)
|
||||
InputArgListBase(const vector<int>& args) :
|
||||
ArgList<T>(args)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -97,7 +101,55 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
int n_input_bits()
|
||||
{
|
||||
int res = 0;
|
||||
for (auto x : *this)
|
||||
res += x.n_bits;
|
||||
return res;
|
||||
}
|
||||
|
||||
int n_interactive_inputs_from_me(int my_num);
|
||||
};
|
||||
|
||||
class InputArgList : public InputArgListBase<InputArgs>
|
||||
{
|
||||
public:
|
||||
InputArgList(const vector<int>& args) :
|
||||
InputArgListBase<InputArgs>(args)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
class InputVecArgs
|
||||
{
|
||||
public:
|
||||
int from;
|
||||
int n;
|
||||
int& n_bits;
|
||||
int& n_shift;
|
||||
int params[2];
|
||||
vector<int> dest;
|
||||
|
||||
InputVecArgs(vector<int>::const_iterator it) : n_bits(params[0]), n_shift(params[1])
|
||||
{
|
||||
n = *it++;
|
||||
n_bits = n - 3;
|
||||
n_shift = *it++;
|
||||
from = *it++;
|
||||
dest.resize(n);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
dest[i] = *it++;
|
||||
}
|
||||
};
|
||||
|
||||
class InputVecArgList : public InputArgListBase<InputVecArgs>
|
||||
{
|
||||
public:
|
||||
InputVecArgList(const vector<int>& args) :
|
||||
InputArgListBase<InputVecArgs>(args)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* GC_ARGTUPLES_H_ */
|
||||
|
||||
@@ -12,9 +12,8 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
int FakeSecret::default_length = 128;
|
||||
|
||||
ostream& FakeSecret::out = cout;
|
||||
SwitchableOutput FakeSecret::out;
|
||||
const int FakeSecret::default_length;
|
||||
|
||||
void FakeSecret::load_clear(int n, const Integer& x)
|
||||
{
|
||||
@@ -60,12 +59,9 @@ void FakeSecret::store_clear_in_dynamic(Memory<DynamicType>& mem,
|
||||
void FakeSecret::ands(Processor<FakeSecret>& processor,
|
||||
const vector<int>& regs)
|
||||
{
|
||||
processor.check_args(regs, 4);
|
||||
for (size_t i = 0; i < regs.size(); i += 4)
|
||||
processor.S[regs[i + 1]] = processor.S[regs[i + 2]].a & processor.S[regs[i + 3]].a;
|
||||
processor.ands(regs);
|
||||
}
|
||||
|
||||
|
||||
void FakeSecret::trans(Processor<FakeSecret>& processor, int n_outputs,
|
||||
const vector<int>& args)
|
||||
{
|
||||
@@ -82,15 +78,13 @@ FakeSecret FakeSecret::input(GC::Processor<FakeSecret>& processor, const InputAr
|
||||
return input(args.from, processor.get_input(args.params), args.n_bits);
|
||||
}
|
||||
|
||||
FakeSecret FakeSecret::input(int from, const int128& input, int n_bits)
|
||||
FakeSecret FakeSecret::input(int from, word input, int n_bits)
|
||||
{
|
||||
(void)from;
|
||||
(void)n_bits;
|
||||
FakeSecret res;
|
||||
res.a = ((__uint128_t)input.get_upper() << 64) + input.get_lower();
|
||||
if (res.a > ((__uint128_t)1 << n_bits))
|
||||
if (n_bits < 64 and input > word(1) << n_bits)
|
||||
throw out_of_range("input too large");
|
||||
return res;
|
||||
return input;
|
||||
}
|
||||
|
||||
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
|
||||
@@ -99,7 +93,7 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,
|
||||
if (repeat)
|
||||
return andrs(n, x, y);
|
||||
else
|
||||
throw runtime_error("call static FakeSecret::ands()");
|
||||
*this = BitVec(x & y).mask(n);
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -12,8 +12,14 @@
|
||||
#include "GC/ArgTuples.h"
|
||||
|
||||
#include "Math/gf2nlong.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "Protocols/FakePrep.h"
|
||||
#include "Protocols/FakeMC.h"
|
||||
#include "Protocols/FakeProtocol.h"
|
||||
#include "Protocols/FakeInput.h"
|
||||
#include "Protocols/ShareInterface.h"
|
||||
|
||||
#include <random>
|
||||
#include <fstream>
|
||||
@@ -26,21 +32,41 @@ class Processor;
|
||||
template <class T>
|
||||
class Machine;
|
||||
|
||||
class FakeSecret
|
||||
class FakeSecret : public ShareInterface, public BitVec
|
||||
{
|
||||
__uint128_t a;
|
||||
|
||||
public:
|
||||
typedef FakeSecret DynamicType;
|
||||
typedef Memory<FakeSecret> DynamicMemory;
|
||||
|
||||
typedef BitVec mac_key_type;
|
||||
typedef BitVec clear;
|
||||
typedef BitVec open_type;
|
||||
|
||||
typedef FakeSecret part_type;
|
||||
typedef FakeSecret small_type;
|
||||
typedef NoShare bit_type;
|
||||
|
||||
typedef FakePrep<FakeSecret> LivePrep;
|
||||
typedef FakeMC<FakeSecret> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef MC Direct_MC;
|
||||
typedef FakeProtocol<FakeSecret> Protocol;
|
||||
typedef FakeInput<FakeSecret> Input;
|
||||
|
||||
static string type_string() { return "fake secret"; }
|
||||
static string phase_name() { return "Faking"; }
|
||||
|
||||
static int default_length;
|
||||
static const int default_length = 64;
|
||||
|
||||
typedef ostream& out_type;
|
||||
static ostream& out;
|
||||
static const bool is_real = true;
|
||||
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
static SwitchableOutput out;
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2; }
|
||||
|
||||
static MC* new_mc(mac_key_type key) { return new MC(key); }
|
||||
|
||||
static void store_clear_in_dynamic(Memory<DynamicType>& mem,
|
||||
const vector<GC::ClearWriteAccess>& accesses);
|
||||
@@ -59,6 +85,11 @@ public:
|
||||
static void inputb(T& processor, const vector<int>& args)
|
||||
{ processor.input(args); }
|
||||
template <class T>
|
||||
static void inputb(T& processor, ArithmeticProcessor&, const vector<int>& args)
|
||||
{ processor.input(args); }
|
||||
template <class T, class U>
|
||||
static void inputbvec(T&, U&, const vector<int>&) { throw not_implemented(); }
|
||||
template <class T>
|
||||
static void reveal_inst(T& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
|
||||
@@ -69,11 +100,13 @@ public:
|
||||
static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; }
|
||||
|
||||
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
|
||||
static FakeSecret input(int from, const int128& input, int n_bits);
|
||||
static FakeSecret input(int from, word input, int n_bits);
|
||||
|
||||
FakeSecret() : a(0) {}
|
||||
FakeSecret(const Integer& x) : a(x.get()) {}
|
||||
FakeSecret(__uint128_t x) : a(x) {}
|
||||
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
|
||||
|
||||
FakeSecret() {}
|
||||
template <class T>
|
||||
FakeSecret(T other) : BitVec(other) {}
|
||||
|
||||
__uint128_t operator>>(const FakeSecret& other) const { return a >> other.a; }
|
||||
__uint128_t operator<<(const FakeSecret& other) const { return a << other.a; }
|
||||
@@ -90,17 +123,19 @@ public:
|
||||
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
|
||||
|
||||
template <class T>
|
||||
void xor_(int n, const FakeSecret& x, const T& y) { (void)n; a = x.a ^ y.a; }
|
||||
void xor_(int n, const FakeSecret& x, const T& y)
|
||||
{ *this = BitVec(x.a ^ y.a).mask(n); }
|
||||
void and_(int n, const FakeSecret& x, const FakeSecret& y, bool repeat);
|
||||
void andrs(int n, const FakeSecret& x, const FakeSecret& y) { (void)n; a = x.a * y.a; }
|
||||
void andrs(int n, const FakeSecret& x, const FakeSecret& y)
|
||||
{ *this = BitVec(x.a * (y.a & 1)).mask(n); }
|
||||
|
||||
void invert(int, const FakeSecret& x) { *this = ~x.a; }
|
||||
void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); }
|
||||
|
||||
void random_bit() { a = random() % 2; }
|
||||
|
||||
void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; }
|
||||
|
||||
int size() { return -1; }
|
||||
void invert(FakeSecret) { throw not_implemented(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -65,6 +65,7 @@ enum
|
||||
STMSBI = 0x243,
|
||||
MOVSB = 0x244,
|
||||
INPUTB = 0x246,
|
||||
INPUTBVEC = 0x247,
|
||||
SPLIT = 0x248,
|
||||
CONVCBIT2S = 0x249,
|
||||
// write to clear
|
||||
|
||||
@@ -67,7 +67,7 @@ unsigned Instruction::get_mem(RegType reg_type) const
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BaseInstruction::get_mem(reg_type, MAX_SECRECY_TYPE);
|
||||
return BaseInstruction::get_mem(reg_type);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -58,6 +58,7 @@ public:
|
||||
void stop_timer() { timer[0].stop(); }
|
||||
void reset_timer() { timer[0].reset(); }
|
||||
|
||||
void run_tapes(const vector<int>& args);
|
||||
void run_tape(int thread_number, int tape_number, int arg);
|
||||
void join_tape(int thread_numer);
|
||||
};
|
||||
|
||||
@@ -61,8 +61,8 @@ template <class T>
|
||||
template <class U>
|
||||
void Memories<T>::reset(const U& program)
|
||||
{
|
||||
MS.resize_min(*program.direct_mem(SBIT), "memory");
|
||||
MC.resize_min(*program.direct_mem(CBIT), "memory");
|
||||
MS.resize_min(program.direct_mem(SBIT), "memory");
|
||||
MC.resize_min(program.direct_mem(CBIT), "memory");
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -70,7 +70,7 @@ template <class U>
|
||||
void Machine<T>::reset(const U& program)
|
||||
{
|
||||
Memories<T>::reset(program);
|
||||
MI.resize_min(*program.direct_mem(INT), "memory");
|
||||
MI.resize_min(program.direct_mem(INT), "memory");
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -78,12 +78,20 @@ template <class U, class V>
|
||||
void Machine<T>::reset(const U& program, V& MD)
|
||||
{
|
||||
reset(program);
|
||||
MD.resize_min(*program.direct_mem(DYN_SBIT), "dynamic memory");
|
||||
MD.resize_min(program.direct_mem(DYN_SBIT), "dynamic memory");
|
||||
#ifdef DEBUG_MEMORY
|
||||
cerr << "reset dynamic mem to " << program.direct_mem(DYN_SBIT) << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Machine<T>::run_tapes(const vector<int>& args)
|
||||
{
|
||||
assert(args.size() % 3 == 0);
|
||||
for (unsigned i = 0; i < args.size(); i++)
|
||||
run_tape(args[i], args[i + 1], args[i + 2]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Machine<T>::run_tape(int thread_number, int tape_number, int arg)
|
||||
{
|
||||
|
||||
43
GC/NoShare.h
43
GC/NoShare.h
@@ -6,15 +6,23 @@
|
||||
#ifndef GC_NOSHARE_H_
|
||||
#define GC_NOSHARE_H_
|
||||
|
||||
#include "BMR/Register.h"
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
|
||||
class InputArgs;
|
||||
class ArithmeticProcessor;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class Processor;
|
||||
class Clear;
|
||||
|
||||
class NoValue : public ValueInterface
|
||||
{
|
||||
public:
|
||||
typedef NoValue Scalar;
|
||||
|
||||
const static int n_bits = 0;
|
||||
const static int MAX_N_BITS = 0;
|
||||
|
||||
@@ -28,6 +36,11 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int length()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void fail()
|
||||
{
|
||||
throw runtime_error("VM does not support binary circuits");
|
||||
@@ -43,9 +56,13 @@ public:
|
||||
int operator<<(int) const { fail(); return 0; }
|
||||
void operator+=(int) { fail(); }
|
||||
|
||||
bool operator!=(NoValue) const { fail(); return 0; }
|
||||
|
||||
bool get_bit(int) { fail(); return 0; }
|
||||
|
||||
void randomize(PRNG&) { fail(); }
|
||||
|
||||
void invert() { fail(); }
|
||||
};
|
||||
|
||||
inline ostream& operator<<(ostream& o, NoValue)
|
||||
@@ -53,18 +70,11 @@ inline ostream& operator<<(ostream& o, NoValue)
|
||||
return o;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline bool operator!=(const T&, NoValue&)
|
||||
{
|
||||
NoValue::fail();
|
||||
return true;
|
||||
}
|
||||
|
||||
class NoShare : public Phase
|
||||
class NoShare
|
||||
{
|
||||
public:
|
||||
typedef DummyMC<NoShare> MC;
|
||||
typedef DummyProtocol Protocol;
|
||||
typedef DummyProtocol<NoShare> Protocol;
|
||||
typedef NotImplementedInput<NoShare> Input;
|
||||
typedef DummyLivePrep<NoShare> LivePrep;
|
||||
typedef DummyMC<NoShare> MAC_Check;
|
||||
@@ -83,6 +93,8 @@ public:
|
||||
static const bool expensive_triples = false;
|
||||
static const bool is_real = false;
|
||||
|
||||
static SwitchableOutput out;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
return new MC;
|
||||
@@ -118,13 +130,16 @@ public:
|
||||
NoValue::fail();
|
||||
}
|
||||
|
||||
static void inputb(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void inputb(Processor<NoShare>&, ArithmeticProcessor&, const vector<int>&) { fail(); }
|
||||
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
|
||||
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static NoShare constant(GC::Clear, int, mac_key_type) { fail(); return {}; }
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
|
||||
@@ -146,6 +161,8 @@ public:
|
||||
void operator^=(NoShare) { fail(); }
|
||||
|
||||
NoShare operator+(const NoShare&) const { fail(); return {}; }
|
||||
NoShare operator-(NoShare) const { fail(); return 0; }
|
||||
NoShare operator*(NoValue) const { fail(); return 0; }
|
||||
|
||||
NoShare operator+(int) const { fail(); return {}; }
|
||||
NoShare operator&(int) const { fail(); return {}; }
|
||||
@@ -153,6 +170,8 @@ public:
|
||||
|
||||
NoShare lsb() const { fail(); return {}; }
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void invert(int, NoShare) { fail(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -20,17 +20,6 @@ using namespace std;
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class ExecutionStats : public map<int, size_t>
|
||||
{
|
||||
public:
|
||||
ExecutionStats& operator+=(const ExecutionStats& other)
|
||||
{
|
||||
for (auto it : other)
|
||||
(*this)[it.first] += it.second;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class Processor : public ::ProcessorBase, public GC::RuntimeBranching
|
||||
{
|
||||
@@ -53,8 +42,6 @@ public:
|
||||
Memory<Clear> C;
|
||||
Memory<Integer> I;
|
||||
|
||||
ExecutionStats stats;
|
||||
|
||||
Timer xor_timer;
|
||||
|
||||
Processor(Machine<T>& machine);
|
||||
@@ -102,9 +89,12 @@ public:
|
||||
void ands(const vector<int>& args) { and_(args, false); }
|
||||
|
||||
void input(const vector<int>& args);
|
||||
void reveal(const vector<int>& args);
|
||||
void inputb(typename T::Input& input, ProcessorBase& input_processor,
|
||||
const vector<int>& args, int my_num);
|
||||
void inputbvec(typename T::Input& input, ProcessorBase& input_processor,
|
||||
const vector<int>& args, int my_num);
|
||||
|
||||
void reveal(const ::BaseInstruction& instruction);
|
||||
void reveal(const vector<int>& args);
|
||||
|
||||
void print_reg(int reg, int n, int size);
|
||||
void print_reg_plain(Clear& value);
|
||||
|
||||
@@ -79,6 +79,8 @@ template<class U>
|
||||
U GC::Processor<T>::get_long_input(const int* params,
|
||||
ProcessorBase& input_proc, bool interactive)
|
||||
{
|
||||
if (not T::actual_inputs)
|
||||
return {};
|
||||
U res = input_proc.get_input<FixInput_<U>>(interactive,
|
||||
¶ms[1]).items[0];
|
||||
int n_bits = *params;
|
||||
@@ -251,8 +253,12 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
|
||||
check_args(args, 4);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
assert(args[i] <= T::default_length);
|
||||
S[args[i+1]].and_(args[i], S[args[i+2]], S[args[i+3]], repeat);
|
||||
for (int j = 0; j < DIV_CEIL(args[i], T::default_length); j++)
|
||||
{
|
||||
int n = min(T::default_length, args[i] - j * T::default_length);
|
||||
S[args[i + 1] + j].and_(n, S[args[i + 2] + j],
|
||||
S[args[i + 3] + (repeat ? 0 : j)], repeat);
|
||||
}
|
||||
complexity += args[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ class Program
|
||||
unsigned num_reg(RegType reg_type) const
|
||||
{ return max_reg[reg_type]; }
|
||||
|
||||
const unsigned* direct_mem(RegType reg_type) const
|
||||
{ return &max_mem[reg_type]; }
|
||||
unsigned direct_mem(RegType reg_type) const
|
||||
{ return max_mem[reg_type]; }
|
||||
|
||||
template<class T, class U>
|
||||
BreakType execute(Processor<T>& Proc, U& dynamic_memory, int PC = -1) const;
|
||||
|
||||
15
GC/Secret.h
15
GC/Secret.h
@@ -60,6 +60,8 @@ public:
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
typedef typename T::Input Input;
|
||||
|
||||
static string type_string() { return "evaluation secret"; }
|
||||
static string phase_name() { return T::name(); }
|
||||
|
||||
@@ -71,6 +73,8 @@ public:
|
||||
|
||||
static const bool is_real = true;
|
||||
|
||||
static const bool actual_inputs = T::actual_inputs;
|
||||
|
||||
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
|
||||
static Secret<T> input(Processor<Secret<T>>& processor, const InputArgs& args);
|
||||
void random(int n_bits, int128 share);
|
||||
@@ -102,6 +106,10 @@ public:
|
||||
const vector<int>& args)
|
||||
{ T::inputb(processor, input_proc, args); }
|
||||
template<class U>
|
||||
static void inputbvec(Processor<U>& processor, ProcessorBase& input_proc,
|
||||
const vector<int>& args)
|
||||
{ T::inputbvec(processor, input_proc, args); }
|
||||
template<class U>
|
||||
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
|
||||
@@ -143,6 +151,13 @@ public:
|
||||
template <class U>
|
||||
void reveal(size_t n_bits, U& x);
|
||||
|
||||
template <class U>
|
||||
void my_input(U& inputter, BitVec value, int n_bits);
|
||||
template <class U>
|
||||
void other_input(U& inputter, int from, int n_bits);
|
||||
template <class U>
|
||||
void finalize_input(U& inputter, int from, int n_bits);
|
||||
|
||||
int size() const { return registers.size(); }
|
||||
RegVector& get_regs() { return registers; }
|
||||
const RegVector& get_regs() const { return registers; }
|
||||
|
||||
@@ -57,6 +57,33 @@ Secret<T> Secret<T>::input(party_id_t from, const int128& input, int n_bits)
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void GC::Secret<T>::my_input(U& inputter, BitVec value, int n_bits)
|
||||
{
|
||||
resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
get_reg(i).my_input(inputter, value.get_bit(i), 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void GC::Secret<T>::other_input(U& inputter, int from, int n_bits)
|
||||
{
|
||||
resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
get_reg(i).other_input(inputter, from);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void GC::Secret<T>::finalize_input(U& inputter, int from, int n_bits)
|
||||
{
|
||||
resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
get_reg(i).finalize_input(inputter, from, 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Secret<T>::random(int n_bits, int128 share)
|
||||
{
|
||||
|
||||
@@ -41,9 +41,8 @@ public:
|
||||
typedef Memory<U> DynamicMemory;
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
|
||||
static const bool is_real = true;
|
||||
static const bool actual_inputs = true;
|
||||
|
||||
static SwitchableOutput out;
|
||||
|
||||
@@ -63,6 +62,8 @@ public:
|
||||
{ inputb(processor, processor, args); }
|
||||
static void inputb(Processor<U>& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
static void inputbvec(Processor<U>& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
static void reveal_inst(Processor<U>& processor, const vector<int>& args);
|
||||
|
||||
template<class T>
|
||||
@@ -75,6 +76,13 @@ public:
|
||||
void invert(int n, const U& x);
|
||||
|
||||
void random_bit();
|
||||
|
||||
template<class T>
|
||||
void my_input(T& inputter, BitVec value, int n_bits);
|
||||
template<class T>
|
||||
void other_input(T& inputter, int from, int n_bits = 1);
|
||||
template<class T>
|
||||
void finalize_input(T& inputter, int from, int n_bits);
|
||||
};
|
||||
|
||||
template<class U>
|
||||
@@ -169,6 +177,8 @@ public:
|
||||
typedef SemiHonestRepSecret small_type;
|
||||
typedef SemiHonestRepSecret whole_type;
|
||||
|
||||
static const bool expensive_triples = false;
|
||||
|
||||
static MC* new_mc(mac_key_type) { return new MC; }
|
||||
|
||||
SemiHonestRepSecret() {}
|
||||
|
||||
@@ -98,6 +98,28 @@ void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
|
||||
mem[access.address] = access.value;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
void GC::ShareSecret<U>::my_input(T& inputter, BitVec value, int n_bits)
|
||||
{
|
||||
inputter.add_mine(value, n_bits);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
void GC::ShareSecret<U>::other_input(T& inputter, int from, int)
|
||||
{
|
||||
inputter.add_other(from);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
template<class T>
|
||||
void GC::ShareSecret<U>::finalize_input(T& inputter, int from,
|
||||
int n_bits)
|
||||
{
|
||||
static_cast<U&>(*this) = inputter.finalize(from, n_bits).mask(n_bits);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
ProcessorBase& input_processor,
|
||||
@@ -106,25 +128,46 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
auto& party = ShareThread<U>::s();
|
||||
typename U::Input input(*party.MC, party.DataF, *party.P);
|
||||
input.reset_all(*party.P);
|
||||
processor.inputb(input, input_processor, args, party.P->my_num());
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::inputbvec(Processor<U>& processor,
|
||||
ProcessorBase& input_processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
auto& party = ShareThread<U>::s();
|
||||
typename U::Input input(*party.MC, party.DataF, *party.P);
|
||||
input.reset_all(*party.P);
|
||||
processor.inputbvec(input, input_processor, args, party.P->my_num());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::inputb(typename T::Input& input, ProcessorBase& input_processor,
|
||||
const vector<int>& args, int my_num)
|
||||
{
|
||||
InputArgList a(args);
|
||||
bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0;
|
||||
int dl = U::default_length;
|
||||
complexity += a.n_input_bits();
|
||||
bool interactive = a.n_interactive_inputs_from_me(my_num) > 0;
|
||||
int dl = T::default_length;
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
if (x.from == party.P->my_num())
|
||||
if (x.from == my_num)
|
||||
{
|
||||
bigint whole_input = processor.template
|
||||
get_long_input<bigint>(x.params,
|
||||
bigint whole_input = get_long_input<bigint>(x.params,
|
||||
input_processor, interactive);
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
input.add_mine(bigint(whole_input >> (i * dl)).get_si(),
|
||||
{
|
||||
auto& res = S[x.dest + i];
|
||||
res.my_input(input, bigint(whole_input >> (i * dl)).get_si(),
|
||||
min(dl, x.n_bits - i * dl));
|
||||
}
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
input.add_other(x.from);
|
||||
S[x.dest + i].other_input(input, x.from,
|
||||
min(dl, x.n_bits - i * dl));
|
||||
}
|
||||
|
||||
if (interactive)
|
||||
@@ -138,9 +181,51 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
int n_bits = x.n_bits;
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
{
|
||||
auto& res = processor.S[x.dest + i];
|
||||
auto& res = S[x.dest + i];
|
||||
int n = min(dl, n_bits - i * dl);
|
||||
res = input.finalize(from, n).mask(n);
|
||||
res.finalize_input(input, from, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::inputbvec(typename T::Input& input, ProcessorBase& input_processor,
|
||||
const vector<int>& args, int my_num)
|
||||
{
|
||||
InputVecArgList a(args);
|
||||
complexity += a.n_input_bits();
|
||||
bool interactive = a.n_interactive_inputs_from_me(my_num) > 0;
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
if (x.from == my_num)
|
||||
{
|
||||
bigint whole_input = get_long_input<bigint>(x.params,
|
||||
input_processor, interactive);
|
||||
for (int i = 0; i < x.n_bits; i++)
|
||||
{
|
||||
auto& res = S[x.dest[i]];
|
||||
res.my_input(input, bigint(whole_input >> (i)).get_si() & 1, 1);
|
||||
}
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < x.n_bits; i++)
|
||||
S[x.dest[i]].other_input(input, x.from, 1);
|
||||
}
|
||||
|
||||
if (interactive)
|
||||
cout << "Thank you" << endl;
|
||||
|
||||
input.exchange();
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
int from = x.from;
|
||||
int n_bits = x.n_bits;
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
auto& res = S[x.dest[i]];
|
||||
res.finalize_input(input, from, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,9 @@ public:
|
||||
|
||||
void pre_run();
|
||||
void post_run() { ShareThread<T>::post_run(); }
|
||||
|
||||
size_t data_sent()
|
||||
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -56,7 +56,7 @@ public:
|
||||
void join_tape();
|
||||
void finish();
|
||||
|
||||
int n_interactive_inputs_from_me(InputArgList& args);
|
||||
virtual size_t data_sent();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -91,15 +91,17 @@ void Thread<T>::finish()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
|
||||
size_t GC::Thread<T>::data_sent()
|
||||
{
|
||||
return args.n_interactive_inputs_from_me(P->my_num());
|
||||
assert(P);
|
||||
return P->comm_stats.total_data();
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
|
||||
inline int InputArgList::n_interactive_inputs_from_me(int my_num)
|
||||
template<class T>
|
||||
inline int InputArgListBase<T>::n_interactive_inputs_from_me(int my_num)
|
||||
{
|
||||
int res = 0;
|
||||
if (ArithmeticProcessor().use_stdin())
|
||||
|
||||
@@ -87,10 +87,12 @@ void ThreadMaster<T>::run()
|
||||
|
||||
NamedCommStats stats = P->comm_stats;
|
||||
ExecutionStats exe_stats;
|
||||
size_t data_sent = P->comm_stats.total_data();
|
||||
for (auto thread : threads)
|
||||
{
|
||||
stats += thread->P->comm_stats;
|
||||
exe_stats += thread->processor.stats;
|
||||
data_sent += thread->data_sent();
|
||||
delete thread;
|
||||
}
|
||||
|
||||
@@ -108,7 +110,8 @@ void ThreadMaster<T>::run()
|
||||
if (it->second.data > 0)
|
||||
cerr << it->first << " " << 1e-6 * it->second.data << " MB" << endl;
|
||||
|
||||
cerr << "Time = " << timer.elapsed() << endl;
|
||||
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
|
||||
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -151,6 +151,18 @@ public:
|
||||
for (auto& reg : this->get_regs())
|
||||
reg.output(s, human);
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void my_input(U& inputter, BitVec value, int n_bits)
|
||||
{
|
||||
inputter.add_mine(value, n_bits);
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void finalize_input(U& inputter, int from, int n_bits)
|
||||
{
|
||||
*this = inputter.finalize(from, n_bits).mask(n_bits);
|
||||
}
|
||||
};
|
||||
|
||||
template<int S>
|
||||
|
||||
@@ -71,6 +71,7 @@
|
||||
|
||||
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
|
||||
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
|
||||
X(ANDM, processor.andm(instruction)) \
|
||||
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
|
||||
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
|
||||
@@ -85,6 +86,7 @@
|
||||
|
||||
#define GC_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, EXTRA)) \
|
||||
X(INPUTBVEC, T::inputbvec(PROC, PROC, EXTRA)) \
|
||||
X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \
|
||||
X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \
|
||||
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \
|
||||
@@ -130,7 +132,7 @@
|
||||
X(PRINTINT, S0.out << I0) \
|
||||
X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \
|
||||
X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \
|
||||
X(RUN_TAPE, MACH->run_tape(R0, IMM, REG1)) \
|
||||
X(RUN_TAPE, MACH->run_tapes(EXTRA)) \
|
||||
X(JOIN_TAPE, MACH->join_tape(R0)) \
|
||||
X(USE, ) \
|
||||
|
||||
|
||||
59
Machines/emulate.cpp
Normal file
59
Machines/emulate.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* emulate.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/FakeShare.h"
|
||||
#include "Processor/Machine.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Processor/RingOptions.h"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ShuffleSacrifice.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/FakeShare.hpp"
|
||||
|
||||
SwitchableOutput GC::NoShare::out;
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
assert(argc > 1);
|
||||
OnlineOptions online_opts;
|
||||
Names N(0, 9999, vector<string>({"localhost"}));
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
opt.parse(argc, argv);
|
||||
string progname;
|
||||
if (opt.firstArgs.size() > 1)
|
||||
progname = *opt.firstArgs.at(1);
|
||||
else if (not opt.lastArgs.empty())
|
||||
progname = *opt.lastArgs.at(0);
|
||||
else if (not opt.unknownArgs.empty())
|
||||
progname = *opt.unknownArgs.at(0);
|
||||
else
|
||||
{
|
||||
string usage;
|
||||
opt.getUsage(usage);
|
||||
cerr << usage << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
switch (ring_opts.R)
|
||||
{
|
||||
case 64:
|
||||
Machine<FakeShare<SignedZ2<64>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
case 128:
|
||||
Machine<FakeShare<SignedZ2<128>>, FakeShare<gf2n>>(0, N, progname,
|
||||
online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true,
|
||||
online_opts.live_prep, online_opts).run();
|
||||
break;
|
||||
default:
|
||||
cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl;
|
||||
}
|
||||
}
|
||||
@@ -5,9 +5,8 @@
|
||||
|
||||
#include "Protocols/MaliciousRep3Share.h"
|
||||
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Machines/SPDZ.cpp"
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/SPDZ.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "GC/Thread.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
*/
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
#include "Protocols/ReplicatedMachine.hpp"
|
||||
#include "Protocols/ReplicatedFieldMachine.hpp"
|
||||
#include "Machines/Rep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
ReplicatedMachine<Rep3Share<gfp>, Rep3Share<gf2n>>(argc, argv,
|
||||
"replicated-field", opt);
|
||||
ReplicatedFieldMachine<Rep3Share>(argc, argv);
|
||||
}
|
||||
|
||||
12
Machines/semi-bmr-party.cpp
Normal file
12
Machines/semi-bmr-party.cpp
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
* semi-bmr-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/Semi.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
RealProgramParty<SemiShare<gf2n_long>>(argc, argv);
|
||||
}
|
||||
@@ -3,9 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
|
||||
4
Makefile
4
Makefile
@@ -57,7 +57,7 @@ include $(wildcard *.d static/*.d)
|
||||
%.o: %.cpp
|
||||
$(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c
|
||||
|
||||
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x
|
||||
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x
|
||||
|
||||
offline: $(OT_EXE) Check-Offline.x
|
||||
|
||||
@@ -193,6 +193,8 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o
|
||||
mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
|
||||
fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
|
||||
emulate.x: GC/FakeSecret.o
|
||||
semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o
|
||||
|
||||
$(LIBSIMPLEOT): SimpleOT/Makefile
|
||||
$(MAKE) -C SimpleOT
|
||||
|
||||
@@ -138,6 +138,8 @@ void write_online_setup(string dirname, const bigint& p)
|
||||
ofstream outf;
|
||||
outf.open(ss.str().c_str());
|
||||
outf << p << endl;
|
||||
if (!outf.good())
|
||||
throw file_error("cannot write to " + ss.str());
|
||||
}
|
||||
|
||||
void init_gf2n(int lg2)
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
class OnlineOptions;
|
||||
class bigint;
|
||||
class PRNG;
|
||||
|
||||
class ValueInterface
|
||||
{
|
||||
@@ -30,6 +31,8 @@ public:
|
||||
static int power_of_two(bool, int) { throw not_implemented(); }
|
||||
|
||||
void normalize() {}
|
||||
|
||||
void randomize_part(PRNG&, int) { throw not_implemented(); }
|
||||
};
|
||||
|
||||
#endif /* MATH_VALUEINTERFACE_H_ */
|
||||
|
||||
15
Math/Z2k.h
15
Math/Z2k.h
@@ -117,7 +117,7 @@ public:
|
||||
Z2<K> operator-(const Z2<K>& other) const;
|
||||
|
||||
template <int L>
|
||||
Z2<K+L> operator*(const Z2<L>& other) const;
|
||||
Z2<(K > L) ? K : L> operator*(const Z2<L>& other) const;
|
||||
|
||||
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(); }
|
||||
Z2<K> operator*(int other) const { return *this * Z2<K>(other); }
|
||||
@@ -171,6 +171,7 @@ public:
|
||||
void XOR(const Z2& a, const Z2& b);
|
||||
|
||||
void randomize(PRNG& G, int n = -1);
|
||||
void randomize_part(PRNG& G, int n);
|
||||
void almost_randomize(PRNG& G) { randomize(G); }
|
||||
|
||||
void force_to_bit() { throw runtime_error("impossible"); }
|
||||
@@ -342,9 +343,9 @@ inline Z2<K> Z2<K>::Mul(const Z2<L>& x, const Z2<M>& y)
|
||||
|
||||
template <int K>
|
||||
template <int L>
|
||||
inline Z2<K+L> Z2<K>::operator*(const Z2<L>& other) const
|
||||
inline Z2<(K > L) ? K : L> Z2<K>::operator*(const Z2<L>& other) const
|
||||
{
|
||||
return Z2<K+L>::Mul(*this, other);
|
||||
return Z2<(K > L) ? K : L>::Mul(*this, other);
|
||||
}
|
||||
|
||||
template <int K>
|
||||
@@ -387,6 +388,14 @@ void Z2<K>::randomize(PRNG& G, int n)
|
||||
normalize();
|
||||
}
|
||||
|
||||
template<int K>
|
||||
void Z2<K>::randomize_part(PRNG& G, int n)
|
||||
{
|
||||
*this = {};
|
||||
G.get_octets((octet*)a, DIV_CEIL(n, 8));
|
||||
a[DIV_CEIL(n, 64) - 1] &= uint64_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS);
|
||||
}
|
||||
|
||||
template<int K>
|
||||
void Z2<K>::pack(octetStream& o, int n) const
|
||||
{
|
||||
|
||||
@@ -41,7 +41,7 @@ template<class T> void generate_prime_setup(string, int, int);
|
||||
#endif
|
||||
|
||||
template<int X, int L>
|
||||
class gfp_
|
||||
class gfp_ : public ValueInterface
|
||||
{
|
||||
typedef modp_<L> modp_type;
|
||||
|
||||
@@ -247,7 +247,7 @@ class gfp_
|
||||
gfp_ operator^(const gfp_& x) { gfp_ res; res.XOR(*this, x); return res; }
|
||||
gfp_ operator|(const gfp_& x) { gfp_ res; res.OR(*this, x); return res; }
|
||||
gfp_ operator<<(int i) const { gfp_ res; res.SHL(*this, i); return res; }
|
||||
gfp_ operator>>(int i) { gfp_ res; res.SHR(*this, i); return res; }
|
||||
gfp_ operator>>(int i) const { gfp_ res; res.SHR(*this, i); return res; }
|
||||
|
||||
gfp_& operator&=(const gfp_& x) { AND(*this, x); return *this; }
|
||||
gfp_& operator<<=(int i) { SHL(*this, i); return *this; }
|
||||
|
||||
@@ -193,7 +193,8 @@ void gfp_<X, L>::reqbl(int n)
|
||||
{
|
||||
if ((int)n > 0 && pr() < bigint(1) << (n-1))
|
||||
{
|
||||
cout << "Tape requires prime of bit length " << n << endl;
|
||||
cerr << "Tape requires prime of bit length " << n << endl;
|
||||
cerr << "Run with '-lgp " << n << "'" << endl;
|
||||
throw invalid_params();
|
||||
}
|
||||
else if ((int)n < 0)
|
||||
|
||||
@@ -450,17 +450,23 @@ void Player::pass_around(octetStream& o, octetStream& to_receive, int offset) co
|
||||
* size getting in the way
|
||||
*/
|
||||
template<class T>
|
||||
void MultiPlayer<T>::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
|
||||
void MultiPlayer<T>::Broadcast_Receive_no_stats(vector<octetStream>& o) const
|
||||
{
|
||||
if (o.size() != sockets.size())
|
||||
throw runtime_error("player numbers don't match");
|
||||
TimeScope ts(comm_stats["Broadcasting"].add(o[player_no]));
|
||||
for (int i=1; i<nplayers; i++)
|
||||
{
|
||||
int send_to = (my_num() + i) % num_players();
|
||||
int receive_from = (my_num() + num_players() - i) % num_players();
|
||||
o[my_num()].exchange(sockets[send_to], sockets[receive_from], o[receive_from]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void MultiPlayer<T>::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Broadcasting"].add(o[player_no]));
|
||||
Broadcast_Receive_no_stats(o);
|
||||
if (!donthash)
|
||||
{ for (int i=0; i<nplayers; i++)
|
||||
{ hash_update(&ctx,o[i].get_data(),o[i].get_length()); }
|
||||
|
||||
@@ -247,6 +247,7 @@ public:
|
||||
* - Assumes o[player_no] contains the thing broadcast by me
|
||||
*/
|
||||
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
|
||||
void Broadcast_Receive_no_stats(vector<octetStream>& o) const;
|
||||
|
||||
// wait for available inputs
|
||||
void wait_for_available(vector<int>& players, vector<int>& result) const;
|
||||
|
||||
@@ -3,23 +3,7 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include <OT/TripleMachine.h>
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
#include "MascotParams.h"
|
||||
|
||||
MascotParams::MascotParams()
|
||||
{
|
||||
|
||||
14
OT/OTVole.h
14
OT/OTVole.h
@@ -51,18 +51,20 @@ public:
|
||||
// Both
|
||||
PRNG local_prng;
|
||||
|
||||
Row<T> tmp;
|
||||
|
||||
octetStream os;
|
||||
vector<octetStream> oss;
|
||||
|
||||
virtual void consistency_check (vector<octetStream>& os);
|
||||
|
||||
void set_coeffs(__m128i* coefficients, PRNG& G, int num_elements) const;
|
||||
|
||||
void hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients);
|
||||
|
||||
void hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients);
|
||||
template<class U>
|
||||
void hash_row(octetStream& os, const U& row, const __m128i* coefficients);
|
||||
template<class U>
|
||||
void hash_row(octet* hash, const U& row, const __m128i* coefficients);
|
||||
template<class U>
|
||||
void hash_row(__m128i res[2], const U& row, const __m128i* coefficients);
|
||||
|
||||
static void add_mul(__m128i res[2], __m128i a, __m128i b);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
||||
147
OT/OTVole.hpp
147
OT/OTVole.hpp
@@ -10,30 +10,28 @@ template <class T>
|
||||
void OTVoleBase<T>::evaluate(vector<T>& output, const vector<T>& newReceiverInput) {
|
||||
const int N1 = newReceiverInput.size() + 1;
|
||||
output.resize(newReceiverInput.size());
|
||||
vector<octetStream> os(2);
|
||||
auto& os = oss;
|
||||
os.resize(2);
|
||||
os[0].reset_write_head();
|
||||
os[1].reset_write_head();
|
||||
|
||||
if (this->ot_role & SENDER) {
|
||||
T extra;
|
||||
extra.randomize(local_prng);
|
||||
vector<T> _corr(newReceiverInput);
|
||||
_corr.push_back(extra);
|
||||
corr_prime = Row<T>(_corr);
|
||||
corr_prime.rows = newReceiverInput;
|
||||
corr_prime.rows.push_back(extra);
|
||||
for (int i = 0; i < S; ++i)
|
||||
{
|
||||
t0[i] = Row<T>(N1);
|
||||
t0[i].randomize(this->G_sender[i][0]);
|
||||
t1[i] = Row<T>(N1);
|
||||
t1[i].randomize(this->G_sender[i][1]);
|
||||
Row<T> u = corr_prime + t1[i] + t0[i];
|
||||
u.pack(os[0]);
|
||||
t0[i].randomize(this->G_sender[i][0], N1);
|
||||
t1[i].randomize(this->G_sender[i][1], N1);
|
||||
(corr_prime + t1[i] + t0[i]).pack(os[0]);
|
||||
}
|
||||
}
|
||||
send_if_ot_sender(this->player, os, this->ot_role);
|
||||
if (this->ot_role & RECEIVER) {
|
||||
for (int i = 0; i < S; ++i)
|
||||
{
|
||||
t[i] = Row<T>(N1);
|
||||
t[i].randomize(this->G_receiver[i]);
|
||||
t[i].randomize(this->G_receiver[i], N1);
|
||||
|
||||
int choice_bit = this->baseReceiverInput.get_bit(i);
|
||||
if (choice_bit == 1) {
|
||||
@@ -84,38 +82,104 @@ void OTVoleBase<T>::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) c
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTVoleBase<T>::hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients) {
|
||||
template<class U>
|
||||
void OTVoleBase<T>::hash_row(octetStream& os, const U& row,
|
||||
const __m128i* coefficients)
|
||||
{
|
||||
octet hash[VOLE_HASH_SIZE] = {0};
|
||||
this->hash_row(hash, row, coefficients);
|
||||
os.append(hash, VOLE_HASH_SIZE);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTVoleBase<T>::hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients) {
|
||||
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
|
||||
|
||||
os.clear();
|
||||
for(auto& x : row.rows)
|
||||
x.pack(os);
|
||||
os.serialize(int128());
|
||||
|
||||
__m128i prods[2];
|
||||
avx_memzero(prods, sizeof(prods));
|
||||
template <class U>
|
||||
void OTVoleBase<T>::hash_row(octet* hash, const U& row,
|
||||
const __m128i* coefficients)
|
||||
{
|
||||
__m128i res[2];
|
||||
avx_memzero(res, sizeof(res));
|
||||
|
||||
for (int i = 0; i < num_blocks; ++i) {
|
||||
__m128i block;
|
||||
os.unserialize(block);
|
||||
mul128(block, coefficients[i], &prods[0], &prods[1]);
|
||||
res[0] ^= prods[0];
|
||||
res[1] ^= prods[1];
|
||||
}
|
||||
for (int i = 0; i < 2; i++)
|
||||
res[i] = _mm_setzero_si128();
|
||||
hash_row(res, row, coefficients);
|
||||
|
||||
crypto_generichash(hash, crypto_generichash_BYTES,
|
||||
(octet*) res, crypto_generichash_BYTES, NULL, 0);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void OTVoleBase<T>::hash_row(__m128i res[2], const U& row,
|
||||
const __m128i* coefficients)
|
||||
{
|
||||
auto coeff_base = coefficients;
|
||||
int num_blocks = DIV_CEIL(row.size() * T::size(), 16);
|
||||
__m128i buffer[T::size()];
|
||||
size_t next = 0;
|
||||
while (next + 16 < row.size())
|
||||
{
|
||||
for (int j = 0; j < 16; j++)
|
||||
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
|
||||
for (int j = 0; j < T::size(); j++)
|
||||
add_mul(res, buffer[j], *coefficients++);
|
||||
}
|
||||
for (int j = 0; j < 16; j++)
|
||||
if (next < row.size())
|
||||
memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size());
|
||||
for (int j = 0; j < num_blocks % T::size(); j++)
|
||||
add_mul(res, buffer[j], *coefficients++);
|
||||
assert(coefficients == coeff_base + num_blocks);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTVoleBase<T>::add_mul(__m128i res[2], __m128i a, __m128i b)
|
||||
{
|
||||
__m128i prods[2];
|
||||
mul128(a, b, &prods[0], &prods[1]);
|
||||
res[0] ^= prods[0];
|
||||
res[1] ^= prods[1];
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class U>
|
||||
inline
|
||||
void OTVoleBase<Z2<128>>::hash_row(__m128i res[2],
|
||||
const U& row, const __m128i* coefficients)
|
||||
{
|
||||
for (size_t i = 0; i < row.size(); i++)
|
||||
{
|
||||
__m128i block = int128(row[i].get_limb(1), row[i].get_limb(0)).a;
|
||||
add_mul(res, block, coefficients[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class U>
|
||||
inline
|
||||
void OTVoleBase<Z2<192>>::hash_row(__m128i res[2], const U& row,
|
||||
const __m128i* coefficients)
|
||||
{
|
||||
__m128i block;
|
||||
size_t j;
|
||||
for (j = 0; j < row.size() - 1; j += 2)
|
||||
{
|
||||
auto x = row[j];
|
||||
auto y = row[j + 1];
|
||||
block = int128(x.get_limb(1), x.get_limb(0)).a;
|
||||
add_mul(res, block, *coefficients++);
|
||||
block = int128(y.get_limb(0), x.get_limb(2)).a;
|
||||
add_mul(res, block, *coefficients++);
|
||||
block = int128(y.get_limb(2), y.get_limb(1)).a;
|
||||
add_mul(res, block, *coefficients++);
|
||||
}
|
||||
if (j < row.size())
|
||||
{
|
||||
auto x = row[j];
|
||||
block = int128(x.get_limb(1), x.get_limb(0)).a;
|
||||
add_mul(res, block, *coefficients++);
|
||||
block = int128(x.get_limb(2)).a;
|
||||
add_mul(res, block, *coefficients++);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
|
||||
PRNG coef_prng_sender;
|
||||
@@ -144,17 +208,16 @@ void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
|
||||
__m128i coefficients[num_blocks];
|
||||
this->set_coeffs(coefficients, coef_prng_sender, num_blocks);
|
||||
|
||||
Row<T> t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size());
|
||||
for (int alpha = 0; alpha < S; ++alpha)
|
||||
{
|
||||
for (int i = 0; i < n_challenges(); i++)
|
||||
{
|
||||
int beta = get_challenge(coef_prng_sender, i);
|
||||
|
||||
t00 = t0[alpha] - t0[beta];
|
||||
t01 = t0[alpha] - t1[beta];
|
||||
t10 = t1[alpha] - t0[beta];
|
||||
t11 = t1[alpha] - t1[beta];
|
||||
auto t00 = t0[alpha] - t0[beta];
|
||||
auto t01 = t0[alpha] - t1[beta];
|
||||
auto t10 = t1[alpha] - t0[beta];
|
||||
auto t11 = t1[alpha] - t1[beta];
|
||||
|
||||
this->hash_row(os[0], t00, coefficients);
|
||||
this->hash_row(os[0], t01, coefficients);
|
||||
@@ -202,16 +265,16 @@ void OTVoleBase<T>::consistency_check(vector<octetStream>& os) {
|
||||
int choice_alpha = this->baseReceiverInput.get_bit(alpha);
|
||||
int choice_beta = this->baseReceiverInput.get_bit(beta);
|
||||
|
||||
tmp = t[alpha] - t[beta];
|
||||
auto diff = t[alpha] - t[beta];
|
||||
octet* choice_hash = hashes[choice_alpha][choice_beta];
|
||||
octet diff_t[VOLE_HASH_SIZE] = {0};
|
||||
this->hash_row(diff_t, tmp, coefficients);
|
||||
this->hash_row(diff_t, diff, coefficients);
|
||||
|
||||
octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta];
|
||||
octet other_diff[VOLE_HASH_SIZE] = {0};
|
||||
tmp = u[alpha] - u[beta];
|
||||
tmp -= t[alpha];
|
||||
tmp += t[beta];
|
||||
auto a = u[alpha] - u[beta];
|
||||
auto b = a - t[alpha];
|
||||
auto tmp = b + t[beta];
|
||||
this->hash_row(other_diff, tmp, coefficients);
|
||||
|
||||
if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) {
|
||||
|
||||
77
OT/Row.h
77
OT/Row.h
@@ -5,7 +5,8 @@
|
||||
#include "Math/gf2nlong.h"
|
||||
#define VOLE_HASH_SIZE crypto_generichash_BYTES
|
||||
|
||||
template <class T> class DeferredMinus;
|
||||
template <class T, class U> class DeferredMinus;
|
||||
template <class T, class U> class DeferredPlus;
|
||||
|
||||
template <class T>
|
||||
class Row
|
||||
@@ -20,7 +21,10 @@ public:
|
||||
|
||||
Row(const vector<T>& _rows) : rows(_rows) {}
|
||||
|
||||
Row(DeferredMinus<T> d) { *this = d; }
|
||||
template<class U>
|
||||
Row(DeferredMinus<T, U> d) { *this = d; }
|
||||
template<class U>
|
||||
Row(DeferredPlus<T, U> d) { *this = d; }
|
||||
|
||||
bool operator==(const Row<T>& other) const;
|
||||
bool operator!=(const Row<T>& other) const { return not (*this == other); }
|
||||
@@ -31,23 +35,28 @@ public:
|
||||
Row<T>& operator*=(const T& other);
|
||||
|
||||
Row<T> operator*(const T& other);
|
||||
Row<T> operator+(const Row<T> & other);
|
||||
DeferredMinus<T> operator-(const Row<T> & other);
|
||||
DeferredPlus<T, Row<T>> operator+(const Row<T> & other);
|
||||
DeferredMinus<T, Row<T>> operator-(const Row<T>& other);
|
||||
|
||||
Row<T>& operator=(DeferredMinus<T> d);
|
||||
template<class U>
|
||||
Row<T>& operator=(const DeferredMinus<T, U>& d);
|
||||
template<class U>
|
||||
Row<T>& operator=(const DeferredPlus<T, U>& d);
|
||||
|
||||
Row<T> operator<<(int i) const;
|
||||
|
||||
// fine, since elements in vector are allocated contiguously
|
||||
const void* get_ptr() const { return rows[0].get_ptr(); }
|
||||
|
||||
void randomize(PRNG& G);
|
||||
void randomize(PRNG& G, size_t size);
|
||||
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
size_t size() const { return rows.size(); }
|
||||
|
||||
const T& operator[](size_t i) const { return rows[i]; }
|
||||
|
||||
template <class V>
|
||||
friend ostream& operator<<(ostream& o, const Row<V>& x);
|
||||
};
|
||||
@@ -55,17 +64,67 @@ public:
|
||||
template <int K>
|
||||
using Z2kRow = Row<Z2<K>>;
|
||||
|
||||
template <class T>
|
||||
template <class T, class U>
|
||||
class DeferredMinus
|
||||
{
|
||||
public:
|
||||
const Row<T>& x;
|
||||
const U& x;
|
||||
const Row<T>& y;
|
||||
|
||||
DeferredMinus(const Row<T>& x, const Row<T>& y) : x(x), y(y)
|
||||
DeferredMinus(const U& x, const Row<T>& y) : x(x), y(y)
|
||||
{
|
||||
assert(x.size() == y.size());
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return x.size();
|
||||
}
|
||||
|
||||
T operator[](size_t i) const
|
||||
{
|
||||
return x[i] - y[i];
|
||||
}
|
||||
|
||||
DeferredPlus<T, DeferredMinus> operator+(const Row<T>& other)
|
||||
{
|
||||
return {*this, other};
|
||||
}
|
||||
|
||||
DeferredMinus<T, DeferredMinus> operator-(const Row<T>& other)
|
||||
{
|
||||
return {*this, other};
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, class U>
|
||||
class DeferredPlus
|
||||
{
|
||||
public:
|
||||
const U& x;
|
||||
const Row<T>& y;
|
||||
|
||||
DeferredPlus(const U& x, const Row<T>& y) : x(x), y(y)
|
||||
{
|
||||
assert(x.size() == y.size());
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return x.size();
|
||||
}
|
||||
|
||||
T operator[](size_t i) const
|
||||
{
|
||||
return x[i] + y[i];
|
||||
}
|
||||
|
||||
DeferredPlus<T, DeferredPlus> operator+(const Row<T>& other)
|
||||
{
|
||||
return {*this, other};
|
||||
}
|
||||
|
||||
void pack(octetStream& o) const;
|
||||
};
|
||||
|
||||
#endif /* OT_ROW_H_ */
|
||||
|
||||
49
OT/Row.hpp
49
OT/Row.hpp
@@ -46,34 +46,46 @@ Row<T> Row<T>::operator *(const T& other)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Row<T> Row<T>::operator +(const Row<T>& other)
|
||||
DeferredPlus<T, Row<T>> Row<T>::operator +(const Row<T>& other)
|
||||
{
|
||||
Row<T> res = other;
|
||||
res += *this;
|
||||
return res;
|
||||
return {*this, other};
|
||||
}
|
||||
|
||||
template<class T>
|
||||
DeferredMinus<T> Row<T>::operator -(const Row<T>& other)
|
||||
DeferredMinus<T, Row<T>> Row<T>::operator -(const Row<T>& other)
|
||||
{
|
||||
return DeferredMinus<T>(*this, other);
|
||||
return {*this, other};
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Row<T>& Row<T>::operator=(DeferredMinus<T> d)
|
||||
template<class U>
|
||||
Row<T>& Row<T>::operator=(const DeferredMinus<T, U>& d)
|
||||
{
|
||||
size_t size = d.x.size();
|
||||
rows.resize(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rows[i] = d.x.rows[i] - d.y.rows[i];
|
||||
rows[i] = d[i];
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Row<T>::randomize(PRNG& G)
|
||||
template<class U>
|
||||
Row<T>& Row<T>::operator=(const DeferredPlus<T, U>& d)
|
||||
{
|
||||
for (size_t i = 0; i < this->size(); i++)
|
||||
rows[i].randomize(G);
|
||||
size_t size = d.x.size();
|
||||
rows.resize(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rows[i] = d[i];
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Row<T>::randomize(PRNG& G, size_t size)
|
||||
{
|
||||
rows.clear();
|
||||
rows.reserve(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rows.push_back(G.get<T>());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -87,6 +99,14 @@ Row<T> Row<T>::operator<<(int i) const {
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
void DeferredPlus<T, U>::pack(octetStream& o) const
|
||||
{
|
||||
o.store(this->size());
|
||||
for (size_t i = 0; i < this->size(); i++)
|
||||
(*this)[i].pack(o);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Row<T>::pack(octetStream& o) const
|
||||
{
|
||||
@@ -100,9 +120,10 @@ void Row<T>::unpack(octetStream& o)
|
||||
{
|
||||
size_t size;
|
||||
o.get(size);
|
||||
this->rows.resize(size);
|
||||
for (size_t i = 0; i < this->size(); i++)
|
||||
rows[i].unpack(o);
|
||||
rows.clear();
|
||||
rows.reserve(size);
|
||||
for (size_t i = 0; i < size; i++)
|
||||
rows.push_back(o.get<T>());
|
||||
}
|
||||
|
||||
template <class V>
|
||||
|
||||
@@ -11,6 +11,9 @@ using namespace std;
|
||||
|
||||
#include "Math/BitVec.h"
|
||||
#include "Data_Files.h"
|
||||
#include "Protocols/Replicated.h"
|
||||
#include "Protocols/MAC_Check_Base.h"
|
||||
#include "Processor/Input.h"
|
||||
|
||||
class Player;
|
||||
class DataPositions;
|
||||
@@ -26,17 +29,21 @@ template<class T> class ShareThread;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
class DummyMC
|
||||
class DummyMC : public MAC_Check_Base<T>
|
||||
{
|
||||
public:
|
||||
void POpen(vector<typename T::open_type>&, vector<T>&, Player&)
|
||||
DummyMC()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void Check(Player& P)
|
||||
template<class U>
|
||||
DummyMC(U, int = 0, int = 0)
|
||||
{
|
||||
(void) P;
|
||||
}
|
||||
|
||||
void exchange(const Player&)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
DummyMC<typename T::part_type>& get_part_MC()
|
||||
@@ -51,7 +58,8 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class DummyProtocol
|
||||
template<class T>
|
||||
class DummyProtocol : public ProtocolBase<T>
|
||||
{
|
||||
public:
|
||||
Player& P;
|
||||
@@ -66,13 +74,11 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void init_mul(SubProcessor<T>* = 0)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
template<class T>
|
||||
void prepare_mul(const T&, const T&, int = 0)
|
||||
typename T::clear prepare_mul(const T&, const T&, int = 0)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
@@ -80,10 +86,10 @@ public:
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
int finalize_mul(int = 0)
|
||||
T finalize_mul(int = 0)
|
||||
{
|
||||
throw not_implemented();
|
||||
return 0;
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -91,6 +97,13 @@ template<class T>
|
||||
class DummyLivePrep : public Preprocessing<T>
|
||||
{
|
||||
public:
|
||||
static void basic_setup(Player&)
|
||||
{
|
||||
}
|
||||
static void teardown()
|
||||
{
|
||||
}
|
||||
|
||||
static void fail()
|
||||
{
|
||||
throw runtime_error(
|
||||
@@ -106,6 +119,11 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
Preprocessing<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
void set_protocol(typename T::Protocol&)
|
||||
{
|
||||
}
|
||||
@@ -177,7 +195,7 @@ public:
|
||||
throw not_implemented();
|
||||
}
|
||||
template<class T>
|
||||
static void input(SubProcessor<T>& proc, vector<int> regs)
|
||||
static void input(SubProcessor<V>& proc, vector<int> regs, int)
|
||||
{
|
||||
(void) proc, (void) regs;
|
||||
throw not_implemented();
|
||||
@@ -206,6 +224,14 @@ public:
|
||||
(void) a, (void) b;
|
||||
throw not_implemented();
|
||||
}
|
||||
static void raw_input(SubProcessor<V>&, vector<int>, int)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
static void input_mixed(SubProcessor<V>&, vector<int>, int, bool)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
};
|
||||
|
||||
class NotImplementedOutput
|
||||
|
||||
@@ -106,10 +106,12 @@ enum
|
||||
SQUARE = 0x52,
|
||||
INV = 0x53,
|
||||
INPUTMASK = 0x56,
|
||||
INPUTMASKREG = 0x5C,
|
||||
PREP = 0x57,
|
||||
DABIT = 0x58,
|
||||
EDABIT = 0x59,
|
||||
SEDABIT = 0x5A,
|
||||
RANDOMS = 0x5B,
|
||||
// Input
|
||||
INPUT = 0x60,
|
||||
INPUTFIX = 0xF0,
|
||||
@@ -143,6 +145,7 @@ enum
|
||||
SHRC = 0x81,
|
||||
SHLCI = 0x82,
|
||||
SHRCI = 0x83,
|
||||
SHRSI = 0x84,
|
||||
// Branching and comparison
|
||||
JMP = 0x90,
|
||||
JMPNZ = 0x91,
|
||||
@@ -283,13 +286,16 @@ enum
|
||||
// Register types
|
||||
enum RegType {
|
||||
MODP,
|
||||
GF2N,
|
||||
INT,
|
||||
SBIT,
|
||||
CBIT,
|
||||
DYN_SBIT,
|
||||
SINT,
|
||||
CINT,
|
||||
SGF2N,
|
||||
CGF2N,
|
||||
NONE,
|
||||
MAX_REG_TYPE,
|
||||
NONE
|
||||
};
|
||||
|
||||
enum SecrecyType {
|
||||
@@ -323,6 +329,8 @@ struct TempVars {
|
||||
|
||||
class BaseInstruction
|
||||
{
|
||||
friend class Program;
|
||||
|
||||
protected:
|
||||
int opcode; // The code
|
||||
int size; // Vector size
|
||||
@@ -346,10 +354,10 @@ public:
|
||||
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
|
||||
virtual int get_reg_type() const;
|
||||
|
||||
bool is_direct_memory_access(SecrecyType sec_type) const;
|
||||
bool is_direct_memory_access() const;
|
||||
|
||||
// Returns the memory size used if applicable and known
|
||||
unsigned get_mem(RegType reg_type, SecrecyType sec_type) const;
|
||||
unsigned get_mem(RegType reg_type) const;
|
||||
|
||||
// Returns the maximal register used
|
||||
unsigned get_max_reg(int reg_type) const;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "Processor/IntInput.h"
|
||||
#include "Processor/FixInput.h"
|
||||
#include "Processor/FloatInput.h"
|
||||
#include "Processor/instructions.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/parse.h"
|
||||
@@ -102,6 +103,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case MULINT:
|
||||
case DIVINT:
|
||||
case CONDPRINTPLAIN:
|
||||
case INPUTMASKREG:
|
||||
r[0]=get_int(s);
|
||||
r[1]=get_int(s);
|
||||
r[2]=get_int(s);
|
||||
@@ -199,6 +201,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case ORCI:
|
||||
case SHLCI:
|
||||
case SHRCI:
|
||||
case SHRSI:
|
||||
case SHLCBI:
|
||||
case SHRCBI:
|
||||
case NOTC:
|
||||
@@ -220,7 +223,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
case RUN_TAPE:
|
||||
case STARTPRIVATEOUTPUT:
|
||||
case GSTARTPRIVATEOUTPUT:
|
||||
case STOPPRIVATEOUTPUT:
|
||||
@@ -261,6 +263,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case INV2M:
|
||||
case CONDPRINTSTR:
|
||||
case CONDPRINTSTRB:
|
||||
case RANDOMS:
|
||||
r[0]=get_int(s);
|
||||
n = get_int(s);
|
||||
break;
|
||||
@@ -310,6 +313,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case RAWINPUT:
|
||||
case GRAWINPUT:
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
@@ -444,6 +448,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case ANDRS:
|
||||
case ANDS:
|
||||
case INPUTB:
|
||||
case INPUTBVEC:
|
||||
case REVEAL:
|
||||
get_vector(get_int(s), start, s);
|
||||
break;
|
||||
@@ -544,15 +549,57 @@ int BaseInstruction::get_reg_type() const
|
||||
case USE_PREP:
|
||||
case GUSE_PREP:
|
||||
case USE_EDABIT:
|
||||
case RUN_TAPE:
|
||||
// those use r[] not for registers
|
||||
return NONE;
|
||||
case LDI:
|
||||
case LDMC:
|
||||
case STMC:
|
||||
case LDMCI:
|
||||
case STMCI:
|
||||
case MOVC:
|
||||
case ADDC:
|
||||
case ADDCI:
|
||||
case SUBC:
|
||||
case SUBCI:
|
||||
case SUBCFI:
|
||||
case MULC:
|
||||
case MULCI:
|
||||
case DIVC:
|
||||
case DIVCI:
|
||||
case MODC:
|
||||
case MODCI:
|
||||
case LEGENDREC:
|
||||
case DIGESTC:
|
||||
case INV2M:
|
||||
case OPEN:
|
||||
case ANDC:
|
||||
case XORC:
|
||||
case ORC:
|
||||
case ANDCI:
|
||||
case XORCI:
|
||||
case ORCI:
|
||||
case NOTC:
|
||||
case SHLC:
|
||||
case SHRC:
|
||||
case SHLCI:
|
||||
case SHRCI:
|
||||
case CONVINT:
|
||||
return CINT;
|
||||
default:
|
||||
if (is_gf2n_instruction())
|
||||
return GF2N;
|
||||
{
|
||||
Instruction tmp;
|
||||
tmp.opcode = opcode - 0x100;
|
||||
if (tmp.get_reg_type() == CINT)
|
||||
return CGF2N;
|
||||
else
|
||||
return SGF2N;
|
||||
}
|
||||
else if (opcode >> 4 == 0x9)
|
||||
return INT;
|
||||
else
|
||||
return MODP;
|
||||
return SINT;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,27 +611,35 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
int size_offset = 0;
|
||||
int size = this->size;
|
||||
|
||||
if (opcode == DABIT)
|
||||
// special treatment for instructions writing to different types
|
||||
switch (opcode)
|
||||
{
|
||||
case DABIT:
|
||||
if (reg_type == SBIT)
|
||||
return r[1] + size;
|
||||
else if (reg_type == MODP)
|
||||
else if (reg_type == SINT)
|
||||
return r[0] + size;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
else if (opcode == EDABIT or opcode == SEDABIT)
|
||||
{
|
||||
case EDABIT:
|
||||
case SEDABIT:
|
||||
if (reg_type == SBIT)
|
||||
skip = 1;
|
||||
else if (reg_type == MODP)
|
||||
else if (reg_type == SINT)
|
||||
return r[0] + size;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
else if (get_reg_type() != reg_type)
|
||||
{
|
||||
return 0;
|
||||
break;
|
||||
case INPUTMASKREG:
|
||||
if (reg_type == SINT)
|
||||
return r[0] + size;
|
||||
else if (reg_type == CINT)
|
||||
return r[1] + size;
|
||||
else
|
||||
return 0;
|
||||
default:
|
||||
if (get_reg_type() != reg_type)
|
||||
return 0;
|
||||
}
|
||||
|
||||
switch (opcode)
|
||||
@@ -607,6 +662,9 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return r[0] + start[0] * start[2];
|
||||
case CONV2DS:
|
||||
return r[0] + start[0] * start[1];
|
||||
case OPEN:
|
||||
skip = 2;
|
||||
break;
|
||||
case LDMSD:
|
||||
case LDMSDI:
|
||||
skip = 3;
|
||||
@@ -627,6 +685,20 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
offset = 3;
|
||||
size_offset = -2;
|
||||
break;
|
||||
case INPUTBVEC:
|
||||
{
|
||||
int res = 0;
|
||||
auto it = start.begin();
|
||||
while (it < start.end())
|
||||
{
|
||||
int n = *it - 3;
|
||||
it += 3;
|
||||
assert(it + n <= start.end());
|
||||
for (int i = 0; i < n; i++)
|
||||
res = max(res, *it++);
|
||||
}
|
||||
return res + 1;
|
||||
}
|
||||
case ANDM:
|
||||
case NOTS:
|
||||
size = DIV_CEIL(n, 64);
|
||||
@@ -673,16 +745,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
}
|
||||
|
||||
inline
|
||||
unsigned BaseInstruction::get_mem(RegType reg_type, SecrecyType sec_type) const
|
||||
unsigned BaseInstruction::get_mem(RegType reg_type) const
|
||||
{
|
||||
if (get_reg_type() == reg_type and is_direct_memory_access(sec_type))
|
||||
if (get_reg_type() == reg_type and is_direct_memory_access())
|
||||
return n + size;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline
|
||||
bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
|
||||
bool BaseInstruction::is_direct_memory_access() const
|
||||
{
|
||||
switch (opcode)
|
||||
{
|
||||
@@ -690,12 +762,10 @@ bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
|
||||
case STMS:
|
||||
case GLDMS:
|
||||
case GSTMS:
|
||||
return sec_type == SECRET;
|
||||
case LDMC:
|
||||
case STMC:
|
||||
case GLDMC:
|
||||
case GSTMC:
|
||||
return sec_type == CLEAR;
|
||||
case LDMINT:
|
||||
case STMINT:
|
||||
case LDMSB:
|
||||
@@ -731,16 +801,9 @@ __attribute__((always_inline))
|
||||
#endif
|
||||
inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
{
|
||||
Proc.PC+=1;
|
||||
auto& Procp = Proc.Procp;
|
||||
auto& Proc2 = Proc.Proc2;
|
||||
|
||||
// binary instructions
|
||||
typedef typename sint::bit_type T;
|
||||
auto& processor = Proc.Procb;
|
||||
auto& instruction = *this;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
// optimize some instructions
|
||||
switch (opcode)
|
||||
{
|
||||
@@ -772,102 +835,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_S2_ref(r[0] + i).mul(Proc.read_S2(r[1] + i),Proc.read_C2(r[2] + i));
|
||||
return;
|
||||
case LDI:
|
||||
Proc.temp.assign_ansp(n);
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Cp(r[0] + i,Proc.temp.ansp);
|
||||
return;
|
||||
case LDMSI:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(Proc.read_Ci(r[1] + i)));
|
||||
return;
|
||||
case LDMS:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(n + i));
|
||||
return;
|
||||
case STMSI:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.machine.Mp.write_S(Proc.read_Ci(r[1] + i), Proc.read_Sp(r[0] + i), Proc.PC);
|
||||
return;
|
||||
case STMS:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.machine.Mp.write_S(n + i, Proc.read_Sp(r[0] + i), Proc.PC);
|
||||
return;
|
||||
case ADDC:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Cp_ref(r[0] + i).add(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i));
|
||||
return;
|
||||
case ADDS:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i));
|
||||
return;
|
||||
case ADDM:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i),Proc.P.my_num(),Proc.MCp.get_alphai());
|
||||
return;
|
||||
case ADDCI:
|
||||
Proc.temp.assign_ansp(n);
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Cp_ref(r[0] + i).add(Proc.temp.ansp,Proc.read_Cp(r[1] + i));
|
||||
return;
|
||||
case SUBS:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i));
|
||||
return;
|
||||
case SUBSFI:
|
||||
Proc.temp.assign_ansp(n);
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).sub(Proc.temp.ansp,Proc.read_Sp(r[1] + i),Proc.P.my_num(),Proc.MCp.get_alphai());
|
||||
return;
|
||||
case SUBML:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i), Proc.read_Cp(r[2] + i),
|
||||
Proc.P.my_num(), Proc.MCp.get_alphai());
|
||||
return;
|
||||
case SUBMR:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Cp(r[1] + i), Proc.read_Sp(r[2] + i),
|
||||
Proc.P.my_num(), Proc.MCp.get_alphai());
|
||||
return;
|
||||
case MULM:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Sp_ref(r[0] + i).mul(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i));
|
||||
return;
|
||||
case MULC:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Cp_ref(r[0] + i).mul(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i));
|
||||
return;
|
||||
case MULCI:
|
||||
Proc.temp.assign_ansp(n);
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Cp_ref(r[0] + i).mul(Proc.temp.ansp,Proc.read_Cp(r[1] + i));
|
||||
return;
|
||||
case SHRCI:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Cp_ref(r[0] + i).SHR(Proc.read_Cp(r[1] + i), n);
|
||||
return;
|
||||
case TRIPLE:
|
||||
for (int i = 0; i < size; i++)
|
||||
Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i),
|
||||
Proc.get_Sp_ref(r[1] + i), Proc.get_Sp_ref(r[2] + i));
|
||||
return;
|
||||
case BIT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0] + i));
|
||||
return;
|
||||
case LDINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Ci(r[0] + i, int(n));
|
||||
return;
|
||||
case INCINT:
|
||||
{
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
int inc = (i / start[0]) % start[1];
|
||||
Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1]) + inc * int(n));
|
||||
}
|
||||
}
|
||||
return;
|
||||
case SHUFFLE:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1] + i));
|
||||
@@ -877,29 +844,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
swap(Proc.get_Ci_ref(r[0] + i), Proc.get_Ci_ref(r[0] + i + j));
|
||||
}
|
||||
return;
|
||||
case ADDINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) + Proc.read_Ci(r[2] + i);
|
||||
return;
|
||||
case SUBINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) - Proc.read_Ci(r[2] + i);
|
||||
return;
|
||||
case MULINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) * Proc.read_Ci(r[2] + i);
|
||||
return;
|
||||
case DIVINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) / Proc.read_Ci(r[2] + i);
|
||||
return;
|
||||
case CONVINT:
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
Proc.temp.assign_ansp(Proc.read_Ci(r[1] + i));
|
||||
Proc.get_Cp_ref(r[0] + i) = Proc.temp.ansp;
|
||||
}
|
||||
return;
|
||||
case CONVMODP:
|
||||
if (n == 0)
|
||||
{
|
||||
@@ -914,9 +858,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
throw Processor_Error(to_string(n) + "-bit conversion impossible; "
|
||||
"integer registers only have 64 bits");
|
||||
return;
|
||||
#define X(NAME, CODE) case NAME: CODE; return;
|
||||
COMBI_INSTRUCTIONS
|
||||
#undef X
|
||||
}
|
||||
|
||||
int r[3] = {this->r[0], this->r[1], this->r[2]};
|
||||
@@ -1275,6 +1216,13 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case GINV:
|
||||
Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]));
|
||||
break;
|
||||
case RANDOMS:
|
||||
Procp.protocol.randoms_inst(Procp, *this);
|
||||
return;
|
||||
case INPUTMASKREG:
|
||||
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2]));
|
||||
Proc.write_Cp(r[1], Proc.temp.rrp);
|
||||
break;
|
||||
case INPUTMASK:
|
||||
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n);
|
||||
if (n == Proc.P.my_num())
|
||||
@@ -1388,6 +1336,9 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case GSHRCI:
|
||||
Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n);
|
||||
break;
|
||||
case SHRSI:
|
||||
Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n;
|
||||
break;
|
||||
case GBITDEC:
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
@@ -1649,7 +1600,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
break;
|
||||
case RUN_TAPE:
|
||||
Proc.DataF.skip(
|
||||
Proc.machine.run_tape(r[0], n, r[1], -1, &Proc.DataF.DataFp,
|
||||
Proc.machine.run_tapes(start, &Proc.DataF.DataFp,
|
||||
&Proc.share_thread.DataF));
|
||||
break;
|
||||
case JOIN_TAPE:
|
||||
@@ -1788,8 +1739,41 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
|
||||
{
|
||||
unsigned int size = p.size();
|
||||
Proc.PC=0;
|
||||
|
||||
auto& Procp = Proc.Procp;
|
||||
|
||||
// binary instructions
|
||||
typedef typename sint::bit_type T;
|
||||
auto& processor = Proc.Procb;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
while (Proc.PC<size)
|
||||
{ p[Proc.PC].execute(Proc); }
|
||||
{
|
||||
auto& instruction = p[Proc.PC];
|
||||
auto& r = instruction.r;
|
||||
auto& n = instruction.n;
|
||||
auto& start = instruction.start;
|
||||
auto& size = instruction.size;
|
||||
|
||||
#ifdef COUNT_INSTRUCTIONS
|
||||
Proc.stats[p[Proc.PC].get_opcode()]++;
|
||||
#endif
|
||||
|
||||
Proc.PC++;
|
||||
|
||||
switch(instruction.get_opcode())
|
||||
{
|
||||
#define X(NAME, PRE, CODE) \
|
||||
case NAME: { PRE; for (int i = 0; i < size; i++) { CODE; } } break;
|
||||
ARITHMETIC_INSTRUCTIONS
|
||||
#undef X
|
||||
#define X(NAME, CODE) case NAME: CODE; break;
|
||||
COMBI_INSTRUCTIONS
|
||||
#undef X
|
||||
default:
|
||||
instruction.execute(Proc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "GC/Machine.h"
|
||||
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/ExecutionStats.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
@@ -44,9 +45,6 @@ class Machine : public BaseMachine
|
||||
// Keep record of used offline data
|
||||
DataPositions pos;
|
||||
|
||||
int tn,numt;
|
||||
bool usage_unknown;
|
||||
|
||||
void load_program(string threadname, string filename);
|
||||
|
||||
public:
|
||||
@@ -71,6 +69,7 @@ class Machine : public BaseMachine
|
||||
OnlineOptions opts;
|
||||
|
||||
atomic<size_t> data_sent;
|
||||
ExecutionStats stats;
|
||||
|
||||
Machine(int my_number, Names& playerNames, string progname,
|
||||
string memtype, int lg2, bool direct, int opening_sum,
|
||||
@@ -79,9 +78,12 @@ class Machine : public BaseMachine
|
||||
|
||||
const Names& get_N() { return N; }
|
||||
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg,
|
||||
int line_number, Preprocessing<sint>* prep = 0,
|
||||
Preprocessing<typename sint::bit_type>* bit_prep = 0);
|
||||
DataPositions run_tapes(const vector<int> &args, Preprocessing<sint> *prep,
|
||||
Preprocessing<typename sint::bit_type> *bit_prep);
|
||||
void fill_buffers(int thread_number, int tape_number,
|
||||
Preprocessing<sint> *prep,
|
||||
Preprocessing<typename sint::bit_type> *bit_prep);
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg);
|
||||
DataPositions join_tape(int thread_number);
|
||||
void run();
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/mkpath.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@@ -22,7 +23,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
string progname_str, string memtype, int lg2, bool direct,
|
||||
int opening_sum, bool receive_threads, int max_broadcast,
|
||||
bool use_encryption, bool live_prep, OnlineOptions opts)
|
||||
: my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false),
|
||||
: my_number(my_number), N(playerNames),
|
||||
direct(direct), opening_sum(opening_sum),
|
||||
receive_threads(receive_threads), max_broadcast(max_broadcast),
|
||||
use_encryption(use_encryption), live_prep(live_prep), opts(opts),
|
||||
@@ -40,9 +41,12 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
// make directory for outputs if necessary
|
||||
mkdir_p(PREP_DIR);
|
||||
|
||||
auto P = new PlainPlayer(N, 0xF00);
|
||||
sint::LivePrep::basic_setup(*P);
|
||||
delete P;
|
||||
if (opts.live_prep)
|
||||
{
|
||||
auto P = new PlainPlayer(N, 0xF00);
|
||||
sint::LivePrep::basic_setup(*P);
|
||||
delete P;
|
||||
}
|
||||
|
||||
sint::read_or_generate_mac_key(prep_dir_prefix<sint>(), N, alphapi);
|
||||
sgf2n::read_or_generate_mac_key(prep_dir_prefix<sgf2n>(), N, alpha2i);
|
||||
@@ -131,21 +135,29 @@ void Machine<sint, sgf2n>::load_program(string threadname, string filename)
|
||||
int i = progs.size() - 1;
|
||||
progs[i].parse(pinp);
|
||||
pinp.close();
|
||||
M2.minimum_size(GF2N, progs[i], threadname);
|
||||
Mp.minimum_size(MODP, progs[i], threadname);
|
||||
Mi.minimum_size(INT, progs[i], threadname);
|
||||
M2.minimum_size(SGF2N, CGF2N, progs[i], threadname);
|
||||
Mp.minimum_size(SINT, CINT, progs[i], threadname);
|
||||
Mi.minimum_size(NONE, INT, progs[i], threadname);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
int arg, int line_number, Preprocessing<sint>* prep,
|
||||
DataPositions Machine<sint, sgf2n>::run_tapes(const vector<int>& args,
|
||||
Preprocessing<sint>* prep, Preprocessing<typename sint::bit_type>* bit_prep)
|
||||
{
|
||||
assert(args.size() % 3 == 0);
|
||||
for (unsigned i = 0; i < args.size(); i += 3)
|
||||
fill_buffers(args[i], args[i + 1], prep, bit_prep);
|
||||
DataPositions res(N.num_players());
|
||||
for (unsigned i = 0; i < args.size(); i += 3)
|
||||
res.increase(run_tape(args[i], args[i + 1], args[i + 2]));
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
|
||||
Preprocessing<sint>* prep,
|
||||
Preprocessing<typename sint::bit_type>* bit_prep)
|
||||
{
|
||||
if (size_t(thread_number) >= tinfo.size())
|
||||
throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size()));
|
||||
if (size_t(tape_number) >= progs.size())
|
||||
throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size()));
|
||||
|
||||
// central preprocessing
|
||||
auto usage = progs[tape_number].get_offline_data_used();
|
||||
if (sint::expensive and prep != 0 and OnlineOptions::singleton.bucket_size == 3)
|
||||
@@ -203,6 +215,16 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
int arg)
|
||||
{
|
||||
if (size_t(thread_number) >= tinfo.size())
|
||||
throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size()));
|
||||
if (size_t(tape_number) >= progs.size())
|
||||
throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size()));
|
||||
|
||||
queues[thread_number]->schedule({tape_number, arg, pos});
|
||||
//printf("Send signal to run program %d in thread %d\n",tape_number,thread_number);
|
||||
@@ -211,21 +233,10 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
{
|
||||
if (not opts.live_prep)
|
||||
{
|
||||
// only one thread allowed
|
||||
if (numt>1)
|
||||
{ cerr << "Line " << line_number << " has " <<
|
||||
numt << " threads but tape " << tape_number <<
|
||||
cerr << "Internally called tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
throw invalid_program();
|
||||
}
|
||||
else if (line_number == -1)
|
||||
{
|
||||
cerr << "Internally called tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
throw invalid_program();
|
||||
}
|
||||
throw invalid_program();
|
||||
}
|
||||
usage_unknown = true;
|
||||
return DataPositions(N.num_players());
|
||||
}
|
||||
else
|
||||
@@ -252,43 +263,12 @@ void Machine<sint, sgf2n>::run()
|
||||
proc_timer.start();
|
||||
timer[0].start();
|
||||
|
||||
bool flag=true;
|
||||
usage_unknown=false;
|
||||
int exec=0;
|
||||
while (flag)
|
||||
{ inpf >> numt;
|
||||
if (numt==0)
|
||||
{ flag=false; }
|
||||
else
|
||||
{ for (int i=0; i<numt; i++)
|
||||
{
|
||||
// Now load up data
|
||||
inpf >> tn;
|
||||
|
||||
// Cope with passing an integer parameter to a tape
|
||||
int arg;
|
||||
if (inpf.get() == ':')
|
||||
inpf >> arg;
|
||||
else
|
||||
arg = 0;
|
||||
|
||||
//cerr << "Run scheduled tape " << tn << " in thread " << i << endl;
|
||||
pos.increase(run_tape(i, tn, arg, exec));
|
||||
}
|
||||
// Make sure all terminate before we continue
|
||||
auto new_pos = join_tape(0);
|
||||
for (int i=1; i<numt; i++)
|
||||
{ join_tape(i);
|
||||
}
|
||||
if (usage_unknown)
|
||||
{ // synchronize files
|
||||
pos = new_pos;
|
||||
usage_unknown = false;
|
||||
}
|
||||
//printf("Finished running line %d\n",exec);
|
||||
exec++;
|
||||
}
|
||||
}
|
||||
// legacy
|
||||
int _;
|
||||
inpf >> _ >> _ >> _;
|
||||
// run main tape
|
||||
pos.increase(run_tape(0, 0, 0));
|
||||
join_tape(0);
|
||||
|
||||
print_compiler();
|
||||
|
||||
@@ -317,6 +297,16 @@ void Machine<sint, sgf2n>::run()
|
||||
finish_timer.stop();
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "Memory usage: ";
|
||||
tinfo[0].print_usage(cerr, Mp.MS, "sint");
|
||||
tinfo[0].print_usage(cerr, Mp.MC, "cint");
|
||||
tinfo[0].print_usage(cerr, M2.MS, "sgf2n");
|
||||
tinfo[0].print_usage(cerr, M2.MS, "cgf2n");
|
||||
tinfo[0].print_usage(cerr, bit_memories.MS, "sbits");
|
||||
tinfo[0].print_usage(cerr, bit_memories.MC, "cbits");
|
||||
tinfo[0].print_usage(cerr, Mi.MC, "regint");
|
||||
cerr << endl;
|
||||
|
||||
for (unsigned int i = 0; i < join_timer.size(); i++)
|
||||
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
|
||||
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
|
||||
@@ -326,6 +316,15 @@ void Machine<sint, sgf2n>::run()
|
||||
print_timers();
|
||||
cerr << "Data sent = " << data_sent / 1e6 << " MB" << endl;
|
||||
|
||||
PlainPlayer P(N, 0xFFF0);
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(data_sent.load());
|
||||
P.Broadcast_Receive_no_stats(bundle);
|
||||
size_t global = 0;
|
||||
for (auto& os : bundle)
|
||||
global += os.get_int(8);
|
||||
cerr << "Global data sent = " << global / 1e6 << " MB" << endl;
|
||||
|
||||
#ifdef VERBOSE
|
||||
if (opening_sum < N.num_players() && !direct)
|
||||
cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl;
|
||||
@@ -374,6 +373,39 @@ void Machine<sint, sgf2n>::run()
|
||||
pos.print_cost();
|
||||
#endif
|
||||
|
||||
if (not stats.empty())
|
||||
{
|
||||
cerr << "Instruction statistics:" << endl;
|
||||
set<pair<size_t, int>> sorted_stats;
|
||||
for (auto& x : stats)
|
||||
{
|
||||
sorted_stats.insert({x.second, x.first});
|
||||
}
|
||||
for (auto& x : sorted_stats)
|
||||
{
|
||||
auto opcode = x.second;
|
||||
auto calls = x.first;
|
||||
cerr << "\t";
|
||||
int n_fill = 15;
|
||||
switch (opcode)
|
||||
{
|
||||
#define X(NAME, PRE, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break;
|
||||
ARITHMETIC_INSTRUCTIONS
|
||||
#undef X
|
||||
#define X(NAME, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break;
|
||||
COMBI_INSTRUCTIONS
|
||||
#undef X
|
||||
default:
|
||||
cerr << hex << setw(5) << showbase << left << opcode;
|
||||
n_fill -= 5;
|
||||
cerr << setw(0);
|
||||
}
|
||||
for (int i = 0; i < n_fill; i++)
|
||||
cerr << " ";
|
||||
cerr << dec << calls << endl;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef INSECURE
|
||||
Data_Files<sint, sgf2n> df(*this);
|
||||
df.seekg(pos);
|
||||
|
||||
@@ -26,7 +26,7 @@ class Memory
|
||||
public:
|
||||
|
||||
CheckVector<T> MS;
|
||||
vector<typename T::clear> MC;
|
||||
CheckVector<typename T::clear> MC;
|
||||
|
||||
void resize_s(int sz)
|
||||
{ MS.resize(sz); }
|
||||
@@ -73,7 +73,8 @@ class Memory
|
||||
{ (void)start, (void)end; cerr << "Memory protection not activated" << endl; }
|
||||
#endif
|
||||
|
||||
void minimum_size(RegType reg_type, const Program& program, string threadname);
|
||||
void minimum_size(RegType secret_type, RegType clear_type,
|
||||
const Program& program, string threadname);
|
||||
|
||||
friend ostream& operator<< <>(ostream& s,const Memory<T>& M);
|
||||
friend istream& operator>> <>(istream& s,Memory<T>& M);
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
#include <fstream>
|
||||
|
||||
template<class T>
|
||||
void Memory<T>::minimum_size(RegType reg_type, const Program& program, string threadname)
|
||||
void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
|
||||
const Program &program, string threadname)
|
||||
{
|
||||
(void) threadname;
|
||||
const unsigned* sizes = program.direct_mem(reg_type);
|
||||
unsigned sizes[MAX_SECRECY_TYPE];
|
||||
sizes[SECRET]= program.direct_mem(secret_type);
|
||||
sizes[CLEAR] = program.direct_mem(clear_type);
|
||||
if (sizes[SECRET] > size_s())
|
||||
{
|
||||
#ifdef DEBUG_MEMORY
|
||||
|
||||
@@ -28,6 +28,9 @@ class thread_info
|
||||
|
||||
static void purge_preprocessing(Machine<sint, sgf2n>& machine);
|
||||
|
||||
template<class T>
|
||||
static void print_usage(ostream& o, const vector<T>& regs, string name);
|
||||
|
||||
void Sub_Main_Func();
|
||||
};
|
||||
|
||||
|
||||
@@ -20,6 +20,15 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
template<class T>
|
||||
void thread_info<sint, sgf2n>::print_usage(ostream &o,
|
||||
const vector<T>& regs, string name)
|
||||
{
|
||||
if (regs.capacity())
|
||||
o << name << "=" << regs.capacity() << " ";
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
{
|
||||
@@ -105,7 +114,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
while (flag)
|
||||
{ // Wait until I have a program to run
|
||||
wait_timer.start();
|
||||
auto job = queues->next();
|
||||
ThreadJob job = queues->next();
|
||||
program = job.prognum;
|
||||
wait_timer.stop();
|
||||
#ifdef DEBUG_THREADS
|
||||
@@ -279,6 +288,16 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl;
|
||||
cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl;
|
||||
|
||||
cerr << "Register usage: ";
|
||||
print_usage(cerr, Proc.Procp.get_S(), "sint");
|
||||
print_usage(cerr, Proc.Procp.get_C(), "cint");
|
||||
print_usage(cerr, Proc.Proc2.get_S(), "sgf2n");
|
||||
print_usage(cerr, Proc.Proc2.get_C(), "cgf2n");
|
||||
print_usage(cerr, Proc.Procb.S, "sbits");
|
||||
print_usage(cerr, Proc.Procb.C, "cbits");
|
||||
print_usage(cerr, Proc.get_Ci(), "regint");
|
||||
cerr << endl;
|
||||
#endif
|
||||
|
||||
// wind down thread by thread
|
||||
@@ -287,6 +306,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
prep_sent += Proc.Procp.bit_prep.data_sent();
|
||||
for (auto& x : Proc.Procp.personal_bit_preps)
|
||||
prep_sent += x->data_sent();
|
||||
machine.stats += Proc.stats;
|
||||
delete processor;
|
||||
|
||||
machine.data_sent += P.sent + prep_sent;
|
||||
|
||||
@@ -163,6 +163,7 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
|
||||
opt.parse(argc, argv);
|
||||
|
||||
vector<string*> allArgs(opt.firstArgs);
|
||||
allArgs.insert(allArgs.end(), opt.unknownArgs.begin(), opt.unknownArgs.end());
|
||||
allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end());
|
||||
string usage;
|
||||
vector<string> badOptions;
|
||||
|
||||
@@ -26,7 +26,7 @@ class Program;
|
||||
template <class T>
|
||||
class SubProcessor
|
||||
{
|
||||
vector<typename T::clear> C;
|
||||
CheckVector<typename T::clear> C;
|
||||
CheckVector<T> S;
|
||||
|
||||
DataPositions bit_usage;
|
||||
@@ -70,11 +70,16 @@ public:
|
||||
int b);
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
vector<T>& get_S()
|
||||
CheckVector<T>& get_S()
|
||||
{
|
||||
return S;
|
||||
}
|
||||
|
||||
CheckVector<typename T::clear>& get_C()
|
||||
{
|
||||
return C;
|
||||
}
|
||||
|
||||
T& get_S_ref(int i)
|
||||
{
|
||||
return S[i];
|
||||
@@ -132,7 +137,7 @@ public:
|
||||
template<class sint, class sgf2n>
|
||||
class Processor : public ArithmeticProcessor
|
||||
{
|
||||
int reg_max2,reg_maxp,reg_maxi;
|
||||
int reg_max2, reg_maxi;
|
||||
|
||||
// Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ)
|
||||
octetStream socket_stream;
|
||||
|
||||
@@ -117,11 +117,11 @@ string Processor<sint, sgf2n>::get_filename(const char* prefix, bool use_number)
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::reset(const Program& program,int arg)
|
||||
{
|
||||
reg_max2 = program.num_reg(GF2N);
|
||||
reg_maxp = program.num_reg(MODP);
|
||||
reg_maxi = program.num_reg(INT);
|
||||
Proc2.resize(reg_max2);
|
||||
Procp.resize(reg_maxp);
|
||||
Proc2.get_S().resize(program.num_reg(SGF2N));
|
||||
Proc2.get_C().resize(program.num_reg(CGF2N));
|
||||
Procp.get_S().resize(program.num_reg(SINT));
|
||||
Procp.get_C().resize(program.num_reg(CINT));
|
||||
Ci.resize(reg_maxi);
|
||||
this->arg = arg;
|
||||
Procb.reset(program);
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/ExecutionStats.h"
|
||||
|
||||
class ProcessorBase
|
||||
{
|
||||
// Stack
|
||||
@@ -24,6 +26,8 @@ protected:
|
||||
int arg;
|
||||
|
||||
public:
|
||||
ExecutionStats stats;
|
||||
|
||||
void pushi(long x) { stacki.push(x); }
|
||||
void popi(long& x) { x = stacki.top(); stacki.pop(); }
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ void Program::compute_constants()
|
||||
for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++)
|
||||
{
|
||||
max_reg[reg_type] = 0;
|
||||
for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++)
|
||||
max_mem[reg_type][sec_type] = 0;
|
||||
max_mem[reg_type] = 0;
|
||||
}
|
||||
for (unsigned int i=0; i<p.size(); i++)
|
||||
{
|
||||
@@ -21,13 +20,10 @@ void Program::compute_constants()
|
||||
{
|
||||
max_reg[reg_type] = max(max_reg[reg_type],
|
||||
p[i].get_max_reg(reg_type));
|
||||
for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++)
|
||||
max_mem[reg_type][sec_type] = max(max_mem[reg_type][sec_type],
|
||||
p[i].get_mem(RegType(reg_type), SecrecyType(sec_type)));
|
||||
max_mem[reg_type] = max(max_mem[reg_type],
|
||||
p[i].get_mem(RegType(reg_type)));
|
||||
}
|
||||
}
|
||||
|
||||
max_mem[INT][SECRET] = 0;
|
||||
}
|
||||
|
||||
void Program::parse(istream& s)
|
||||
|
||||
@@ -21,7 +21,7 @@ class Program
|
||||
unsigned max_reg[MAX_REG_TYPE];
|
||||
|
||||
// Memory size used directly
|
||||
unsigned max_mem[MAX_REG_TYPE][MAX_SECRECY_TYPE];
|
||||
unsigned max_mem[MAX_REG_TYPE];
|
||||
|
||||
// True if program contains variable-sized loop
|
||||
bool unknown_usage;
|
||||
@@ -45,7 +45,7 @@ class Program
|
||||
int num_reg(RegType reg_type) const
|
||||
{ return max_reg[reg_type]; }
|
||||
|
||||
const unsigned* direct_mem(RegType reg_type) const
|
||||
unsigned direct_mem(RegType reg_type) const
|
||||
{ return max_mem[reg_type]; }
|
||||
|
||||
friend ostream& operator<<(ostream& s,const Program& P);
|
||||
|
||||
85
Processor/instructions.h
Normal file
85
Processor/instructions.h
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* instructions.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_INSTRUCTIONS_H_
|
||||
#define PROCESSOR_INSTRUCTIONS_H_
|
||||
|
||||
#define ARITHMETIC_INSTRUCTIONS \
|
||||
X(LDI, auto dest = &Procp.get_C()[r[0]]; typename sint::clear tmp = int(n), \
|
||||
*dest++ = tmp) \
|
||||
X(LDMS, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.machine.Mp.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(STMS, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.machine.Mp.MS[n], \
|
||||
*dest++ = *source++) \
|
||||
X(LDMSI, auto dest = &Procp.get_S()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = Proc.machine.Mp.read_S(*source++)) \
|
||||
X(STMSI, auto source = &Procp.get_S()[r[0]]; auto dest = &Proc.get_Ci()[r[1]], \
|
||||
Proc.machine.Mp.write_S(*dest++, *source++)) \
|
||||
X(MOVS, auto dest = &Procp.get_S()[r[0]]; auto source = &Procp.get_S()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
X(ADDS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_S()[r[2]], \
|
||||
*dest++ = *op1++ + *op2++) \
|
||||
X(ADDM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ + sint::constant(*op2++, Proc.P.my_num(), Procp.MC.get_alphai())) \
|
||||
X(ADDC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ + *op2++) \
|
||||
X(ADDCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
typename sint::clear op2 = int(n), \
|
||||
*dest++ = *op1++ + op2) \
|
||||
X(SUBS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_S()[r[2]], \
|
||||
*dest++ = *op1++ - *op2++) \
|
||||
X(SUBSFI, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = sint::constant(int(n), Proc.P.my_num(), Procp.MC.get_alphai()), \
|
||||
*dest++ = op2 - *op1++) \
|
||||
X(SUBML, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ - sint::constant(*op2++, Proc.P.my_num(), Procp.MC.get_alphai())) \
|
||||
X(SUBMR, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
auto op2 = &Procp.get_S()[r[2]], \
|
||||
*dest++ = sint::constant(*op1++, Proc.P.my_num(), Procp.MC.get_alphai()) - *op2++) \
|
||||
X(SUBC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ - *op2++) \
|
||||
X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(MULC, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
auto op2 = &Procp.get_C()[r[2]], \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(MULCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \
|
||||
typename sint::clear op2 = int(n), \
|
||||
*dest++ = *op1++ * op2) \
|
||||
X(SHRCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]], \
|
||||
*dest++ = *op1++ >> n) \
|
||||
X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \
|
||||
auto c = &Procp.get_S()[r[2]], \
|
||||
Procp.DataF.get_three(DATA_TRIPLE, *a++, *b++, *c++)) \
|
||||
X(BIT, auto dest = &Procp.get_S()[r[0]], \
|
||||
Procp.DataF.get_one(DATA_BIT, *dest++)) \
|
||||
X(LDINT, auto dest = &Proc.get_Ci()[r[0]], \
|
||||
*dest++ = int(n)) \
|
||||
X(ADDINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
|
||||
auto op2 = &Proc.get_Ci()[r[2]], \
|
||||
*dest++ = *op1++ + *op2++) \
|
||||
X(SUBINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
|
||||
auto op2 = &Proc.get_Ci()[r[2]], \
|
||||
*dest++ = *op1++ - *op2++) \
|
||||
X(MULINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
|
||||
auto op2 = &Proc.get_Ci()[r[2]], \
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(DIVINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
|
||||
auto op2 = &Proc.get_Ci()[r[2]], \
|
||||
*dest++ = *op1++ / *op2++) \
|
||||
X(INCINT, auto dest = &Proc.get_Ci()[r[0]]; auto base = Proc.get_Ci()[r[1]], \
|
||||
int inc = (i / start[0]) % start[1]; *dest++ = base + inc * int(n)) \
|
||||
X(CONVINT, auto dest = &Procp.get_C()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \
|
||||
*dest++ = *source++) \
|
||||
|
||||
|
||||
#endif /* PROCESSOR_INSTRUCTIONS_H_ */
|
||||
Submodule Programs/Circuits updated: 82dfda9d12...908452826c
@@ -2,11 +2,7 @@ import ml
|
||||
import util
|
||||
import math
|
||||
|
||||
if 'trunc_pr' in program.args:
|
||||
program.use_trunc_pr = True
|
||||
if 'split' in program.args:
|
||||
program.use_split(3)
|
||||
|
||||
program.options_from_args()
|
||||
program.options.cisc = True
|
||||
|
||||
try:
|
||||
@@ -72,7 +68,7 @@ else:
|
||||
opt = ml.Optimizer()
|
||||
opt.layers = layers
|
||||
for layer in layers:
|
||||
layer.input_from(0, raw='raw' in program.args)
|
||||
layer.input_from(0)
|
||||
layers[0].X.input_from(1)
|
||||
start_timer(1)
|
||||
opt.forward(1)
|
||||
|
||||
@@ -3,13 +3,7 @@ from Compiler import ml
|
||||
debug = False
|
||||
|
||||
program.use_edabit(True)
|
||||
program.use_trunc_pr = True
|
||||
|
||||
if 'split' in program.args:
|
||||
program.use_split(3)
|
||||
|
||||
if 'split2' in program.args:
|
||||
program.use_split(2)
|
||||
program.options_from_args()
|
||||
|
||||
sfix.set_precision(16, 31)
|
||||
cfix.set_precision(16, 31)
|
||||
@@ -29,9 +23,19 @@ dense = ml.Dense(12800, dim, 1)
|
||||
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
|
||||
|
||||
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
|
||||
sgd.reset([X_normal, X_pos])
|
||||
sgd.run(batch_size=batch)
|
||||
|
||||
# @for_range(1000)
|
||||
# def _(i):
|
||||
# sgd.backward()
|
||||
if not ('forward' in program.args or 'backward' in program.args):
|
||||
sgd.reset([X_normal, X_pos])
|
||||
sgd.run(batch_size=batch)
|
||||
|
||||
if 'forward' in program.args:
|
||||
@for_range(1000)
|
||||
def _(i):
|
||||
sgd.forward(N=batch)
|
||||
|
||||
if 'backward' in program.args:
|
||||
b = regint.Array(batch)
|
||||
b.assign(regint.inc(batch))
|
||||
@for_range(1000)
|
||||
def _(i):
|
||||
sgd.backward(batch=b)
|
||||
|
||||
@@ -97,3 +97,6 @@ test(c[0], 0)
|
||||
test(c[1], 1)
|
||||
test(c[2], 1)
|
||||
test(c[3], 0)
|
||||
|
||||
test(sbit(sbits(3)), 1)
|
||||
test(sbits(sbit(1)), 1)
|
||||
|
||||
51
Protocols/FakeInput.h
Normal file
51
Protocols/FakeInput.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* FakeProtocol.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_FAKEINPUT_H_
|
||||
#define PROTOCOLS_FAKEINPUT_H_
|
||||
|
||||
#include "Replicated.h"
|
||||
#include "Processor/Input.h"
|
||||
|
||||
template<class T>
|
||||
class FakeInput : public InputBase<T>
|
||||
{
|
||||
PointerVector<T> results;
|
||||
|
||||
public:
|
||||
FakeInput(SubProcessor<T>&, typename T::MAC_Check&)
|
||||
{
|
||||
}
|
||||
|
||||
void reset(int)
|
||||
{
|
||||
results.clear();
|
||||
}
|
||||
|
||||
void add_mine(const typename T::open_type& x, int = -1)
|
||||
{
|
||||
results.push_back(x);
|
||||
}
|
||||
|
||||
void add_other(int)
|
||||
{
|
||||
}
|
||||
|
||||
void send_mine()
|
||||
{
|
||||
}
|
||||
|
||||
T finalize_mine()
|
||||
{
|
||||
return results.next();
|
||||
}
|
||||
|
||||
void finalize_other(int, T&, octetStream&, int = -1)
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */
|
||||
31
Protocols/FakeMC.h
Normal file
31
Protocols/FakeMC.h
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* FakeMC.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_FAKEMC_H_
|
||||
#define PROTOCOLS_FAKEMC_H_
|
||||
|
||||
#include "MAC_Check_Base.h"
|
||||
|
||||
template<class T>
|
||||
class FakeMC : public MAC_Check_Base<T>
|
||||
{
|
||||
public:
|
||||
FakeMC(T, int = 0, int = 0)
|
||||
{
|
||||
}
|
||||
|
||||
void exchange(const Player&)
|
||||
{
|
||||
for (auto& x : this->secrets)
|
||||
this->values.push_back(x);
|
||||
}
|
||||
|
||||
FakeMC& get_part_MC()
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKEMC_H_ */
|
||||
82
Protocols/FakePrep.h
Normal file
82
Protocols/FakePrep.h
Normal file
@@ -0,0 +1,82 @@
|
||||
/*
|
||||
* FakePrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_FAKEPREP_H_
|
||||
#define PROTOCOLS_FAKEPREP_H_
|
||||
|
||||
#include "ReplicatedPrep.h"
|
||||
|
||||
template<class T>
|
||||
class FakePrep : public BufferPrep<T>
|
||||
{
|
||||
SeededPRNG G;
|
||||
|
||||
public:
|
||||
FakePrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
FakePrep(DataPositions& usage, GC::ShareThread<T>&) :
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
FakePrep(DataPositions& usage, int = 0) :
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
void set_protocol(typename T::Protocol&)
|
||||
{
|
||||
}
|
||||
|
||||
void buffer_triples()
|
||||
{
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
auto a = G.get<T>();
|
||||
auto b = G.get<T>();
|
||||
this->triples.push_back({{a, b, a * b}});
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_squares()
|
||||
{
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
auto a = G.get<T>();
|
||||
this->squares.push_back({{a, a * a}});
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_inverses()
|
||||
{
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
auto a = G.get<T>();
|
||||
T aa;
|
||||
aa.invert(a);
|
||||
this->inverses.push_back({{a, aa}});
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_bits()
|
||||
{
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
this->bits.push_back(G.get_bit());
|
||||
}
|
||||
}
|
||||
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b)
|
||||
{
|
||||
auto bit = G.get_bit();
|
||||
a = bit;
|
||||
b = bit;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKEPREP_H_ */
|
||||
63
Protocols/FakeProtocol.h
Normal file
63
Protocols/FakeProtocol.h
Normal file
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
* FakeProtocol.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_FAKEPROTOCOL_H_
|
||||
#define PROTOCOLS_FAKEPROTOCOL_H_
|
||||
|
||||
#include "Replicated.h"
|
||||
|
||||
template<class T>
|
||||
class FakeProtocol : public ProtocolBase<T>
|
||||
{
|
||||
PointerVector<T> results;
|
||||
SeededPRNG G;
|
||||
|
||||
public:
|
||||
Player& P;
|
||||
|
||||
FakeProtocol(Player& P) : P(P)
|
||||
{
|
||||
}
|
||||
|
||||
void init_mul(SubProcessor<T>*)
|
||||
{
|
||||
results.clear();
|
||||
}
|
||||
|
||||
typename T::clear prepare_mul(const T& x, const T& y, int = -1)
|
||||
{
|
||||
results.push_back(x * y);
|
||||
return {};
|
||||
}
|
||||
|
||||
void exchange()
|
||||
{
|
||||
}
|
||||
|
||||
T finalize_mul(int = -1)
|
||||
{
|
||||
return results.next();
|
||||
}
|
||||
|
||||
void randoms(T& res, int n_bits)
|
||||
{
|
||||
res.randomize_part(G, n_bits);
|
||||
}
|
||||
|
||||
int get_n_relevant_players()
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc)
|
||||
{
|
||||
for (size_t i = 0; i < regs.size(); i += 4)
|
||||
for (int l = 0; l < size; l++)
|
||||
proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l)
|
||||
>> regs[i + 3];
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */
|
||||
80
Protocols/FakeShare.h
Normal file
80
Protocols/FakeShare.h
Normal file
@@ -0,0 +1,80 @@
|
||||
/*
|
||||
* FakeShare.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_FAKESHARE_H_
|
||||
#define PROTOCOLS_FAKESHARE_H_
|
||||
|
||||
#include "GC/FakeSecret.h"
|
||||
#include "ShareInterface.h"
|
||||
#include "Processor/NoLivePrep.h"
|
||||
#include "FakeMC.h"
|
||||
#include "FakeProtocol.h"
|
||||
#include "FakePrep.h"
|
||||
#include "FakeInput.h"
|
||||
|
||||
template<class T>
|
||||
class FakeShare : public T, public ShareInterface
|
||||
{
|
||||
typedef FakeShare This;
|
||||
|
||||
public:
|
||||
typedef T mac_key_type;
|
||||
typedef T open_type;
|
||||
typedef T clear;
|
||||
|
||||
typedef FakePrep<This> LivePrep;
|
||||
typedef FakeMC<This> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef FakeInput<This> Input;
|
||||
typedef ::PrivateOutput<This> PrivateOutput;
|
||||
typedef FakeProtocol<This> Protocol;
|
||||
|
||||
typedef GC::FakeSecret bit_type;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "emul";
|
||||
}
|
||||
|
||||
static int threshold(int)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static T constant(T value, int = 0, T = 0)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
FakeShare()
|
||||
{
|
||||
}
|
||||
|
||||
template<class U>
|
||||
FakeShare(U other) :
|
||||
T(other)
|
||||
{
|
||||
}
|
||||
|
||||
void assign(T value, int = 0, T = 0)
|
||||
{
|
||||
*this = value;
|
||||
}
|
||||
|
||||
void add(T a, T b, int = 0, T = {})
|
||||
{
|
||||
*this = a + b;
|
||||
}
|
||||
|
||||
void sub(T a, T b, int = 0, T = {})
|
||||
{
|
||||
*this = a - b;
|
||||
}
|
||||
|
||||
static void split(vector<bit_type>& dest, const vector<int>& regs, int n_bits,
|
||||
const This* source, int n_inputs, Player& P);
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_FAKESHARE_H_ */
|
||||
50
Protocols/FakeShare.hpp
Normal file
50
Protocols/FakeShare.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* FakeShare.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "FakeShare.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "GC/square64.h"
|
||||
|
||||
template<class T>
|
||||
void FakeShare<T>::split(vector<bit_type>& dest,
|
||||
const vector<int>& regs, int n_bits, const This* source, int n_inputs,
|
||||
Player&)
|
||||
{
|
||||
assert(n_bits <= 64);
|
||||
int unit = GC::Clear::N_BITS;
|
||||
for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++)
|
||||
{
|
||||
int start = k * unit;
|
||||
int m = min(unit, n_inputs - start);
|
||||
|
||||
switch (regs.size() / n_bits)
|
||||
{
|
||||
case 3:
|
||||
{
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
for (int j = 1; j < 3; j++)
|
||||
dest.at(regs.at(3 * i + j) + k) = {};
|
||||
|
||||
square64 square;
|
||||
|
||||
for (int j = 0; j < m; j++)
|
||||
{
|
||||
square.rows[j] = (source[j + start]).get_limb(0);
|
||||
}
|
||||
|
||||
square.transpose(m, n_bits);
|
||||
|
||||
for (int j = 0; j < n_bits; j++)
|
||||
{
|
||||
auto& dest_reg = dest.at(regs.at(3 * j) + k);
|
||||
dest_reg = square.rows[j];
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
not_implemented();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,10 +27,20 @@ public:
|
||||
void buffer_inputs(int player);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class RingOnlyBitsFromSquaresPrep : public virtual BufferPrep<T>
|
||||
{
|
||||
public:
|
||||
RingOnlyBitsFromSquaresPrep(SubProcessor<T>* proc, DataPositions& usage);
|
||||
|
||||
void buffer_bits();
|
||||
};
|
||||
|
||||
// extra class to avoid recursion
|
||||
template<class T>
|
||||
class MalRepRingPrepWithBits: public virtual MaliciousRingPrep<T>,
|
||||
public virtual MalRepRingPrep<T>
|
||||
public virtual MalRepRingPrep<T>,
|
||||
public virtual RingOnlyBitsFromSquaresPrep<T>
|
||||
{
|
||||
public:
|
||||
MalRepRingPrepWithBits(SubProcessor<T>* proc, DataPositions& usage);
|
||||
@@ -40,12 +50,20 @@ public:
|
||||
MaliciousRingPrep<T>::set_protocol(protocol);
|
||||
}
|
||||
|
||||
void buffer_triples()
|
||||
{
|
||||
MalRepRingPrep<T>::buffer_triples();
|
||||
}
|
||||
|
||||
void buffer_squares()
|
||||
{
|
||||
MalRepRingPrep<T>::buffer_squares();
|
||||
}
|
||||
|
||||
void buffer_bits();
|
||||
void buffer_bits()
|
||||
{
|
||||
RingOnlyBitsFromSquaresPrep<T>::buffer_bits();
|
||||
}
|
||||
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b)
|
||||
{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user