diff --git a/BMR/CommonParty.cpp b/BMR/CommonParty.cpp index f4bd2e31..33cc9a20 100644 --- a/BMR/CommonParty.cpp +++ b/BMR/CommonParty.cpp @@ -42,7 +42,6 @@ CommonFakeParty::~CommonFakeParty() CommonParty::~CommonParty() { - cerr << "Total time: " << timer.elapsed() << endl; #ifdef VERBOSE cerr << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl; cerr << "CPU time: " << cpu_timer.elapsed() << endl; @@ -50,6 +49,7 @@ CommonParty::~CommonParty() cerr << "Second phase time: " << timers[1].elapsed() << endl; cerr << "Number of gates: " << gate_counter << endl; #endif + cerr << "Time = " << timer.elapsed() << " seconds" << endl; } void CommonParty::check(int n_parties) diff --git a/BMR/ProgramParty.hpp b/BMR/ProgramParty.hpp index 1490dc7b..088d8cae 100644 --- a/BMR/ProgramParty.hpp +++ b/BMR/ProgramParty.hpp @@ -8,6 +8,8 @@ #include "Party.h" +#include "GC/ShareSecret.hpp" + template ProgramPartySpec* ProgramPartySpec::singleton = 0; diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h index 842c732e..9fa2dc52 100644 --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -9,15 +9,32 @@ #include "Register.h" template class RealProgramParty; +template class RealGarbleWire; + +template +class GarbleInputter +{ +public: + RealProgramParty& party; + + Bundle oss; + PointerVector*, int>> tuples; + + GarbleInputter(); + void exchange(); +}; template class RealGarbleWire : public PRFRegister { friend class RealProgramParty; + friend class GarbleInputter; T mask; public: + typedef GarbleInputter Input; + static void store(NoMemory& dest, const vector>>& accesses); static void load(vector>>& accesses, @@ -26,6 +43,11 @@ public: static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& processor); + static void inputb(GC::Processor>& processor, + const vector& args); + static void inputbvec(GC::Processor>& processor, + ProcessorBase& input_processor, const vector& args); + RealGarbleWire(const Register& reg) : PRFRegister(reg) {} void garble(PRFOutputs& prf_output, const RealGarbleWire& left, @@ -37,6 +59,10 @@ public: void public_input(bool value); void random(); void output(); + + void my_input(Input& Inputter, bool value, int n_bits); + void other_input(Input& Inputter, int from); + void finalize_input(Input&, int from, int); }; template diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 1b25dba7..55adcbfb 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -94,36 +94,81 @@ void RealGarbleWire::XOR(const RealGarbleWire& left, const RealGarbleWire< } template -void RealGarbleWire::input(party_id_t from, char input) +void RealGarbleWire::inputb( + GC::Processor>& processor, + const vector& args) +{ + GarbleInputter inputter; + processor.inputb(inputter, processor, args, + inputter.party.P->my_num()); +} + +template +void RealGarbleWire::inputbvec( + GC::Processor>& processor, + ProcessorBase& input_processor, const vector& args) +{ + GarbleInputter inputter; + processor.inputbvec(inputter, input_processor, args, + inputter.party.P->my_num()); +} + +template +GarbleInputter::GarbleInputter() : + party(RealProgramParty::s()), oss(*party.P) +{ +} + +template +void RealGarbleWire::my_input(Input& inputter, bool, int n_bits) +{ + assert(n_bits == 1); + inputter.tuples.push_back({this, inputter.party.P->my_num()}); +} + +template +void RealGarbleWire::other_input(Input& inputter, int from) +{ + inputter.tuples.push_back({this, from}); +} + +template +void GarbleInputter::exchange() { - PRFRegister::input(from, input); - auto& party = RealProgramParty::s(); assert(party.shared_proc != 0); auto& inputter = party.shared_proc->input; - inputter.reset(from - 1); - if (from == party.get_id()) + inputter.reset_all(*party.P); + for (auto& tuple : tuples) { - char my_mask; - my_mask = party.prng.get_bit(); - party.input_masks.serialize(my_mask); - inputter.add_mine(my_mask); - inputter.send_mine(); - mask = inputter.finalize_mine(); + int from = tuple.second; + party_id_t from_id = from + 1; + tuple.first->PRFRegister::input(from_id, -1); + if (from_id == party.get_id()) + { + char my_mask; + my_mask = party.prng.get_bit(); + party.garble_input_masks.serialize(my_mask); + inputter.add_mine(my_mask); #ifdef DEBUG_MASK - cout << "my mask: " << (int)my_mask << endl; + cout << "my mask: " << (int)my_mask << endl; #endif + } + else + { + inputter.add_other(from); + } } - else - { - inputter.add_other(from - 1); - octetStream os; - party.P->receive_player(from - 1, os, true); - inputter.finalize_other(from - 1, mask, os); - } + + inputter.exchange(); + + for (auto& tuple : tuples) + tuple.first->mask = (inputter.finalize(tuple.second)); + // important to make sure that mask is a bit try { - mask.force_to_bit(); + for (auto& tuple : tuples) + tuple.first->mask.force_to_bit(); } catch (not_implemented& e) { @@ -131,16 +176,36 @@ void RealGarbleWire::input(party_id_t from, char input) assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; protocol.init_mul(party.shared_proc); - protocol.prepare_mul(mask, T::constant(1, party.P->my_num(), party.mac_key) - mask); + for (auto& tuple : tuples) + protocol.prepare_mul(tuple.first->mask, + T::constant(1, party.P->my_num(), party.mac_key) + - tuple.first->mask); protocol.exchange(); - if (party.MC->open(protocol.finalize_mul(), *party.P) != 0) + vector to_check; + to_check.reserve(tuples.size()); + for (size_t i = 0; i < tuples.size(); i++) + { + to_check.push_back(protocol.finalize_mul()); + } + try + { + party.MC->CheckFor(0, to_check, *party.P); + } + catch (mac_fail&) + { throw runtime_error("input mask not a bit"); + } } #ifdef DEBUG_MASK cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl; #endif } +template +void RealGarbleWire::finalize_input(GarbleInputter&, int, int) +{ +} + template void RealGarbleWire::public_input(bool value) { @@ -169,7 +234,7 @@ void RealGarbleWire::output() assert(party.MC != 0); assert(party.P != 0); auto m = party.MC->open(mask, *party.P); - party.output_masks.push_back(m.get_bit(0)); + party.garble_output_masks.push_back(m.get_bit(0)); party.taint(); #ifdef DEBUG_MASK cout << "output mask: " << m << endl; diff --git a/BMR/RealProgramParty.h b/BMR/RealProgramParty.h index d3ef1d0b..9f274bd8 100644 --- a/BMR/RealProgramParty.h +++ b/BMR/RealProgramParty.h @@ -19,6 +19,7 @@ class RealProgramParty : public ProgramPartySpec typedef typename T::Input Inputter; friend class RealGarbleWire; + friend class GarbleInputter; friend class GarbleJob; static RealProgramParty* singleton; @@ -40,9 +41,15 @@ class RealProgramParty : public ProgramPartySpec GC::BreakType next; + bool one_shot; + + size_t data_sent; + public: static RealProgramParty& s(); + LocalBuffer garble_input_masks, garble_output_masks; + RealProgramParty(int argc, const char** argv); ~RealProgramParty(); diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 630d2ec2..93ef4fb3 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -44,10 +44,20 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : "-N", // Flag token. "--nparties" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Evaluate only after garbling.", // Help description. + "-O", // Flag token. + "--oneshot" // Flag token. + ); opt.parse(argc, argv); int nparties; opt.get("-N")->getInt(nparties); this->check(nparties); + one_shot = opt.isSet("-O"); NetworkOptions network_opts(opt, argc, argv); OnlineOptions& online_opts = OnlineOptions::singleton; @@ -90,7 +100,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : mac_key.randomize(prng); if (T::needs_ot) BaseMachine::s().ot_setups.push_back({*P, true}); - prep = Preprocessing::get_live_prep(0, usage); + prep = new typename T::TriplePrep(0, usage); } else { @@ -122,6 +132,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : for (int i = 0; i < SPDZ_OP_N; i++) this->spdz_wires[i].push_back({}); + this->timer.reset(); do { next = GC::TIME_BREAK; @@ -129,7 +140,14 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : try { this->online_timer.start(); - this->start_online_round(); + if (one_shot) + this->start_online_round(); + else + { + this->load_garbled_circuit(); + next = this->second_phase(program, this->processor, + this->machine, this->dynamic_memory); + } this->online_timer.stop(); } catch (needs_cleaning& e) @@ -139,6 +157,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); + data_sent = P->comm_stats.total_data() + prep->data_sent(); if (server) delete server; @@ -152,7 +171,7 @@ void RealProgramParty::garble() auto& program = this->program; auto& MC = this->MC; - while (next == GC::TIME_BREAK) + do { garble_jobs.clear(); garble_inputter->reset_all(*P); @@ -178,13 +197,15 @@ void RealProgramParty::garble() vector opened; MC->POpen(opened, wires, *P); + LocalBuffer garbled_circuit; for (auto& x : opened) - this->garbled_circuit.serialize(x); + garbled_circuit.serialize(x); - this->garbled_circuits.push_and_clear(this->garbled_circuit); - this->input_masks_store.push_and_clear(this->input_masks); - this->output_masks_store.push_and_clear(this->output_masks); + this->garbled_circuits.push_and_clear(garbled_circuit); + this->input_masks_store.push_and_clear(garble_input_masks); + this->output_masks_store.push_and_clear(garble_output_masks); } + while (one_shot and next == GC::TIME_BREAK); } template @@ -194,6 +215,7 @@ RealProgramParty::~RealProgramParty() delete prep; delete garble_inputter; delete garble_protocol; + cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; } template diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 0be981fb..99d1bda7 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -601,37 +601,64 @@ void EvalRegister::input_helper(char value, octetStream& os) os.serialize(get_external()); } -void EvalRegister::input(party_id_t from, char value) +EvalInputter::EvalInputter() : + party(ProgramParty::s()), oss(*party.P) { - auto& party = ProgramParty::s(); +} + +void EvalRegister::my_input(EvalInputter& inputter, bool input, int n_bits) +{ + assert(n_bits == 1); + auto& party = inputter.party; party.load_wire(*this); - octetStream os; - if (from == party.get_id()) - { - if (value and (1 - value)) - throw runtime_error("invalid input"); - input_helper(value, os); - party.P->send_all(os, true); - } - else - { - party.P->receive_player(from - 1, os); - char ext; - os.unserialize(ext); - set_external(ext); - } - vector oss(party.get_n_parties()); - size_t id = party.get_id() - 1; - oss[id].serialize(garbled_entry[id]); -#ifdef DEBUG_COMM - cout << "send " << garbled_entry[id] << ", " - << oss[id].get_length() << " bytes from " << id << endl; -#endif + input_helper(input, inputter.oss.mine); + inputter.tuples.push_back({this, party.P->my_num()}); +} + +void EvalRegister::other_input(EvalInputter& inputter, int from) +{ + auto& party = inputter.party; + party.load_wire(*this); + inputter.tuples.push_back({this, from}); +} + +void EvalInputter::exchange() +{ party.P->Broadcast_Receive(oss, true); + for (auto& tuple : tuples) + { + if (tuple.from != party.P->my_num()) + { + char ext; + oss[tuple.from].unserialize(ext); + tuple.reg->set_external(ext); + } + } + + size_t id = party.get_id() - 1; + for (auto& os : oss) + os.reset_write_head(); + + for (auto& tuple : tuples) + { + oss[id].serialize(tuple.reg->get_garbled_entry()[id]); +#ifdef DEBUG_COMM + cout << "send " << garbled_entry[id] << ", " + << oss[id].get_length() << " bytes from " << id << endl; +#endif + } + + party.P->Broadcast_Receive(oss, true); +} + +void EvalRegister::finalize_input(EvalInputter& inputter, int, int) +{ + auto& party = inputter.party; + size_t id = party.get_id() - 1; for (size_t i = 0; i < (size_t)party.get_n_parties(); i++) { if (i != id) - oss[i].unserialize(garbled_entry[i]); + inputter.oss[i].unserialize(garbled_entry[i]); } keys[external] = garbled_entry; #ifdef DEBUG @@ -934,4 +961,4 @@ void KeyTuple::print(int wire_id, party_id_t pid) } template class KeyTuple<2>; -template class KeyTuple<4>; +template class KeyTuple<4> ; diff --git a/BMR/Register.h b/BMR/Register.h index 192c5adb..09b29d18 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -20,6 +20,8 @@ using namespace std; #include "GC/ArgTuples.h" #include "Math/gf2n.h" #include "Tools/FlexBuffer.h" +#include "Tools/PointerVector.h" +#include "Tools/Bundle.h" //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -201,6 +203,8 @@ public: inline BlackHole& endl(BlackHole& b) { return b; } inline BlackHole& flush(BlackHole& b) { return b; } +class ProcessorBase; + class Phase { public: @@ -209,6 +213,8 @@ public: typedef BlackHole out_type; static BlackHole out; + static const bool actual_inputs = true; + template static void store_clear_in_dynamic(T& mem, const vector& accesses) { (void)mem; (void)accesses; } @@ -231,6 +237,9 @@ public: template static void inputb(T& processor, const vector& args) { processor.input(args); } template + static void inputbvec(T&, ProcessorBase&, const vector&) + { throw not_implemented(); } + template static T get_input(int from, GC::Processor& processor, int n_bits) { return T::input(from, processor.get_input(n_bits), n_bits); } @@ -244,9 +253,24 @@ public: void output() {} }; +class NoOpInputter +{ +public: + PointerVector inputs; + + void exchange() + { + } +}; + class ProgramRegister : public Phase, public Register { public: + typedef NoOpInputter Input; + + // only true for evaluation + static const bool actual_inputs = false; + static Register new_reg(); static Register tmp_reg() { return new_reg(); } static Register and_reg() { return new_reg(); } @@ -255,11 +279,18 @@ public: static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } + template + static void inputbvec(T& processor, ProcessorBase& input_processor, + const vector& args); + // most BMR phases don't need actual input template static T get_input(GC::Processor& processor, const InputArgs& args) { (void)processor; return T::input(args.from + 1, 0, args.n_bits); } + void my_input(Input&, bool, int) {} + void other_input(Input&, int) {} + char get_output() { return 0; } ProgramRegister(const Register& reg) : Register(reg) {} @@ -282,6 +313,37 @@ public: void public_input(bool value); void random(); void output(); + + void finalize_input(NoOpInputter&, int from, int) + { input(from + 1, -1); } +}; + +class ProgramParty; +class EvalRegister; + +class EvalInputter +{ + class Tuple + { + public: + EvalRegister* reg; + int from; + + Tuple(EvalRegister* reg, int from) : + reg(reg), from(from) + { + } + }; + +public: + ProgramParty& party; + + Bundle oss; + vector tuples; + + EvalInputter(); + void add_other(int from); + void exchange(); }; class EvalRegister : public ProgramRegister @@ -289,9 +351,13 @@ class EvalRegister : public ProgramRegister public: static string name() { return "Evaluation"; } + typedef EvalInputter Input; + typedef ostream& out_type; static ostream& out; + static const bool actual_inputs = true; + template static void store(GC::Memory& dest, const vector >& accesses); @@ -303,6 +369,9 @@ public: static void andrs(T& processor, const vector& args); template static void inputb(T& processor, const vector& args); + template + static void inputbvec(T& processor, ProcessorBase& input_processor, + const vector& args); template static T get_input(GC::Processor& processor, const InputArgs& args) @@ -330,6 +399,10 @@ public: static void check_input(long long input, int n_bits); void input(party_id_t from, char value = -1); void input_helper(char value, octetStream& os); + + void my_input(EvalInputter& inputter, bool input, int); + void other_input(EvalInputter& inputter, int from); + void finalize_input(EvalInputter& inputter, int from, int); }; class GarbleRegister : public ProgramRegister @@ -349,6 +422,9 @@ public: void public_input(bool value); void random(); void output() {} + + void finalize_input(NoOpInputter&, int from, int) + { input(from + 1, -1); } }; class RandomRegister : public ProgramRegister @@ -374,6 +450,9 @@ public: void public_input(bool value); void random(); void output(); + + void finalize_input(NoOpInputter&, int from, int) + { input(from + 1, -1); } }; diff --git a/BMR/Register.hpp b/BMR/Register.hpp index 05e96301..8c8c9a97 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -9,6 +9,31 @@ #include "Register.h" #include "Party.h" +template +void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, + const vector& args) +{ + NoOpInputter inputter; + int my_num = -1; + try + { + my_num = ProgramParty::s().P->my_num(); + } + catch (exception&) + { + } + processor.inputbvec(inputter, input_processor, args, my_num); +} + +template +void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, + const vector& args) +{ + EvalInputter inputter; + processor.inputbvec(inputter, input_processor, args, + ProgramParty::s().P->my_num()); +} + template void PRFRegister::load(vector >& accesses, const NoMemory& source) diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 9a7bfee2..ae701549 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -29,6 +29,7 @@ #include "GC/ThreadMaster.hpp" #include "GC/Program.hpp" #include "GC/Instruction.hpp" +#include "GC/ShareSecret.hpp" #include "Processor/Instruction.hpp" #include "Protocols/Share.hpp" diff --git a/CHANGELOG.md b/CHANGELOG.md index 51db3bb1..4fd7a9f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.9 (Aug 24, 2020) + +- Streamline inputs to binary circuits +- Improved private output +- Emulator for arithmetic circuits +- Efficient dot product with Shamir's secret sharing +- Lower memory usage for TensorFlow inference +- This version breaks bytecode compatibilty. + ## 0.1.8 (June 15, 2020) - Half-gate garbling diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 88543760..3ea9d4a0 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -37,6 +37,7 @@ opcodes = dict( STMSBI = 0x243, MOVSB = 0x244, INPUTB = 0x246, + INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, XORCBI = 0x210, @@ -269,6 +270,24 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): arg_format = tools.cycle(['p','int','int','sbw']) is_vec = lambda self: True +class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, + base.Mergeable): + __slots__ = [] + code = opcodes['INPUTBVEC'] + + def __init__(self, *args, **kwargs): + self.arg_format = [] + i = 0 + while i < len(args): + self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (args[i] - 3) + i += args[i] + assert i == len(args) + super(inputbvec, self).__init__(*args, **kwargs) + + def merge(self, other): + self.args += other.args + self.arg_format += other.arg_format + class print_regb(base.VectorInstruction, base.IOInstruction): code = opcodes['PRINTREGB'] arg_format = ['cb','i'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 13e0c689..b61a80d4 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -88,7 +88,8 @@ class bits(Tape.Register, _structure, _bit): if mem_type == 'sd': return cls.load_dynamic_mem(address) else: - cls.load_inst[util.is_constant(address)](res, address) + for i in range(res.size): + cls.load_inst[util.is_constant(address)](res[i], address + i) return res def store_in_mem(self, address): self.store_inst[isinstance(address, int)](self, address) @@ -120,17 +121,19 @@ class bits(Tape.Register, _structure, _bit): assert(other.size == math.ceil(self.n / self.unit)) for i, (x, y) in enumerate(zip(self, other)): self.conv_regint(min(self.unit, self.n - i * self.unit), x, y) - elif isinstance(self, type(other)) or isinstance(other, type(self)): - assert(self.n == other.n) + elif (isinstance(self, type(other)) or isinstance(other, type(self))) \ + and self.n == other.n: for i in range(math.ceil(self.n / self.unit)): self.mov(self[i], other[i]) else: try: - other = self.bit_compose(other.bit_decompose()) + bits = other.bit_decompose() + bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits)) + other = self.bit_compose(bits) self.load_other(other) except: - raise CompilerError('cannot convert from %s to %s' % \ - (type(other), type(self))) + raise CompilerError('cannot convert %s/%s from %s to %s' % \ + (str(other), repr(other), type(other), type(self))) def long_one(self): return 2**self.n - 1 if self.n != None else None def __repr__(self): @@ -160,6 +163,9 @@ class cbits(bits): conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) conv_cint_vec = inst.convcintvec @classmethod + def bit_compose(cls, bits): + return sum(bit << i for i, bit in enumerate(bits)) + @classmethod def conv_regint_by_bit(cls, n, res, other): assert n == res.n assert n == other.size @@ -459,46 +465,68 @@ class sbits(bits): class sbitvec(_vec): @classmethod def get_type(cls, n): - class sbitvecn(cls): + class sbitvecn(cls, _structure): @staticmethod def malloc(size): - return sbits.malloc(size * n) + return sbit.malloc(size * n) @staticmethod def n_elements(): return n @classmethod def get_input_from(cls, player): - return cls.from_vec( - sbits.get_input_from(player, n).bit_decompose(n)) + res = cls.from_vec(sbit() for i in range(n)) + inst.inputbvec(n + 3, 0, player, *res.v) + return res get_raw_input_from = get_input_from def __init__(self, other=None): if other is not None: - self.v = sbits(other, n=n).bit_decompose(n) + if util.is_constant(other): + self.v = [sbit((other >> i) & 1) for i in range(n)] + else: + self.v = sbits(other, n=n).bit_decompose(n) @classmethod def load_mem(cls, address): - try: - assert len(address) == n + if not isinstance(address, int) and len(address) == n: return cls.from_vec(sbit.load_mem(x) for x in address) - except: + else: return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): assert self.v[0].n == 1 - try: - assert len(address) == n + if not isinstance(address, int) and len(address) == n: for x, y in zip(self.v, address): x.store_in_mem(y) - except: + else: for i in range(n): self.v[i].store_in_mem(address + i) def reveal(self): - return self.elements()[0].reveal() + revealed = [cbit() for i in range(len(self))] + for i in range(len(self)): + inst.reveal(1, revealed[i], self.v[i]) + return cbits.get_type(len(self)).bit_compose(revealed) + @classmethod + def two_power(cls, nn): + return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def coerce(self, other): + if util.is_constant(other): + return self.from_vec(util.bit_decompose(other, n)) + else: + return super(sbitvecn, self).coerce(other) + @classmethod + def bit_compose(cls, bits): + if len(bits) < n: + bits += [0] * (n - len(bits)) + assert len(bits) == n + return cls.from_vec(bits) + def __str__(self): + return 'sbitvec(%d)' % n return sbitvecn @classmethod def from_vec(cls, vector): res = cls() res.v = list(vector) return res + compose = from_vec @classmethod def combine(cls, vectors): res = cls() @@ -512,11 +540,7 @@ class sbitvec(_vec): if length: assert isinstance(elements, sint) if Program.prog.use_split(): - n = Program.prog.use_split() - columns = [[sbits.get_type(elements.size)() - for i in range(n)] for i in range(length)] - inst.split(n, elements, *sum(columns, [])) - x = sbitint.wallace_tree_without_finish(columns, False) + x = elements.split_to_two_summands(length) v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True) else: assert Program.prog.options.ring @@ -569,11 +593,14 @@ class sbitvec(_vec): @property def size(self): return self.v[0].n + @property + def n_bits(self): + return len(self.v) def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self): - return self.v + def bit_decompose(self, n_bits=None): + return self.v[:n_bits] bit_compose = from_vec def reveal(self): assert len(self) == 1 @@ -673,7 +700,41 @@ cbits.dynamic_array = Array def _complement_two_extend(bits, k): return bits + [bits[-1]] * (k - len(bits)) -class sbitint(_bitint, _number, sbits): +class _sbitintbase: + def extend(self, n): + bits = self.bit_decompose() + bits += [bits[-1]] * (n - len(bits)) + return self.get_type(n).bit_compose(bits) + def cast(self, n): + bits = self.bit_decompose()[:n] + bits += [bits[-1]] * (n - len(bits)) + return self.get_type(n).bit_compose(bits) + def round(self, k, m, kappa=None, nearest=None, signed=None): + bits = self.bit_decompose() + res_bits = self.bit_adder(bits[m:k], [bits[m-1]]) + return self.get_type(k - m).compose(res_bits) + def int_div(self, other, bit_length=None): + k = bit_length or max(self.n, other.n) + return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k) + def Norm(self, k, f, kappa=None, simplex_flag=False): + absolute_val = abs(self) + #next 2 lines actually compute the SufOR for little indian encoding + bits = absolute_val.bit_decompose(k)[::-1] + suffixes = floatingpoint.PreOR(bits)[::-1] + z = [0] * k + for i in range(k - 1): + z[i] = suffixes[i] - suffixes[i+1] + z[k - 1] = suffixes[k-1] + z.reverse() + t2k = self.get_type(2 * k) + acc = t2k.bit_compose(z) + sign = self.bit_decompose()[-1] + signed_acc = util.if_else(sign, -acc, acc) + absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose()) + part_reciprocal = absolute_val_2k * acc + return part_reciprocal, signed_acc + +class sbitint(_bitint, _number, sbits, _sbitintbase): n_bits = None bin_type = None types = {} @@ -728,43 +789,11 @@ class sbitint(_bitint, _number, sbits): res_bits = product.bit_decompose()[m:k] t = self.combo_type(other) return t.bit_compose(res_bits) - def Norm(self, k, f, kappa=None, simplex_flag=False): - absolute_val = abs(self) - #next 2 lines actually compute the SufOR for little indian encoding - bits = absolute_val.bit_decompose(k)[::-1] - suffixes = floatingpoint.PreOR(bits)[::-1] - z = [0] * k - for i in range(k - 1): - z[i] = suffixes[i] - suffixes[i+1] - z[k - 1] = suffixes[k-1] - z.reverse() - t2k = self.get_type(2 * k) - acc = t2k.bit_compose(z) - sign = self.bit_decompose()[-1] - signed_acc = util.if_else(sign, -acc, acc) - absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose()) - part_reciprocal = absolute_val_2k * acc - return part_reciprocal, signed_acc - def extend(self, n): - bits = self.bit_decompose() - bits += [bits[-1]] * (n - len(bits)) - return self.get_type(n).bit_compose(bits) def __mul__(self, other): if isinstance(other, sbitintvec): return other * self else: return super(sbitint, self).__mul__(other) - def cast(self, n): - bits = self.bit_decompose()[:n] - bits += [bits[-1]] * (n - len(bits)) - return self.get_type(n).bit_compose(bits) - def int_div(self, other, bit_length=None): - k = bit_length or max(self.n, other.n) - return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k) - def round(self, k, m, kappa=None, nearest=None, signed=None): - bits = self.bit_decompose() - res_bits = self.bit_adder(bits[m:k], [bits[m-1]]) - return self.get_type(k - m).compose(res_bits) @classmethod def get_bit_matrix(cls, self_bits, other): n = len(self_bits) @@ -782,10 +811,11 @@ class sbitint(_bitint, _number, sbits): res.append([(x & bit) for x in other.bit_decompose(n - i)]) return res -class sbitintvec(sbitvec, _number): +class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): def __add__(self, other): if util.is_zero(other): return self + other = self.coerce(other) assert(len(self.v) == len(other.v)) v = sbitint.bit_adder(self.v, other.v) return self.from_vec(v) @@ -797,7 +827,7 @@ class sbitintvec(sbitvec, _number): if isinstance(other, sbits): return self.from_vec(other * x for x in self.v) matrix = [] - for i, b in enumerate(other.bit_decompose()): + for i, b in enumerate(util.bit_decompose(other)): matrix.append([x * b for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 77575c23..5b647c3b 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -12,6 +12,51 @@ import operator import sys from functools import reduce +class BlockAllocator: + """ Manages freed memory blocks. """ + def __init__(self): + self.by_logsize = [defaultdict(set) for i in range(32)] + self.by_address = {} + + def by_size(self, size): + return self.by_logsize[int(math.log(size, 2))][size] + + def push(self, address, size): + end = address + size + if end in self.by_address: + next_size = self.by_address.pop(end) + self.by_size(next_size).remove(end) + size += next_size + self.by_size(size).add(address) + self.by_address[address] = size + + def pop(self, size): + if len(self.by_size(size)) > 0: + block_size = size + else: + logsize = int(math.log(size, 2)) + for block_size, addresses in self.by_logsize[logsize].items(): + if block_size >= size and len(addresses) > 0: + break + else: + done = False + for x in self.by_logsize[logsize + 1:]: + for block_size, addresses in x.items(): + if len(addresses) > 0: + done = True + break + if done: + break + else: + block_size = 0 + if block_size >= size: + addr = self.by_size(block_size).pop() + del self.by_address[addr] + diff = block_size - size + if diff: + self.by_size(diff).add(addr + size) + self.by_address[addr + size] = diff + return addr class StraightlineAllocator: """Allocate variables in a straightline program using n registers. @@ -509,3 +554,47 @@ class Merger: for i in range(self.G.n): print('%d: %s' % (self.depths[i], self.instructions[i]), file=f) f.close() + +class RegintOptimizer: + def __init__(self): + self.cache = util.dict_by_id() + + def run(self, instructions): + for i, inst in enumerate(instructions): + if isinstance(inst, ldint_class): + self.cache[inst.args[0]] = inst.args[1] + elif isinstance(inst, IntegerInstruction): + if inst.args[1] in self.cache and inst.args[2] in self.cache: + res = inst.op(self.cache[inst.args[1]], + self.cache[inst.args[2]]) + if abs(res) < 2 ** 31: + self.cache[inst.args[0]] = res + instructions[i] = ldint(inst.args[0], res, + add_to_prog=False) + elif isinstance(inst, addint_class): + if inst.args[1] in self.cache and \ + self.cache[inst.args[1]] == 0: + instructions[i] = inst.args[0].link(inst.args[2]) + elif inst.args[2] in self.cache and \ + self.cache[inst.args[2]] == 0: + instructions[i] = inst.args[0].link(inst.args[1]) + elif isinstance(inst, IndirectMemoryInstruction): + if inst.args[1] in self.cache: + instructions[i] = inst.get_direct(self.cache[inst.args[1]]) + elif isinstance(inst, convint_class): + if inst.args[1] in self.cache: + res = self.cache[inst.args[1]] + self.cache[inst.args[0]] = res + if abs(res) < 2 ** 31: + instructions[i] = ldi(inst.args[0], res, + add_to_prog=False) + elif isinstance(inst, mulm_class): + if inst.args[2] in self.cache: + op = self.cache[inst.args[2]] + if op == 0: + instructions[i] = ldsi(inst.args[0], 0, + add_to_prog=False) + elif op == 1: + instructions[i] = None + inst.args[0].link(inst.args[1]) + instructions[:] = filter(lambda x: x is not None, instructions) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 1869c08b..5e542d79 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -79,7 +79,10 @@ def LTZ(s, a, k, kappa): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): - movs(s, sint.conv(sbitvec(a, k).v[-1])) + summands = a.split_to_two_summands(k) + carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) + msb = carry ^ summands[0][-1] ^ summands[1][-1] + movs(s, sint.conv(msb)) return elif program.options.ring: from . import floatingpoint @@ -132,9 +135,31 @@ def Trunc(d, a, k, m, kappa, signed): mulm(d, t, c[2]) def TruncRing(d, a, k, m, signed): - a_prime = Mod2mRing(None, a, k, m, signed) - a -= a_prime - res = TruncLeakyInRing(a, k, m, signed) + if program.use_split() == 3: + from Compiler.types import sint + from .GC.types import sbitint + length = int(program.options.ring) + summands = a.split_to_n_summands(length, 3) + x = sbitint.wallace_tree_without_finish(summands, True) + if m == 1: + low = x[1][1] + high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \ + sint.conv(x[0][-1]) + else: + mid_carry = CarryOutRawLE(x[1][:m], x[0][:m]) + low = sint.conv(mid_carry) + sint.conv(x[0][m]) + tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy) + for xx, yy in zip(x[1][m:-1], + x[0][m:-1]))) + top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1]) + high = top_carry + sint.conv(x[0][-1]) + shifted = sint() + shrsi(shifted, a, m) + res = shifted + sint.conv(low) - (high << (length - m)) + else: + a_prime = Mod2mRing(None, a, k, m, signed) + a -= a_prime + res = TruncLeakyInRing(a, k, m, signed) if d is not None: movs(d, res) return res @@ -281,7 +306,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ program.curr_tape.require_bit_length(k + kappa) from .types import sint - if program.use_edabit() and m > 1: + if program.use_edabit() and m > 1 and not const_rounds: movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) tmp, b[:] = sint.get_edabit(m, True) movs(r_prime, tmp) @@ -290,7 +315,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): t[0][1] = b[-1] PRandInt(r_dprime, k + kappa - m) # r_dprime is always multiplied by 2^m - if use_dabit and program.use_dabit and m > 1: + if use_dabit and program.use_dabit and m > 1 and not const_rounds: r, b[:] = zip(*(sint.get_dabit() for i in range(m))) r = sint.bit_compose(r) movs(r_prime, r) @@ -383,7 +408,7 @@ def BitLTC1(u, a, b, kappa): Mod2(u, t[4][k-1], k, kappa, False) return p, a_bits, d, s, t, c, b, pre_input -def carry(b, a, compute_p): +def carry(b, a, compute_p=True): """ Carry propogation: return (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1)) """ diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 9247abff..1c91998f 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -9,14 +9,14 @@ import time import sys -def run(args, options, param=-1, merge_opens=True, +def run(args, options, merge_opens=True, reallocate=True, debug=False): """ Compile a file and output a Program object. If merge_opens is set to True, will attempt to merge any parallelisable open instructions. """ - prog = Program(args, options, param) + prog = Program(args, options) instructions.program = prog instructions_base.program = prog types.program = prog @@ -24,7 +24,7 @@ def run(args, options, param=-1, merge_opens=True, prog.DEBUG = debug VARS['program'] = prog if options.binary: - VARS['sint'] = GC_types.sbitint.get_type(int(options.binary)) + VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) VARS['sfix'] = GC_types.sbitfix comparison.set_variant(options) diff --git a/Compiler/config.py b/Compiler/config.py index 267c6476..9297a7e7 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -6,6 +6,7 @@ USER_MEM = 8192 P_VALUES = { 32: 2147565569, \ 64: 9223372036855103489, \ 128: 170141183460469231731687303715885907969, \ + 192: 3138550867693340381917894711603833208051177722232017256453, \ 256: 57896044618658097711785492504343953926634992332820282019728792003956566065153, \ 512: 6703903964971298549787012499102923063739682910296196688861780721860882015036773488400937149083451713845015929093243025426876941405973284973216824503566337 } @@ -19,7 +20,7 @@ BIT_LENGTHS = { -1: 64, 512: 64 } -COST = defaultdict(lambda: defaultdict(lambda: 0), +COST = defaultdict(lambda: defaultdict(lambda: 0), { 'modp': defaultdict(lambda: 0, { 'triple': 0.00020652622883106154, 'square': 0.00020652622883106154, diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index dad6a493..80b04652 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -104,11 +104,15 @@ def PreORC(a, kappa=None, m=None, raw=False): k = len(a) if k == 1: return [a[0]] + prog = program.Program.prog + kappa = kappa or prog.security m = m or k if isinstance(a[0], types.sgf2n): max_k = program.Program.prog.galois_length - 1 else: - max_k = int(log(program.Program.prog.P) / log(2)) - kappa + # assume prime length is power of two + prime_length = 2 ** int(ceil(log(prog.bit_length + kappa, 2))) + max_k = prime_length - kappa - 2 assert(max_k > 0) if k <= max_k: p = [None] * m @@ -132,7 +136,8 @@ def PreORC(a, kappa=None, m=None, raw=False): # not constant-round anymore s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)] t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw) - return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0]) + return sum(([or_op(x, y) for x in si] + for si,y in zip(s[1:],t)), s[0])[-m:] def PreOpL(op, items): """ @@ -549,9 +554,10 @@ def TruncPrRing(a, k, m, signed=True): trunc_pr(res, a, k, m) else: # extra bit to mask overflow - if program.Program.prog.use_edabit(): - lower = sint.get_edabit(m, True)[0] - upper = sint.get_edabit(k - m, True)[0] + prog = program.Program.prog + if prog.use_edabit() or prog.use_split() == 3: + lower = sint.get_random_int(m) + upper = sint.get_random_int(k - m) msb = sint.get_random_bit() r = (msb << k) + (upper << m) + lower else: diff --git a/Compiler/instructions.py b/Compiler/instructions.py index ac4f8550..99662dbf 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -91,64 +91,74 @@ class stmint(base.DirectMemoryWriteInstruction): # must have seperate instructions because address is always modp @base.vectorize -class ldmci(base.ReadMemoryInstruction): +class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ code = base.opcodes['LDMCI'] arg_format = ['cw','ci'] + direct = staticmethod(ldmc) @base.vectorize -class ldmsi(base.ReadMemoryInstruction): +class ldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """ code = base.opcodes['LDMSI'] arg_format = ['sw','ci'] + direct = staticmethod(ldms) @base.vectorize -class stmci(base.WriteMemoryInstruction): +class stmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction): r""" Sets \verb+C[cj]+ to be the value $c_i$. """ code = base.opcodes['STMCI'] arg_format = ['c','ci'] + direct = staticmethod(stmc) @base.vectorize -class stmsi(base.WriteMemoryInstruction): +class stmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction): r""" Sets \verb+S[cj]+ to be the value $s_i$. """ code = base.opcodes['STMSI'] arg_format = ['s','ci'] + direct = staticmethod(stms) @base.vectorize -class ldminti(base.ReadMemoryInstruction): +class ldminti(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """ code = base.opcodes['LDMINTI'] arg_format = ['ciw','ci'] + direct = staticmethod(ldmint) @base.vectorize -class stminti(base.WriteMemoryInstruction): +class stminti(base.WriteMemoryInstruction, base.IndirectMemoryInstruction): r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """ code = base.opcodes['STMINTI'] arg_format = ['ci','ci'] + direct = staticmethod(stmint) @base.vectorize -class gldmci(base.ReadMemoryInstruction): +class gldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """ code = base.opcodes['LDMCI'] + 0x100 arg_format = ['cgw','ci'] + direct = staticmethod(gldmc) @base.vectorize -class gldmsi(base.ReadMemoryInstruction): +class gldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """ code = base.opcodes['LDMSI'] + 0x100 arg_format = ['sgw','ci'] + direct = staticmethod(gldms) @base.vectorize -class gstmci(base.WriteMemoryInstruction): +class gstmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction): r""" Sets \verb+C[cj]+ to be the value $c_i$. """ code = base.opcodes['STMCI'] + 0x100 arg_format = ['cg','ci'] + direct = staticmethod(gstmc) @base.vectorize -class gstmsi(base.WriteMemoryInstruction): +class gstmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction): r""" Sets \verb+S[cj]+ to be the value $s_i$. """ code = base.opcodes['STMSI'] + 0x100 arg_format = ['sg','ci'] + direct = staticmethod(gstms) @base.gf2n @base.vectorize @@ -268,7 +278,7 @@ class use_edabit(base.Instruction): class run_tape(base.Instruction): r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """ code = base.opcodes['RUN_TAPE'] - arg_format = ['int','int','int'] + arg_format = tools.cycle(['int','int','int']) class join_tape(base.Instruction): r""" Join thread $c_i$. """ @@ -661,6 +671,12 @@ class shrci(base.ClearShiftInstruction): code = base.opcodes['SHRCI'] op = '__rshift__' +@base.vectorize +class shrsi(base.ClearShiftInstruction): + r""" Secret bitwise shift right by immediate value. """ + __slots__ = [] + code = base.opcodes['SHRSI'] + arg_format = ['sw','s','i'] ### ### Data access instructions @@ -742,6 +758,14 @@ class sedabit(base.Instruction): def add_usage(self, req_node): req_node.increment(('sedabit', len(self.args) - 1), self.get_size()) +@base.vectorize +class randoms(base.Instruction): + """ Random share """ + __slots__ = [] + code = base.opcodes['RANDOMS'] + arg_format = ['sw','int'] + field_type = 'modp' + @base.gf2n @base.vectorize class square(base.DataInstruction): @@ -781,6 +805,17 @@ class inputmask(base.Instruction): req_node.increment((self.field_type, 'input', self.args[1]), \ self.get_size()) +@base.vectorize +class inputmaskreg(base.Instruction): + __slots__ = [] + code = base.opcodes['INPUTMASKREG'] + arg_format = ['sw', 'cw', 'ci'] + field_type = 'modp' + + def add_usage(self, req_node): + # player 0 as proxy + req_node.increment((self.field_type, 'input', 0), float('inf')) + @base.gf2n @base.vectorize class prep(base.Instruction): @@ -1165,21 +1200,25 @@ class ldint(base.Instruction): class addint(base.IntegerInstruction): __slots__ = [] code = base.opcodes['ADDINT'] + op = operator.add @base.vectorize class subint(base.IntegerInstruction): __slots__ = [] code = base.opcodes['SUBINT'] + op = operator.sub @base.vectorize class mulint(base.IntegerInstruction): __slots__ = [] code = base.opcodes['MULINT'] + op = operator.mul @base.vectorize class divint(base.IntegerInstruction): __slots__ = [] code = base.opcodes['DIVINT'] + op = operator.floordiv @base.vectorize class bitdecint(base.Instruction): @@ -1228,18 +1267,21 @@ class ltc(base.IntegerInstruction): r""" Clear comparison $c_i = (c_j \stackrel{?}{<} c_k)$. """ __slots__ = [] code = base.opcodes['LTC'] + op = operator.lt @base.vectorize class gtc(base.IntegerInstruction): r""" Clear comparison $c_i = (c_j \stackrel{?}{>} c_k)$. """ __slots__ = [] code = base.opcodes['GTC'] + op = operator.gt @base.vectorize class eqc(base.IntegerInstruction): r""" Clear comparison $c_i = (c_j \stackrel{?}{==} c_k)$. """ __slots__ = [] code = base.opcodes['EQC'] + op = operator.eq ### diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 0b7f4ef9..0f4c448a 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -106,10 +106,12 @@ opcodes = dict( GBITTRIPLE = 0x154, GBITGF2NTRIPLE = 0x155, INPUTMASK = 0x56, + INPUTMASKREG = 0x5C, PREP = 0x57, DABIT = 0x58, EDABIT = 0x59, SEDABIT = 0x5A, + RANDOMS = 0x5B, # Input INPUT = 0x60, INPUTFIX = 0xF0, @@ -143,6 +145,7 @@ opcodes = dict( SHRC = 0x81, SHLCI = 0x82, SHRCI = 0x83, + SHRSI = 0x84, # Branching and comparison JMP = 0x90, JMPNZ = 0x91, @@ -446,10 +449,11 @@ def cisc(function): assert int(program.options.max_parallel_open) == 0, \ 'merging restriction not compatible with ' \ 'mergeable CISC instructions' - merger.longest_paths_merge() + n_rounds = merger.longest_paths_merge() filtered = filter(lambda x: x is not None, block.instructions) - self.instructions[self.merge_id()] = list(filtered), args - template, args = self.instructions[self.merge_id()] + self.instructions[self.merge_id()] = list(filtered), args, \ + n_rounds + template, args, self.n_rounds = self.instructions[self.merge_id()] subs = util.dict_by_id() for arg, reg in zip(args, regs): subs[arg] = reg @@ -486,6 +490,9 @@ def cisc(function): base += reg.size return block.instructions + def expanded_rounds(self): + return self.n_rounds - 1 + MergeCISC.__name__ = function.__name__ def wrapper(*args, **kwargs): @@ -649,7 +656,7 @@ class String(ArgFormat): @classmethod def encode(cls, arg): - return arg + '\0' * (cls.length - len(arg)) + return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg)) ArgFormats = { 'c': ClearModpAF, @@ -683,6 +690,7 @@ class Instruction(object): """ __slots__ = ['args', 'arg_format', 'code', 'caller'] count = 0 + code_length = 10 def __init__(self, *args, **kwargs): """ Create an instruction and append it to the program list. """ @@ -701,7 +709,7 @@ class Instruction(object): print("Compiled %d lines at" % self.__class__.count, time.asctime()) def get_code(self, prefix=0): - return (prefix << 10) + self.code + return (prefix << self.code_length) + self.code def get_encoding(self): enc = int_to_bytes(self.get_code()) @@ -713,7 +721,10 @@ class Instruction(object): return enc def get_bytes(self): - return bytearray(self.get_encoding()) + try: + return bytearray(self.get_encoding()) + except TypeError: + raise CompilerError('cannot encode %s/%s' % (self, self.get_encoding())) def check_args(self): """ Check the args match up with that specified in arg_format """ @@ -787,6 +798,9 @@ class Instruction(object): def expand_merged(self): return [self] + def expanded_rounds(self): + return 0 + def get_new_args(self, size, subs): new_args = [] for arg, f in zip(self.args, self.arg_format): @@ -864,6 +878,12 @@ class DirectMemoryInstruction(Instruction): def __init__(self, *args, **kwargs): super(DirectMemoryInstruction, self).__init__(*args, **kwargs) +class IndirectMemoryInstruction(Instruction): + __slots__ = [] + + def get_direct(self, address): + return self.direct(self.args[0], address, add_to_prog=False) + class ReadMemoryInstruction(Instruction): __slots__ = [] @@ -875,7 +895,7 @@ class DirectMemoryWriteInstruction(DirectMemoryInstruction, \ __slots__ = [] def __init__(self, *args, **kwargs): if program.curr_tape.prevent_direct_memory_write: - raise CompilerError('Direct memory writing prevented') + raise CompilerError('Direct memory writing prevented in threads') super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs) ### diff --git a/Compiler/library.py b/Compiler/library.py index 7f055091..02174979 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -6,6 +6,7 @@ in particularly providing flow control and output. from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc from Compiler.instructions import * from Compiler.util import tuplify,untuplify,is_zero +from Compiler.allocator import RegintOptimizer from Compiler import instructions,instructions_base,comparison,program,util import inspect,math import random @@ -54,6 +55,9 @@ def set_instruction_type(function): return instruction_typed_function +def _expand_to_print(val): + return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val) + def print_str(s, *args): """ Print a string, with optional args for adding variables/registers with ``%s``. """ @@ -90,7 +94,7 @@ def print_str(s, *args): elif isinstance(val, cfloat): val.print_float_plain() elif isinstance(val, (list, tuple, Array)): - print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) + print_str(*_expand_to_print(val)) else: try: val.output() @@ -127,6 +131,9 @@ def print_ln_if(cond, ss, *args): print_ln_if(get_player_id() == 0, 'Player 0 here') """ + print_str_if(cond, ss + '\n', *args) + +def print_str_if(cond, ss, *args): if util.is_constant(cond): if cond: print_ln(ss, *args) @@ -138,9 +145,14 @@ def print_ln_if(cond, ss, *args): cond = cint.conv(cond) for i, s in enumerate(subs): if i != 0: - args[i - 1].output_if(cond) - if i == len(args): - s += '\n' + val = args[i - 1] + try: + val.output_if(cond) + except: + if isinstance(val, (list, tuple, Array)): + print_str_if(cond, *_expand_to_print(val)) + else: + print_str_if(cond, str(val)) s += '\0' * ((-len(s)) % 4) while s: cond.print_if(s[:4]) @@ -162,7 +174,15 @@ def print_ln_to(player, ss, *args): new_args = [] for arg in args: if isinstance(arg, personal): - assert arg.player == player + if util.is_constant(arg.player) ^ util.is_constant(player): + match = False + else: + if util.is_constant(player): + match = arg.player == player + else: + match = id(arg.player) == id(player) + if not match: + raise CompilerError('player mismatch in personal printing') new_args.append(arg._v) else: new_args.append(arg) @@ -968,6 +988,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], del blocks[-n_to_merge + 1:] del get_tape().req_node.children[-1] merged.children = [] + RegintOptimizer().run(merged.instructions) get_tape().active_basicblock = merged else: req_node = get_tape().req_node.children[-1].nodes[0] @@ -1117,7 +1138,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ else: return loop_body(base + i) prog = get_program() - threads = [] + thread_args = [] if not util.is_zero(thread_rounds): tape = prog.new_tape(f, (0,), 'multithread') for i in range(n_threads - remainder): @@ -1125,7 +1146,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ args[remainder + i][0] = i * thread_rounds if len(mem_state): args[remainder + i][1] = mem_state.address - threads.append(prog.run_tape(tape, remainder + i)) + thread_args.append((tape, remainder + i)) if remainder: tape1 = prog.new_tape(f, (1,), 'multithread1') for i in range(remainder): @@ -1133,7 +1154,8 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ args[i][0] = (n_threads - remainder + i) * thread_rounds + i if len(mem_state): args[i][1] = mem_state.address - threads.append(prog.run_tape(tape1, i)) + thread_args.append((tape1, i)) + threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) if state: @@ -1625,6 +1647,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): y = y.round(2*k, 2*f, kappa, nearest, signed=True) x = x.round(2*k, 2*f, kappa, nearest, signed=True) + x = x.extend(2 * k) y = y.extend(2 * k) * (alpha + x).extend(2 * k) y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) return y @@ -1659,7 +1682,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): #next 2 lines actually compute the SufOR for little indian encoding bits = absolute_val.bit_decompose(k, kappa)[::-1] - suffixes = PreOR(bits)[::-1] + suffixes = PreOR(bits, kappa)[::-1] z = [0] * k for i in range(k - 1): diff --git a/Compiler/ml.py b/Compiler/ml.py index 561e613d..2060d22e 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -112,6 +112,12 @@ def lse_0_from_e_x(x, e_x): def lse_0(x): return lse_0_from_e_x(x, exp(x)) +def approx_lse_0(x, n=3): + assert n != 5 + a = x < -0.5 + b = x > 0.5 + return a.if_else(0, b.if_else(x, 0.5 * (x + 0.5) ** 2)) - x + def relu_prime(x): """ ReLU derivative. """ return (0 <= x) @@ -145,10 +151,19 @@ class Tensor(MultiArray): kwargs['alloc'] = False super(Tensor, self).__init__(*args, **kwargs) + def input_from(self, *args, **kwargs): + self.alloc() + super(Tensor, self).input_from(*args, **kwargs) + + def __getitem__(self, *args): + self.alloc() + return super(Tensor, self).__getitem__(*args) + class Layer: n_threads = 1 inputs = [] input_bias = True + thetas = lambda self: () @property def shape(self): @@ -193,14 +208,13 @@ class Output(Layer): self.approx = approx nablas = lambda self: () - thetas = lambda self: () reset = lambda self: None def divisor(self, divisor, size): return cfix(1.0 / divisor, size=size) def forward(self, batch): - if self.approx: + if self.approx == 5: self.l.write(999) return N = len(batch) @@ -209,10 +223,12 @@ class Output(Layer): def _(base, size): x = self.X.get_vector(base, size) y = self.Y.get(batch.get_vector(base, size)) + if self.approx: + lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base) + return e_x = exp(-x) self.e_x.assign(e_x, base) lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base) - e_x = self.e_x.get_vector(0, N) self.l.write(sum(lse) * \ self.divisor(N, 1)) @@ -323,7 +339,7 @@ class Dense(DenseBase): self.X = MultiArray([N, d, d_in], sfix) self.Y = MultiArray([N, d, d_out], sfix) - self.W = sfix.Matrix(d_in, d_out) + self.W = Tensor([d_in, d_out], sfix) self.b = sfix.Array(d_out) self.nabla_Y = MultiArray([N, d, d_out], sfix) @@ -546,12 +562,12 @@ class MaxPool(NoVariableLayer): for x in strides, ksize: for i in 0, 3: assert x[i] == 1 - self.X = MultiArray(shape, sfix) + self.X = Tensor(shape, sfix) if padding == 'SAME': output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)] else: output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)] - self.Y = MultiArray(output_shape, sfix) + self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize @@ -741,6 +757,7 @@ class ConvBase(BaseLayer): use_conv2ds = False temp_weights = None temp_inputs = None + thetas = lambda self: (self.weights, self.bias) @classmethod def init_temp(cls, layers): @@ -780,7 +797,7 @@ class ConvBase(BaseLayer): self.weight_squant = self.new_squant() self.bias_squant = self.new_squant() - self.weights = MultiArray(weight_shape, self.weight_squant) + self.weights = Tensor(weight_shape, self.weight_squant) self.bias = Array(output_shape[-1], self.bias_squant) self.unreduced = Tensor(self.output_shape, sint) @@ -1158,7 +1175,8 @@ class Optimizer: layer.last_used = list(filter(lambda x: x not in used, layer.inputs)) used.update(layer.inputs) - def forward(self, N=None, batch=None, keep_intermediate=True): + def forward(self, N=None, batch=None, keep_intermediate=True, + model_from=None): """ Compute graph. :param N: batch size (used if batch not given) @@ -1172,12 +1190,16 @@ class Optimizer: if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None: layer._X.address = layer.inputs[0].Y.address layer.Y.alloc() + if model_from is not None: + layer.input_from(model_from) break_point() layer.forward(batch=batch) break_point() if not keep_intermediate: for l in layer.last_used: l.Y.delete() + for theta in layer.thetas(): + theta.delete() def eval(self, data): """ Compute evaluation after training. """ @@ -1207,6 +1229,7 @@ class Optimizer: else: N = self.layers[0].N i = MemValue(0) + n_iterations = MemValue(0) @do_while def _(): if self.X_by_label is None: @@ -1216,6 +1239,7 @@ class Optimizer: n = N // len(self.X_by_label) n_per_epoch = int(math.ceil(1. * max(len(X) for X in self.X_by_label) / n)) + n_iterations.iadd(n_per_epoch) print('%d runs per epoch' % n_per_epoch) indices_by_label = [] for label, X in enumerate(self.X_by_label): @@ -1235,7 +1259,7 @@ class Optimizer: self.backward(batch=batch) self.update(i) loss = self.layers[-1].l - if self.report_loss and not self.layers[-1].approx: + if self.report_loss and self.layers[-1].approx != 5: print_ln('loss after epoch %s: %s', i, loss.reveal()) else: print_ln('done with epoch %s', i) @@ -1245,7 +1269,7 @@ class Optimizer: if self.tol > 0: res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() return res - print_ln('finished after %s epochs', i) + print_ln('finished after %s epochs and %s iterations', i, n_iterations) class Adam(Optimizer): def __init__(self, layers, n_epochs): diff --git a/Compiler/program.py b/Compiler/program.py index eec539af..326dc796 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -3,6 +3,7 @@ from Compiler.exceptions import * from Compiler.instructions_base import RegType import Compiler.instructions import Compiler.instructions_base +import Compiler.instructions_base as inst_base from . import compilerLib from . import allocator as al from . import util @@ -40,22 +41,20 @@ class Program(object): These are created by executing a file containing appropriate instructions and threads. """ - def __init__(self, args, options, param=-1): + def __init__(self, args, options): self.options = options self.verbose = options.verbose self.args = args self.init_names(args) - self.P = P_VALUES[param] - self.param = param - if (param != -1) + sum(x != 0 for x in(options.ring, options.field, + if sum(x != 0 for x in(options.ring, options.field, options.binary)) > 1: - raise CompilerError('can only use one out of -p, -B, -R, -F') + raise CompilerError('can only use one out of -B, -R, -F') if options.ring: self.bit_length = int(options.ring) - 1 else: self.bit_length = int(options.binary) or int(options.field) if not self.bit_length: - self.bit_length = BIT_LENGTHS[param] + self.bit_length = 64 print('Default bit length:', self.bit_length) self.security = 40 print('Default security parameter:', self.security) @@ -67,7 +66,7 @@ class Program(object): self._curr_tape = None self.DEBUG = False self.allocated_mem = RegType.create_dict(lambda: USER_MEM) - self.free_mem_blocks = defaultdict(lambda: defaultdict(set)) + self.free_mem_blocks = defaultdict(al.BlockAllocator) self.allocated_mem_blocks = {} self.saved = 0 self.req_num = None @@ -100,6 +99,7 @@ class Program(object): self._edabit = options.edabit self._split = False self._square = False + self._always_raw = False Program.prog = self def get_args(self): @@ -165,20 +165,28 @@ class Program(object): return tape_index def run_tape(self, tape_index, arg): + return self.run_tapes([[tape_index, arg]])[0] + + def run_tapes(self, args): if self.curr_tape is not self.tapes[0]: raise CompilerError('Compiler does not support ' \ 'recursive spawning of threads') - if self.free_threads: - thread_number = min(self.free_threads) - self.free_threads.remove(thread_number) - else: - thread_number = self.n_threads - self.n_threads += 1 + thread_numbers = [] + while len(thread_numbers) < len(args): + if self.free_threads: + thread_numbers.append(min(self.free_threads)) + self.free_threads.remove(thread_numbers[-1]) + else: + thread_numbers.append(self.n_threads) + self.n_threads += 1 self.curr_tape.start_new_basicblock(name='pre-run_tape') - Compiler.instructions.run_tape(thread_number, arg, tape_index) + Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in + zip(thread_numbers, args)), [])) self.curr_tape.start_new_basicblock(name='post-run_tape') - self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree) - return thread_number + for arg in args: + self.curr_tape.req_node.children.append( + self.tapes[arg[0]].req_tree) + return thread_numbers def join_tape(self, thread_number): self.curr_tape.start_new_basicblock(name='pre-join_tape') @@ -250,19 +258,9 @@ class Program(object): mem_type = mem_type.reg_type elif reg_type is not None: self.types[mem_type] = reg_type - block_size = 0 blocks = self.free_mem_blocks[mem_type] - if len(blocks[size]) > 0: - block_size = size - else: - for block_size, addresses in blocks.items(): - if block_size >= size and len(addresses) > 0: - break - else: - block_size = 0 - if block_size >= size: - addr = self.free_mem_blocks[mem_type][block_size].pop() - self.free_mem_blocks[mem_type][block_size - size].add(addr + size) + addr = blocks.pop(size) + if addr is not None: self.saved += size else: addr = self.allocated_mem[mem_type] @@ -278,7 +276,7 @@ class Program(object): is not self.curr_tape.basicblocks[0].alloc_pool: raise CompilerError('Cannot free memory within function block') size = self.allocated_mem_blocks.pop((addr,mem_type)) - self.free_mem_blocks[mem_type][size].add(addr) + self.free_mem_blocks[mem_type].push(addr, size) def finalize_memory(self): from . import library @@ -341,6 +339,20 @@ class Program(object): else: self._square = change + def always_raw(self, change=None): + if change is None: + return self._always_raw + else: + self._always_raw = change + + def options_from_args(self): + if 'trunc_pr' in self.args: + self.use_trunc_pr = True + if 'split' in self.args or 'split3' in self.args: + self.use_split(3) + if 'raw' in self.args: + self.always_raw(True) + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -445,13 +457,14 @@ class Tape: instructions = self.instructions for inst in instructions: inst.add_usage(req_node) - req_node.num['all', 'round'] = self.n_rounds - req_node.num['all', 'inv'] = self.n_to_merge + req_node.num['all', 'round'] += self.n_rounds + req_node.num['all', 'inv'] += self.n_to_merge def expand_cisc(self): new_instructions = [] for inst in self.instructions: new_instructions.extend(inst.expand_merged()) + self.n_rounds += inst.expanded_rounds() self.instructions = new_instructions def __str__(self): @@ -503,9 +516,6 @@ class Tape: def unpurged(function): def wrapper(self, *args, **kwargs): if self.purged: - if self.program.verbose: - print('%s called on purged block %s, ignoring' % \ - (function.__name__, self.name)) return return function(self, *args, **kwargs) return wrapper @@ -585,7 +595,8 @@ class Tape: reg_counts = self.count_regs() if not options.noreallocate: if self.program.verbose: - print('Tape register usage:', dict(reg_counts)) + print('Tape register usage before re-allocation:', + dict(reg_counts)) print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) print('Re-allocating...') @@ -613,6 +624,8 @@ class Tape: alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) allocator.finalize(options) + if self.program.verbose: + print('Tape register usage:', dict(allocator.usage)) # offline data requirements if self.program.verbose: @@ -847,10 +860,6 @@ class Tape: if t == 'p': self.req_bit_length[t] = max(bit_length + 1, \ self.req_bit_length[t]) - if self.program.param != -1 and bit_length >= self.program.param: - raise CompilerError('Inadequate bit length %d for prime, ' \ - 'program requires %d bits' % \ - (self.program.param, self.req_bit_length['p'])) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @@ -862,6 +871,7 @@ class Tape: __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ "size", "vector", "vectorbase", "caller", \ "can_eliminate", "duplicates"] + maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): """ Creates a new register. @@ -875,6 +885,8 @@ class Tape: self.program = program if size is None: size = Compiler.instructions_base.get_global_vector_size() + if size is not None and size > self.maximum_size: + raise CompilerError('vector too large') self.size = size self.vectorbase = self self.relative_i = 0 diff --git a/Compiler/types.py b/Compiler/types.py index 813836f0..2df15abd 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -147,7 +147,8 @@ def vectorize(operation): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise CompilerError('Different vector sizes of operands') + raise CompilerError('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) res = operation(self, *args, **kwargs) reset_global_vector_size() @@ -789,7 +790,7 @@ class cint(_clear, _int): bit_length = 1 + int(math.ceil(math.log(abs(val)))) if program.options.ring: assert(bit_length <= int(program.options.ring)) - elif program.param != -1 or program.options.field: + elif program.options.field: program.curr_tape.require_bit_length(bit_length) if self.in_immediate_range(val): ldi(self, val) @@ -1731,6 +1732,15 @@ class sint(_secret, _int): """ Secret random n-bit number according to security model. :param bits: compile-time integer (int) """ + if program.use_split() == 3: + tmp = sint() + randoms(tmp, bits) + x = tmp.split_to_two_summands(bits, True) + overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \ + sint.conv(x[0][-1]) + return tmp - (overflow << bits) + elif program.use_edabit(): + return sint.get_edabit(bits, True)[0] res = sint() comparison.PRandInt(res, bits) return res @@ -2068,6 +2078,37 @@ class sint(_secret, _int): def two_power(n): return floatingpoint.two_power(n) + def split_to_n_summands(self, length, n): + from .GC.types import sbits + from .GC.instructions import split + columns = [[sbits.get_type(self.size)() + for i in range(n)] for i in range(length)] + split(n, self, *sum(columns, [])) + return columns + + def split_to_two_summands(self, length, get_carry=False): + n = program.use_split() + assert n + columns = self.split_to_n_summands(length, n) + return _bitint.wallace_tree_without_finish(columns, get_carry) + + @vectorize + def reveal_to(self, player): + """ Reveal secret value to :py:obj:`player`. + Result potentially written to ``Player-Data/Private-Output-P.`` + + :param player: public integer (int/regint/cint): + :returns: value to be used with :py:func:`Compiler.library.print_ln_to` + """ + if not util.is_constant(player) or self.size > 1: + secret_mask = sint() + player_mask = cint() + inputmaskreg(secret_mask, player_mask, player) + return personal(player, + (self + secret_mask).reveal() - player_mask) + else: + return super(sint, self).reveal_to(player) + class sgf2n(_secret, _gf2n): """ Secret GF(2^n) value. """ __slots__ = [] @@ -2255,7 +2296,8 @@ class _bitint(object): def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, get_carry=False): lower = [] - for (ai,bi) in zip(a,b): + a, b = a[:], b[:] + for (ai, bi) in zip(a[:], b[:]): if is_zero(ai) or is_zero(bi): lower.append(ai + bi) a.pop(0) @@ -2404,6 +2446,7 @@ class _bitint(object): @classmethod def wallace_tree_without_finish(cls, columns, get_carry=True): self = cls + columns = [col[:] for col in columns] while max(len(c) for c in columns) > 2: new_columns = [[] for i in range(len(columns) + 1)] for i,col in enumerate(columns): @@ -2439,11 +2482,12 @@ class _bitint(object): raise CompilerError('Unclear subtraction') a = self.bit_decompose() b = util.bit_decompose(other, self.n_bits) - d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)] + d = [(reduce(util.bit_xor, (ai, bi, 1)), (1 - ai) * bi) + for (ai,bi) in zip(a,b)] borrow = lambda y,x,*args: \ (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1] - return self.compose(ai + bi + borrow \ + return self.compose(reduce(util.bit_xor, (ai, bi, borrow)) \ for (ai,bi,borrow) in zip(a,b,borrows)) def __rsub__(self, other): @@ -2715,10 +2759,15 @@ class cfix(_number, _structure): def set_precision(cls, f, k = None): """ Set the precision of the integer representation. Note that some operations are undefined when the precision of :py:class:`sfix` and - :py:class:`cfix` differs. + :py:class:`cfix` differs. The initial defaults are chosen to + allow the best optimization of probabilistic truncation in + computation modulo 2^64 (2*k < 64). Generally, 2*k must be at + most the integer length for rings and at most m-s-1 for + computation modulo an m-bit prime and statistical security s + (default 40). - :param f: bit length of decimal part - :param k: whole bit length of fixed point, defaults to twice :py:obj:`f`. + :param f: bit length of decimal part (initial default 16) + :param k: whole bit length of fixed point, defaults to twice :py:obj:`f` if not given (initial default 31) """ cls.f = f @@ -2789,6 +2838,10 @@ class cfix(_number, _structure): else: raise CompilerError('cannot initialize cfix with %s' % v) + def __iter__(self): + for x in self.v: + yield type(self)(x, self.k, self.f) + @vectorize def load_int(self, v): self.v = cint(v) * (2 ** self.f) @@ -3141,7 +3194,6 @@ class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k', 'size'] - @classmethod def set_precision(cls, f, k = None): cls.f = f # default bitlength = 2*precision @@ -3152,6 +3204,7 @@ class _fix(_single): raise CompilerError('bit length cannot be less than precision') cls.k = k set_precision.__doc__ = cfix.set_precision.__doc__ + set_precision = classmethod(set_precision) @classmethod def coerce(cls, other): @@ -3377,9 +3430,10 @@ class sfix(_fix): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Raw representation written to ``Player-Data/Private-Output-P`` + Raw representation possibly written to + ``Player-Data/Private-Output-P.`` - :param player: int + :param player: public integer (int/regint/cint) :returns: value to be used with :py:func:`Compiler.library.print_ln_to` """ return personal(player, cfix(self.v.reveal_to(player)._v, @@ -3419,17 +3473,8 @@ class unreduced_sfix(_single): sfix.unreduced_type = unreduced_sfix -# this is for 20 bit decimal precision -# with 40 bitlength of entire number -# these constants have been chosen for multiplications to fit in 128 bit prime field -# (precision n1) 41 + (precision n2) 41 + (stat_sec) 40 = 82 + 40 = 122 <= 128 -# with statistical security of 40 - -fixed_lower = 20 -fixed_upper = 40 - -sfix.set_precision(fixed_lower, fixed_upper) -cfix.set_precision(fixed_lower, fixed_upper) +sfix.set_precision(16, 31) +cfix.set_precision(16, 31) class squant(_single): """ Quantization as in ArXiv:1712.05877v1 """ @@ -4222,8 +4267,10 @@ class Array(object): :param value: convertible to basic type """ if conv: value = self.value_type.conv(value) + if value.size != 1: + raise CompilerError('cannot assign vector to all elements') mem_value = MemValue(value) - self.address = MemValue(self.address) + self.address = MemValue.if_necessary(self.address) n_threads = 8 if use_threads and len(self) > 2**20 else 1 @library.for_range_multithread(n_threads, 1024, len(self)) def f(i): @@ -4242,7 +4289,7 @@ class Array(object): def get(self, indices): return self.value_type.load_mem( - regint(self.address, size=len(indices)) + indices, + regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) def expand_to_vector(self, index, size): @@ -4258,13 +4305,13 @@ class Array(object): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - if raw: + if raw or program.always_raw(): input_from = self.value_type.get_raw_input_from else: input_from = self.value_type.get_input_from try: self.assign(input_from(player, size=len(self))) - except: + except TypeError: @library.for_range_opt(len(self), budget=budget) def _(i): self[i] = input_from(player) @@ -4318,6 +4365,12 @@ class Array(object): :returns: Array of relevant clear type. """ return Array.create_from(x.reveal() for x in self) + def reveal_list(self): + """ Reveal as list. """ + return list(self.get_vector().reveal()) + + reveal_nested = reveal_list + sint.dynamic_array = Array sgf2n.dynamic_array = Array @@ -4454,16 +4507,17 @@ class SubMultiArray(object): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - budget = budget or 2 ** 21 + budget = budget or Tape.Register.maximum_size if (self.total_size() < budget) and \ self.value_type.n_elements() == 1: - if raw: + if raw or program.always_raw(): input_from = self.value_type.get_raw_input_from else: input_from = self.value_type.get_input_from self.assign_vector(input_from(player, size=self.total_size())) else: - @library.for_range_opt(self.sizes[0], budget=budget) + @library.for_range_opt(self.sizes[0], + budget=budget / self[0].total_size()) def _(i): self[i].input_from(player, budget=budget, raw=raw) @@ -4687,6 +4741,21 @@ class SubMultiArray(object): library.break_point() return res + def reveal_list(self): + """ Reveal as list. """ + return list(self.get_vector().reveal()) + + def reveal_nested(self): + """ Reveal as nested list. """ + flat = iter(self.get_vector().reveal()) + res = [] + def f(sizes): + if len(sizes) == 1: + return [next(flat) for i in range(sizes[0])] + else: + return [f(sizes[1:]) for i in range(sizes[0])] + return f(self.sizes) + class MultiArray(SubMultiArray): """ Multidimensional array. """ def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): @@ -4836,10 +4905,12 @@ class MemValue(_mem): self.value_type = type(value) self.deleted = False if address is None: - self.address = self.value_type.malloc(1) + self.address = self.value_type.malloc(value.size) + self.size = value.size self.write(value) else: self.address = address + self.size = 1 def delete(self): self.value_type.free(self.address) @@ -4869,6 +4940,8 @@ class MemValue(_mem): elif isinstance(value, int): self.register = self.value_type(value) else: + if value.size != self.size: + raise CompilerError('size mismatch') self.register = value if not isinstance(self.register, self.value_type): raise CompilerError('Mismatch in register type, cannot write \ diff --git a/Compiler/util.py b/Compiler/util.py index 6bcce91c..d7109cd3 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -46,8 +46,10 @@ def right_shift(a, b, bits): else: return a.right_shift(b, bits) -def bit_decompose(a, bits): +def bit_decompose(a, bits=None): if isinstance(a, int): + if bits is None: + bits = int_len(a) return [int((a >> i) & 1) for i in range(bits)] else: return a.bit_decompose(bits) @@ -122,6 +124,15 @@ def or_op(a, b): OR = or_op +def bit_xor(a, b): + if is_constant(a): + if is_constant(b): + return a ^ b + else: + return b.bit_xor(a) + else: + return a.bit_xor(b) + def pow2(bits): powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)] return tree_reduce(operator.mul, powers) diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp index 11f01053..4fb9e7e3 100644 --- a/FHEOffline/Proof.cpp +++ b/FHEOffline/Proof.cpp @@ -92,8 +92,8 @@ bool Proof::check_bounds(T& z, X& t, int i) const unsigned int j,k; // Check Bound 1 and Bound 2 - AbsoluteBoundChecker> plain_checker(plain_check * n_proofs); - AbsoluteBoundChecker> rand_checker(rand_check * n_proofs); + AbsoluteBoundChecker> plain_checker(plain_check * n_proofs); + AbsoluteBoundChecker> rand_checker(rand_check * n_proofs); for (j=0; j::create_more() this->c.unpack(ciphertexts, this->pk); commitments.unpack(ciphertexts, this->pk); +#ifdef VERBOSE_HE cout << "Tree-wise sum of ciphertexts with " << 1e-9 * ciphertexts.get_length() << " GB" << endl; +#endif this->timers["Exchanging ciphertexts"].start(); tree_sum.run(this->c, P); tree_sum.run(commitments, P); diff --git a/GC/ArgTuples.h b/GC/ArgTuples.h index 59904bfb..5a4ec6a7 100644 --- a/GC/ArgTuples.h +++ b/GC/ArgTuples.h @@ -13,9 +13,12 @@ template class ArgIter { vector::const_iterator it; + vector::const_iterator end; public: - ArgIter(const vector::const_iterator it) : it(it) + ArgIter(const vector::const_iterator it, + const vector::const_iterator end) : + it(it), end(end) { } @@ -27,8 +30,10 @@ public: ArgIter operator++() { auto res = it; - it += T::n; - return res; + it += T(res).n; + if (it > end) + throw runtime_error("wrong number of args"); + return {res, end}; } bool operator!=(const ArgIter& other) @@ -46,18 +51,16 @@ public: ArgList(const vector& args) : args(args) { - if (args.size() % T::n != 0) - throw runtime_error("wrong number of args"); } ArgIter begin() { - return args.begin(); + return {args.begin(), args.end()}; } ArgIter end() { - return args.end(); + return {args.end(), args.end()}; } }; @@ -81,11 +84,12 @@ public: } }; -class InputArgList : public ArgList +template +class InputArgListBase : public ArgList { public: - InputArgList(const vector& args) : - ArgList(args) + InputArgListBase(const vector& args) : + ArgList(args) { } @@ -97,7 +101,55 @@ public: return res; } + int n_input_bits() + { + int res = 0; + for (auto x : *this) + res += x.n_bits; + return res; + } + int n_interactive_inputs_from_me(int my_num); }; +class InputArgList : public InputArgListBase +{ +public: + InputArgList(const vector& args) : + InputArgListBase(args) + { + } +}; + +class InputVecArgs +{ +public: + int from; + int n; + int& n_bits; + int& n_shift; + int params[2]; + vector dest; + + InputVecArgs(vector::const_iterator it) : n_bits(params[0]), n_shift(params[1]) + { + n = *it++; + n_bits = n - 3; + n_shift = *it++; + from = *it++; + dest.resize(n); + for (int i = 0; i < n_bits; i++) + dest[i] = *it++; + } +}; + +class InputVecArgList : public InputArgListBase +{ +public: + InputVecArgList(const vector& args) : + InputArgListBase(args) + { + } +}; + #endif /* GC_ARGTUPLES_H_ */ diff --git a/GC/FakeSecret.cpp b/GC/FakeSecret.cpp index 981ef22a..258324e1 100644 --- a/GC/FakeSecret.cpp +++ b/GC/FakeSecret.cpp @@ -12,9 +12,8 @@ namespace GC { -int FakeSecret::default_length = 128; - -ostream& FakeSecret::out = cout; +SwitchableOutput FakeSecret::out; +const int FakeSecret::default_length; void FakeSecret::load_clear(int n, const Integer& x) { @@ -60,12 +59,9 @@ void FakeSecret::store_clear_in_dynamic(Memory& mem, void FakeSecret::ands(Processor& processor, const vector& regs) { - processor.check_args(regs, 4); - for (size_t i = 0; i < regs.size(); i += 4) - processor.S[regs[i + 1]] = processor.S[regs[i + 2]].a & processor.S[regs[i + 3]].a; + processor.ands(regs); } - void FakeSecret::trans(Processor& processor, int n_outputs, const vector& args) { @@ -82,15 +78,13 @@ FakeSecret FakeSecret::input(GC::Processor& processor, const InputAr return input(args.from, processor.get_input(args.params), args.n_bits); } -FakeSecret FakeSecret::input(int from, const int128& input, int n_bits) +FakeSecret FakeSecret::input(int from, word input, int n_bits) { (void)from; (void)n_bits; - FakeSecret res; - res.a = ((__uint128_t)input.get_upper() << 64) + input.get_lower(); - if (res.a > ((__uint128_t)1 << n_bits)) + if (n_bits < 64 and input > word(1) << n_bits) throw out_of_range("input too large"); - return res; + return input; } void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y, @@ -99,7 +93,7 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y, if (repeat) return andrs(n, x, y); else - throw runtime_error("call static FakeSecret::ands()"); + *this = BitVec(x & y).mask(n); } } /* namespace GC */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index d7cb8872..19f3198e 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -12,8 +12,14 @@ #include "GC/ArgTuples.h" #include "Math/gf2nlong.h" +#include "Tools/SwitchableOutput.h" #include "Processor/DummyProtocol.h" +#include "Protocols/FakePrep.h" +#include "Protocols/FakeMC.h" +#include "Protocols/FakeProtocol.h" +#include "Protocols/FakeInput.h" +#include "Protocols/ShareInterface.h" #include #include @@ -26,21 +32,41 @@ class Processor; template class Machine; -class FakeSecret +class FakeSecret : public ShareInterface, public BitVec { - __uint128_t a; - public: typedef FakeSecret DynamicType; typedef Memory DynamicMemory; + typedef BitVec mac_key_type; + typedef BitVec clear; + typedef BitVec open_type; + + typedef FakeSecret part_type; + typedef FakeSecret small_type; + typedef NoShare bit_type; + + typedef FakePrep LivePrep; + typedef FakeMC MC; + typedef MC MAC_Check; + typedef MC Direct_MC; + typedef FakeProtocol Protocol; + typedef FakeInput Input; + static string type_string() { return "fake secret"; } static string phase_name() { return "Faking"; } - static int default_length; + static const int default_length = 64; - typedef ostream& out_type; - static ostream& out; + static const bool is_real = true; + + static const bool actual_inputs = true; + + static SwitchableOutput out; + + static DataFieldType field_type() { return DATA_GF2; } + + static MC* new_mc(mac_key_type key) { return new MC(key); } static void store_clear_in_dynamic(Memory& mem, const vector& accesses); @@ -59,6 +85,11 @@ public: static void inputb(T& processor, const vector& args) { processor.input(args); } template + static void inputb(T& processor, ArithmeticProcessor&, const vector& args) + { processor.input(args); } + template + static void inputbvec(T&, U&, const vector&) { throw not_implemented(); } + template static void reveal_inst(T& processor, const vector& args) { processor.reveal(args); } @@ -69,11 +100,13 @@ public: static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; } static FakeSecret input(GC::Processor& processor, const InputArgs& args); - static FakeSecret input(int from, const int128& input, int n_bits); + static FakeSecret input(int from, word input, int n_bits); - FakeSecret() : a(0) {} - FakeSecret(const Integer& x) : a(x.get()) {} - FakeSecret(__uint128_t x) : a(x) {} + static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; } + + FakeSecret() {} + template + FakeSecret(T other) : BitVec(other) {} __uint128_t operator>>(const FakeSecret& other) const { return a >> other.a; } __uint128_t operator<<(const FakeSecret& other) const { return a << other.a; } @@ -90,17 +123,19 @@ public: void bitdec(Memory& S, const vector& regs) const; template - void xor_(int n, const FakeSecret& x, const T& y) { (void)n; a = x.a ^ y.a; } + void xor_(int n, const FakeSecret& x, const T& y) + { *this = BitVec(x.a ^ y.a).mask(n); } void and_(int n, const FakeSecret& x, const FakeSecret& y, bool repeat); - void andrs(int n, const FakeSecret& x, const FakeSecret& y) { (void)n; a = x.a * y.a; } + void andrs(int n, const FakeSecret& x, const FakeSecret& y) + { *this = BitVec(x.a * (y.a & 1)).mask(n); } - void invert(int, const FakeSecret& x) { *this = ~x.a; } + void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); } void random_bit() { a = random() % 2; } void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; } - int size() { return -1; } + void invert(FakeSecret) { throw not_implemented(); } }; } /* namespace GC */ diff --git a/GC/Instruction.h b/GC/Instruction.h index 030fa2c8..a6b83129 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -65,6 +65,7 @@ enum STMSBI = 0x243, MOVSB = 0x244, INPUTB = 0x246, + INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, // write to clear diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index 4081546f..256eb4fa 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -67,7 +67,7 @@ unsigned Instruction::get_mem(RegType reg_type) const } break; default: - return BaseInstruction::get_mem(reg_type, MAX_SECRECY_TYPE); + return BaseInstruction::get_mem(reg_type); } return 0; diff --git a/GC/Machine.h b/GC/Machine.h index bdc36afc..d3f1b302 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -58,6 +58,7 @@ public: void stop_timer() { timer[0].stop(); } void reset_timer() { timer[0].reset(); } + void run_tapes(const vector& args); void run_tape(int thread_number, int tape_number, int arg); void join_tape(int thread_numer); }; diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 0a21387a..03560b16 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -61,8 +61,8 @@ template template void Memories::reset(const U& program) { - MS.resize_min(*program.direct_mem(SBIT), "memory"); - MC.resize_min(*program.direct_mem(CBIT), "memory"); + MS.resize_min(program.direct_mem(SBIT), "memory"); + MC.resize_min(program.direct_mem(CBIT), "memory"); } template @@ -70,7 +70,7 @@ template void Machine::reset(const U& program) { Memories::reset(program); - MI.resize_min(*program.direct_mem(INT), "memory"); + MI.resize_min(program.direct_mem(INT), "memory"); } template @@ -78,12 +78,20 @@ template void Machine::reset(const U& program, V& MD) { reset(program); - MD.resize_min(*program.direct_mem(DYN_SBIT), "dynamic memory"); + MD.resize_min(program.direct_mem(DYN_SBIT), "dynamic memory"); #ifdef DEBUG_MEMORY cerr << "reset dynamic mem to " << program.direct_mem(DYN_SBIT) << endl; #endif } +template +void Machine::run_tapes(const vector& args) +{ + assert(args.size() % 3 == 0); + for (unsigned i = 0; i < args.size(); i++) + run_tape(args[i], args[i + 1], args[i + 2]); +} + template void Machine::run_tape(int thread_number, int tape_number, int arg) { diff --git a/GC/NoShare.h b/GC/NoShare.h index b34d4077..7cf1ec8d 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -6,15 +6,23 @@ #ifndef GC_NOSHARE_H_ #define GC_NOSHARE_H_ -#include "BMR/Register.h" #include "Processor/DummyProtocol.h" +#include "Tools/SwitchableOutput.h" + +class InputArgs; +class ArithmeticProcessor; namespace GC { +template class Processor; +class Clear; + class NoValue : public ValueInterface { public: + typedef NoValue Scalar; + const static int n_bits = 0; const static int MAX_N_BITS = 0; @@ -28,6 +36,11 @@ public: return 0; } + static int length() + { + return 0; + } + static void fail() { throw runtime_error("VM does not support binary circuits"); @@ -43,9 +56,13 @@ public: int operator<<(int) const { fail(); return 0; } void operator+=(int) { fail(); } + bool operator!=(NoValue) const { fail(); return 0; } + bool get_bit(int) { fail(); return 0; } void randomize(PRNG&) { fail(); } + + void invert() { fail(); } }; inline ostream& operator<<(ostream& o, NoValue) @@ -53,18 +70,11 @@ inline ostream& operator<<(ostream& o, NoValue) return o; } -template -inline bool operator!=(const T&, NoValue&) -{ - NoValue::fail(); - return true; -} - -class NoShare : public Phase +class NoShare { public: typedef DummyMC MC; - typedef DummyProtocol Protocol; + typedef DummyProtocol Protocol; typedef NotImplementedInput Input; typedef DummyLivePrep LivePrep; typedef DummyMC MAC_Check; @@ -83,6 +93,8 @@ public: static const bool expensive_triples = false; static const bool is_real = false; + static SwitchableOutput out; + static MC* new_mc(mac_key_type) { return new MC; @@ -118,13 +130,16 @@ public: NoValue::fail(); } - static void inputb(Processor&, const vector&) { fail(); } + static void inputb(Processor&, ArithmeticProcessor&, const vector&) { fail(); } static void reveal_inst(Processor&, const vector&) { fail(); } + static void xors(Processor&, const vector&) { fail(); } + static void ands(Processor&, const vector&) { fail(); } + static void andrs(Processor&, const vector&) { fail(); } static void input(Processor&, InputArgs&) { fail(); } static void trans(Processor&, Integer, const vector&) { fail(); } - static NoShare constant(GC::Clear, int, mac_key_type) { fail(); return {}; } + static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; } NoShare() {} @@ -146,6 +161,8 @@ public: void operator^=(NoShare) { fail(); } NoShare operator+(const NoShare&) const { fail(); return {}; } + NoShare operator-(NoShare) const { fail(); return 0; } + NoShare operator*(NoValue) const { fail(); return 0; } NoShare operator+(int) const { fail(); return {}; } NoShare operator&(int) const { fail(); return {}; } @@ -153,6 +170,8 @@ public: NoShare lsb() const { fail(); return {}; } NoShare get_bit(int) const { fail(); return {}; } + + void invert(int, NoShare) { fail(); } }; } /* namespace GC */ diff --git a/GC/Processor.h b/GC/Processor.h index 492e306a..3703cac5 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -20,17 +20,6 @@ using namespace std; namespace GC { -class ExecutionStats : public map -{ -public: - ExecutionStats& operator+=(const ExecutionStats& other) - { - for (auto it : other) - (*this)[it.first] += it.second; - return *this; - } -}; - template class Processor : public ::ProcessorBase, public GC::RuntimeBranching { @@ -53,8 +42,6 @@ public: Memory C; Memory I; - ExecutionStats stats; - Timer xor_timer; Processor(Machine& machine); @@ -102,9 +89,12 @@ public: void ands(const vector& args) { and_(args, false); } void input(const vector& args); - void reveal(const vector& args); + void inputb(typename T::Input& input, ProcessorBase& input_processor, + const vector& args, int my_num); + void inputbvec(typename T::Input& input, ProcessorBase& input_processor, + const vector& args, int my_num); - void reveal(const ::BaseInstruction& instruction); + void reveal(const vector& args); void print_reg(int reg, int n, int size); void print_reg_plain(Clear& value); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 16c638fb..a91c24e0 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -79,6 +79,8 @@ template U GC::Processor::get_long_input(const int* params, ProcessorBase& input_proc, bool interactive) { + if (not T::actual_inputs) + return {}; U res = input_proc.get_input>(interactive, ¶ms[1]).items[0]; int n_bits = *params; @@ -251,8 +253,12 @@ void Processor::and_(const vector& args, bool repeat) check_args(args, 4); for (size_t i = 0; i < args.size(); i += 4) { - assert(args[i] <= T::default_length); - S[args[i+1]].and_(args[i], S[args[i+2]], S[args[i+3]], repeat); + for (int j = 0; j < DIV_CEIL(args[i], T::default_length); j++) + { + int n = min(T::default_length, args[i] - j * T::default_length); + S[args[i + 1] + j].and_(n, S[args[i + 2] + j], + S[args[i + 3] + (repeat ? 0 : j)], repeat); + } complexity += args[i]; } } diff --git a/GC/Program.h b/GC/Program.h index 5afc5fed..8280c3f7 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -48,8 +48,8 @@ class Program unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } - const unsigned* direct_mem(RegType reg_type) const - { return &max_mem[reg_type]; } + unsigned direct_mem(RegType reg_type) const + { return max_mem[reg_type]; } template BreakType execute(Processor& Proc, U& dynamic_memory, int PC = -1) const; diff --git a/GC/Secret.h b/GC/Secret.h index d2596a10..f8a11b49 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -60,6 +60,8 @@ public: typedef NoShare bit_type; + typedef typename T::Input Input; + static string type_string() { return "evaluation secret"; } static string phase_name() { return T::name(); } @@ -71,6 +73,8 @@ public: static const bool is_real = true; + static const bool actual_inputs = T::actual_inputs; + static Secret input(party_id_t from, const int128& input, int n_bits = -1); static Secret input(Processor>& processor, const InputArgs& args); void random(int n_bits, int128 share); @@ -102,6 +106,10 @@ public: const vector& args) { T::inputb(processor, input_proc, args); } template + static void inputbvec(Processor& processor, ProcessorBase& input_proc, + const vector& args) + { T::inputbvec(processor, input_proc, args); } + template static void reveal_inst(Processor& processor, const vector& args) { processor.reveal(args); } @@ -143,6 +151,13 @@ public: template void reveal(size_t n_bits, U& x); + template + void my_input(U& inputter, BitVec value, int n_bits); + template + void other_input(U& inputter, int from, int n_bits); + template + void finalize_input(U& inputter, int from, int n_bits); + int size() const { return registers.size(); } RegVector& get_regs() { return registers; } const RegVector& get_regs() const { return registers; } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 3ca22a4a..d2b16901 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -57,6 +57,33 @@ Secret Secret::input(party_id_t from, const int128& input, int n_bits) return res; } +template +template +void GC::Secret::my_input(U& inputter, BitVec value, int n_bits) +{ + resize_regs(n_bits); + for (int i = 0; i < n_bits; i++) + get_reg(i).my_input(inputter, value.get_bit(i), 1); +} + +template +template +void GC::Secret::other_input(U& inputter, int from, int n_bits) +{ + resize_regs(n_bits); + for (int i = 0; i < n_bits; i++) + get_reg(i).other_input(inputter, from); +} + +template +template +void GC::Secret::finalize_input(U& inputter, int from, int n_bits) +{ + resize_regs(n_bits); + for (int i = 0; i < n_bits; i++) + get_reg(i).finalize_input(inputter, from, 1); +} + template void Secret::random(int n_bits, int128 share) { diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index b92d3c56..1a7f8be3 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -41,9 +41,8 @@ public: typedef Memory DynamicMemory; typedef SwitchableOutput out_type; - static const bool expensive_triples = false; - static const bool is_real = true; + static const bool actual_inputs = true; static SwitchableOutput out; @@ -63,6 +62,8 @@ public: { inputb(processor, processor, args); } static void inputb(Processor& processor, ProcessorBase& input_processor, const vector& args); + static void inputbvec(Processor& processor, ProcessorBase& input_processor, + const vector& args); static void reveal_inst(Processor& processor, const vector& args); template @@ -75,6 +76,13 @@ public: void invert(int n, const U& x); void random_bit(); + + template + void my_input(T& inputter, BitVec value, int n_bits); + template + void other_input(T& inputter, int from, int n_bits = 1); + template + void finalize_input(T& inputter, int from, int n_bits); }; template @@ -169,6 +177,8 @@ public: typedef SemiHonestRepSecret small_type; typedef SemiHonestRepSecret whole_type; + static const bool expensive_triples = false; + static MC* new_mc(mac_key_type) { return new MC; } SemiHonestRepSecret() {} diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 6d95a231..f9a45361 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -98,6 +98,28 @@ void ShareSecret::store_clear_in_dynamic(Memory& mem, mem[access.address] = access.value; } +template +template +void GC::ShareSecret::my_input(T& inputter, BitVec value, int n_bits) +{ + inputter.add_mine(value, n_bits); +} + +template +template +void GC::ShareSecret::other_input(T& inputter, int from, int) +{ + inputter.add_other(from); +} + +template +template +void GC::ShareSecret::finalize_input(T& inputter, int from, + int n_bits) +{ + static_cast(*this) = inputter.finalize(from, n_bits).mask(n_bits); +} + template void ShareSecret::inputb(Processor& processor, ProcessorBase& input_processor, @@ -106,25 +128,46 @@ void ShareSecret::inputb(Processor& processor, auto& party = ShareThread::s(); typename U::Input input(*party.MC, party.DataF, *party.P); input.reset_all(*party.P); + processor.inputb(input, input_processor, args, party.P->my_num()); +} +template +void ShareSecret::inputbvec(Processor& processor, + ProcessorBase& input_processor, + const vector& args) +{ + auto& party = ShareThread::s(); + typename U::Input input(*party.MC, party.DataF, *party.P); + input.reset_all(*party.P); + processor.inputbvec(input, input_processor, args, party.P->my_num()); +} + +template +void Processor::inputb(typename T::Input& input, ProcessorBase& input_processor, + const vector& args, int my_num) +{ InputArgList a(args); - bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0; - int dl = U::default_length; + complexity += a.n_input_bits(); + bool interactive = a.n_interactive_inputs_from_me(my_num) > 0; + int dl = T::default_length; for (auto x : a) { - if (x.from == party.P->my_num()) + if (x.from == my_num) { - bigint whole_input = processor.template - get_long_input(x.params, + bigint whole_input = get_long_input(x.params, input_processor, interactive); for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) - input.add_mine(bigint(whole_input >> (i * dl)).get_si(), + { + auto& res = S[x.dest + i]; + res.my_input(input, bigint(whole_input >> (i * dl)).get_si(), min(dl, x.n_bits - i * dl)); + } } else for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) - input.add_other(x.from); + S[x.dest + i].other_input(input, x.from, + min(dl, x.n_bits - i * dl)); } if (interactive) @@ -138,9 +181,51 @@ void ShareSecret::inputb(Processor& processor, int n_bits = x.n_bits; for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) { - auto& res = processor.S[x.dest + i]; + auto& res = S[x.dest + i]; int n = min(dl, n_bits - i * dl); - res = input.finalize(from, n).mask(n); + res.finalize_input(input, from, n); + } + } +} + +template +void Processor::inputbvec(typename T::Input& input, ProcessorBase& input_processor, + const vector& args, int my_num) +{ + InputVecArgList a(args); + complexity += a.n_input_bits(); + bool interactive = a.n_interactive_inputs_from_me(my_num) > 0; + + for (auto x : a) + { + if (x.from == my_num) + { + bigint whole_input = get_long_input(x.params, + input_processor, interactive); + for (int i = 0; i < x.n_bits; i++) + { + auto& res = S[x.dest[i]]; + res.my_input(input, bigint(whole_input >> (i)).get_si() & 1, 1); + } + } + else + for (int i = 0; i < x.n_bits; i++) + S[x.dest[i]].other_input(input, x.from, 1); + } + + if (interactive) + cout << "Thank you" << endl; + + input.exchange(); + + for (auto x : a) + { + int from = x.from; + int n_bits = x.n_bits; + for (int i = 0; i < n_bits; i++) + { + auto& res = S[x.dest[i]]; + res.finalize_input(input, from, 1); } } } diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 9a61651e..a084f8c7 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -57,6 +57,9 @@ public: void pre_run(); void post_run() { ShareThread::post_run(); } + + size_t data_sent() + { return Thread::data_sent() + this->DataF.data_sent(); } }; template diff --git a/GC/Thread.h b/GC/Thread.h index 56fc87cc..d7734165 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -56,7 +56,7 @@ public: void join_tape(); void finish(); - int n_interactive_inputs_from_me(InputArgList& args); + virtual size_t data_sent(); }; template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 7e504755..5dba9910 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -91,15 +91,17 @@ void Thread::finish() } template -int Thread::n_interactive_inputs_from_me(InputArgList& args) +size_t GC::Thread::data_sent() { - return args.n_interactive_inputs_from_me(P->my_num()); + assert(P); + return P->comm_stats.total_data(); } } /* namespace GC */ -inline int InputArgList::n_interactive_inputs_from_me(int my_num) +template +inline int InputArgListBase::n_interactive_inputs_from_me(int my_num) { int res = 0; if (ArithmeticProcessor().use_stdin()) diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 9e44bcbf..9a47e43a 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -87,10 +87,12 @@ void ThreadMaster::run() NamedCommStats stats = P->comm_stats; ExecutionStats exe_stats; + size_t data_sent = P->comm_stats.total_data(); for (auto thread : threads) { stats += thread->P->comm_stats; exe_stats += thread->processor.stats; + data_sent += thread->data_sent(); delete thread; } @@ -108,7 +110,8 @@ void ThreadMaster::run() if (it->second.data > 0) cerr << it->first << " " << 1e-6 * it->second.data << " MB" << endl; - cerr << "Time = " << timer.elapsed() << endl; + cerr << "Time = " << timer.elapsed() << " seconds" << endl; + cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl; } } /* namespace GC */ diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 30ec5a4a..6270ad34 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -151,6 +151,18 @@ public: for (auto& reg : this->get_regs()) reg.output(s, human); } + + template + void my_input(U& inputter, BitVec value, int n_bits) + { + inputter.add_mine(value, n_bits); + } + + template + void finalize_input(U& inputter, int from, int n_bits) + { + *this = inputter.finalize(from, n_bits).mask(n_bits); + } }; template diff --git a/GC/instructions.h b/GC/instructions.h index d75ee2c3..9143d01d 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -71,6 +71,7 @@ #define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \ + X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \ X(ANDM, processor.andm(instruction)) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ @@ -85,6 +86,7 @@ #define GC_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, EXTRA)) \ + X(INPUTBVEC, T::inputbvec(PROC, PROC, EXTRA)) \ X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \ X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \ X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \ @@ -130,7 +132,7 @@ X(PRINTINT, S0.out << I0) \ X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \ X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \ - X(RUN_TAPE, MACH->run_tape(R0, IMM, REG1)) \ + X(RUN_TAPE, MACH->run_tapes(EXTRA)) \ X(JOIN_TAPE, MACH->join_tape(R0)) \ X(USE, ) \ diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp new file mode 100644 index 00000000..24377bc9 --- /dev/null +++ b/Machines/emulate.cpp @@ -0,0 +1,59 @@ +/* + * emulate.cpp + * + */ + +#include "Protocols/FakeShare.h" +#include "Processor/Machine.h" +#include "Math/Z2k.h" +#include "Math/gf2n.h" +#include "Processor/RingOptions.h" + +#include "Processor/Machine.hpp" +#include "Math/Z2k.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/ShuffleSacrifice.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/FakeShare.hpp" + +SwitchableOutput GC::NoShare::out; + +int main(int argc, const char** argv) +{ + assert(argc > 1); + OnlineOptions online_opts; + Names N(0, 9999, vector({"localhost"})); + ez::ezOptionParser opt; + RingOptions ring_opts(opt, argc, argv); + opt.parse(argc, argv); + string progname; + if (opt.firstArgs.size() > 1) + progname = *opt.firstArgs.at(1); + else if (not opt.lastArgs.empty()) + progname = *opt.lastArgs.at(0); + else if (not opt.unknownArgs.empty()) + progname = *opt.unknownArgs.at(0); + else + { + string usage; + opt.getUsage(usage); + cerr << usage << endl; + exit(1); + } + + switch (ring_opts.R) + { + case 64: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; + case 128: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; + default: + cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl; + } +} diff --git a/Machines/mal-rep-bmr-party.cpp b/Machines/mal-rep-bmr-party.cpp index 3fda28fa..09cfc5fa 100644 --- a/Machines/mal-rep-bmr-party.cpp +++ b/Machines/mal-rep-bmr-party.cpp @@ -5,9 +5,8 @@ #include "Protocols/MaliciousRep3Share.h" -#include "Machines/Rep.hpp" - #include "BMR/RealProgramParty.hpp" +#include "Machines/Rep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/mal-shamir-bmr-party.cpp b/Machines/mal-shamir-bmr-party.cpp index 4f360a69..d086f305 100644 --- a/Machines/mal-shamir-bmr-party.cpp +++ b/Machines/mal-shamir-bmr-party.cpp @@ -3,9 +3,8 @@ * */ -#include "Machines/ShamirMachine.hpp" - #include "BMR/RealProgramParty.hpp" +#include "Machines/ShamirMachine.hpp" #include "Math/Z2k.hpp" int main(int argc, const char** argv) diff --git a/Machines/malicious-rep-bin-party.cpp b/Machines/malicious-rep-bin-party.cpp index a90f4b8e..8986b11f 100644 --- a/Machines/malicious-rep-bin-party.cpp +++ b/Machines/malicious-rep-bin-party.cpp @@ -15,7 +15,6 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" -#include "Processor/Machine.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/MAC_Check_Base.hpp" diff --git a/Machines/real-bmr-party.cpp b/Machines/real-bmr-party.cpp index 6872ef39..dd13d173 100644 --- a/Machines/real-bmr-party.cpp +++ b/Machines/real-bmr-party.cpp @@ -3,9 +3,8 @@ * */ -#include "Machines/SPDZ.cpp" - #include "BMR/RealProgramParty.hpp" +#include "Machines/SPDZ.hpp" int main(int argc, const char** argv) { diff --git a/Machines/rep-bmr-party.cpp b/Machines/rep-bmr-party.cpp index 579c84d7..f6bb669b 100644 --- a/Machines/rep-bmr-party.cpp +++ b/Machines/rep-bmr-party.cpp @@ -3,9 +3,8 @@ * */ -#include "Machines/Rep.hpp" - #include "BMR/RealProgramParty.hpp" +#include "Machines/Rep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/replicated-bin-party.cpp b/Machines/replicated-bin-party.cpp index 39e41567..2043d7f2 100644 --- a/Machines/replicated-bin-party.cpp +++ b/Machines/replicated-bin-party.cpp @@ -14,7 +14,6 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" -#include "Processor/Machine.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/MAC_Check_Base.hpp" diff --git a/Machines/replicated-field-party.cpp b/Machines/replicated-field-party.cpp index 77edb801..2dcb8cb1 100644 --- a/Machines/replicated-field-party.cpp +++ b/Machines/replicated-field-party.cpp @@ -4,12 +4,10 @@ */ #include "Math/gfp.hpp" -#include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/ReplicatedFieldMachine.hpp" #include "Machines/Rep.hpp" int main(int argc, const char** argv) { - ez::ezOptionParser opt; - ReplicatedMachine, Rep3Share>(argc, argv, - "replicated-field", opt); + ReplicatedFieldMachine(argc, argv); } diff --git a/Machines/semi-bmr-party.cpp b/Machines/semi-bmr-party.cpp new file mode 100644 index 00000000..b087f988 --- /dev/null +++ b/Machines/semi-bmr-party.cpp @@ -0,0 +1,12 @@ +/* + * semi-bmr-party.cpp + * + */ + +#include "BMR/RealProgramParty.hpp" +#include "Machines/Semi.hpp" + +int main(int argc, const char** argv) +{ + RealProgramParty>(argc, argv); +} diff --git a/Machines/shamir-bmr-party.cpp b/Machines/shamir-bmr-party.cpp index 719bbd3b..e6fe0ac8 100644 --- a/Machines/shamir-bmr-party.cpp +++ b/Machines/shamir-bmr-party.cpp @@ -3,9 +3,8 @@ * */ -#include "Machines/ShamirMachine.hpp" - #include "BMR/RealProgramParty.hpp" +#include "Machines/ShamirMachine.hpp" #include "Math/Z2k.hpp" int main(int argc, const char** argv) diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 695d9178..dfc60ffc 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -19,7 +19,6 @@ #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" -#include "Processor/Machine.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MAC_Check.hpp" #include "Protocols/MAC_Check_Base.hpp" diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp index e14685a4..0ff36835 100644 --- a/Machines/tiny-party.cpp +++ b/Machines/tiny-party.cpp @@ -19,7 +19,6 @@ #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" -#include "Processor/Machine.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MAC_Check.hpp" #include "Protocols/MAC_Check_Base.hpp" diff --git a/Makefile b/Makefile index c0e12092..4a9b4463 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ include $(wildcard *.d static/*.d) %.o: %.cpp $(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c -online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x +online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x offline: $(OT_EXE) Check-Offline.x @@ -193,6 +193,8 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) +emulate.x: GC/FakeSecret.o +semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(LIBSIMPLEOT): SimpleOT/Makefile $(MAKE) -C SimpleOT diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 4ef7d096..9c48a3d7 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -138,6 +138,8 @@ void write_online_setup(string dirname, const bigint& p) ofstream outf; outf.open(ss.str().c_str()); outf << p << endl; + if (!outf.good()) + throw file_error("cannot write to " + ss.str()); } void init_gf2n(int lg2) diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index 041a3f0b..32cbb5b7 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -10,6 +10,7 @@ class OnlineOptions; class bigint; +class PRNG; class ValueInterface { @@ -30,6 +31,8 @@ public: static int power_of_two(bool, int) { throw not_implemented(); } void normalize() {} + + void randomize_part(PRNG&, int) { throw not_implemented(); } }; #endif /* MATH_VALUEINTERFACE_H_ */ diff --git a/Math/Z2k.h b/Math/Z2k.h index 395f59d7..80cfa81b 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -117,7 +117,7 @@ public: Z2 operator-(const Z2& other) const; template - Z2 operator*(const Z2& other) const; + Z2<(K > L) ? K : L> operator*(const Z2& other) const; Z2 operator*(bool other) const { return other ? *this : Z2(); } Z2 operator*(int other) const { return *this * Z2(other); } @@ -171,6 +171,7 @@ public: void XOR(const Z2& a, const Z2& b); void randomize(PRNG& G, int n = -1); + void randomize_part(PRNG& G, int n); void almost_randomize(PRNG& G) { randomize(G); } void force_to_bit() { throw runtime_error("impossible"); } @@ -342,9 +343,9 @@ inline Z2 Z2::Mul(const Z2& x, const Z2& y) template template -inline Z2 Z2::operator*(const Z2& other) const +inline Z2<(K > L) ? K : L> Z2::operator*(const Z2& other) const { - return Z2::Mul(*this, other); + return Z2<(K > L) ? K : L>::Mul(*this, other); } template @@ -387,6 +388,14 @@ void Z2::randomize(PRNG& G, int n) normalize(); } +template +void Z2::randomize_part(PRNG& G, int n) +{ + *this = {}; + G.get_octets((octet*)a, DIV_CEIL(n, 8)); + a[DIV_CEIL(n, 64) - 1] &= uint64_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS); +} + template void Z2::pack(octetStream& o, int n) const { diff --git a/Math/gfp.h b/Math/gfp.h index dc745345..48e620c1 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -41,7 +41,7 @@ template void generate_prime_setup(string, int, int); #endif template -class gfp_ +class gfp_ : public ValueInterface { typedef modp_ modp_type; @@ -247,7 +247,7 @@ class gfp_ gfp_ operator^(const gfp_& x) { gfp_ res; res.XOR(*this, x); return res; } gfp_ operator|(const gfp_& x) { gfp_ res; res.OR(*this, x); return res; } gfp_ operator<<(int i) const { gfp_ res; res.SHL(*this, i); return res; } - gfp_ operator>>(int i) { gfp_ res; res.SHR(*this, i); return res; } + gfp_ operator>>(int i) const { gfp_ res; res.SHR(*this, i); return res; } gfp_& operator&=(const gfp_& x) { AND(*this, x); return *this; } gfp_& operator<<=(int i) { SHL(*this, i); return *this; } diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 6efb4f6c..246df321 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -193,7 +193,8 @@ void gfp_::reqbl(int n) { if ((int)n > 0 && pr() < bigint(1) << (n-1)) { - cout << "Tape requires prime of bit length " << n << endl; + cerr << "Tape requires prime of bit length " << n << endl; + cerr << "Run with '-lgp " << n << "'" << endl; throw invalid_params(); } else if ((int)n < 0) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 91ccc8e1..10490a62 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -450,17 +450,23 @@ void Player::pass_around(octetStream& o, octetStream& to_receive, int offset) co * size getting in the way */ template -void MultiPlayer::Broadcast_Receive(vector& o,bool donthash) const +void MultiPlayer::Broadcast_Receive_no_stats(vector& o) const { if (o.size() != sockets.size()) throw runtime_error("player numbers don't match"); - TimeScope ts(comm_stats["Broadcasting"].add(o[player_no])); for (int i=1; i +void MultiPlayer::Broadcast_Receive(vector& o,bool donthash) const +{ + TimeScope ts(comm_stats["Broadcasting"].add(o[player_no])); + Broadcast_Receive_no_stats(o); if (!donthash) { for (int i=0; i& o,bool donthash=false) const; + void Broadcast_Receive_no_stats(vector& o) const; // wait for available inputs void wait_for_available(vector& players, vector& result) const; diff --git a/OT/MascotParams.cpp b/OT/MascotParams.cpp index 868ea874..67a6d444 100644 --- a/OT/MascotParams.cpp +++ b/OT/MascotParams.cpp @@ -3,23 +3,7 @@ * */ -#include -#include "OT/NPartyTripleGenerator.h" -#include "OT/OTTripleSetup.h" -#include "Math/gf2n.h" -#include "Math/Setup.h" -#include "Protocols/Spdz2kShare.h" -#include "Tools/ezOptionParser.h" -#include "Math/Setup.h" -#include "Protocols/fake-stuff.h" -#include "Math/BitVec.h" - -#include "Protocols/fake-stuff.hpp" -#include "Math/Z2k.hpp" - -#include -#include -using namespace std; +#include "MascotParams.h" MascotParams::MascotParams() { diff --git a/OT/OTVole.h b/OT/OTVole.h index 2016b2bc..3f7294d7 100644 --- a/OT/OTVole.h +++ b/OT/OTVole.h @@ -51,18 +51,20 @@ public: // Both PRNG local_prng; - Row tmp; - - octetStream os; + vector oss; virtual void consistency_check (vector& os); void set_coeffs(__m128i* coefficients, PRNG& G, int num_elements) const; - void hash_row(octetStream& os, const Row& row, const __m128i* coefficients); - - void hash_row(octet* hash, const Row& row, const __m128i* coefficients); + template + void hash_row(octetStream& os, const U& row, const __m128i* coefficients); + template + void hash_row(octet* hash, const U& row, const __m128i* coefficients); + template + void hash_row(__m128i res[2], const U& row, const __m128i* coefficients); + static void add_mul(__m128i res[2], __m128i a, __m128i b); }; template diff --git a/OT/OTVole.hpp b/OT/OTVole.hpp index f8a84afc..1dbcdbe0 100644 --- a/OT/OTVole.hpp +++ b/OT/OTVole.hpp @@ -10,30 +10,28 @@ template void OTVoleBase::evaluate(vector& output, const vector& newReceiverInput) { const int N1 = newReceiverInput.size() + 1; output.resize(newReceiverInput.size()); - vector os(2); + auto& os = oss; + os.resize(2); + os[0].reset_write_head(); + os[1].reset_write_head(); if (this->ot_role & SENDER) { T extra; extra.randomize(local_prng); - vector _corr(newReceiverInput); - _corr.push_back(extra); - corr_prime = Row(_corr); + corr_prime.rows = newReceiverInput; + corr_prime.rows.push_back(extra); for (int i = 0; i < S; ++i) { - t0[i] = Row(N1); - t0[i].randomize(this->G_sender[i][0]); - t1[i] = Row(N1); - t1[i].randomize(this->G_sender[i][1]); - Row u = corr_prime + t1[i] + t0[i]; - u.pack(os[0]); + t0[i].randomize(this->G_sender[i][0], N1); + t1[i].randomize(this->G_sender[i][1], N1); + (corr_prime + t1[i] + t0[i]).pack(os[0]); } } send_if_ot_sender(this->player, os, this->ot_role); if (this->ot_role & RECEIVER) { for (int i = 0; i < S; ++i) { - t[i] = Row(N1); - t[i].randomize(this->G_receiver[i]); + t[i].randomize(this->G_receiver[i], N1); int choice_bit = this->baseReceiverInput.get_bit(i); if (choice_bit == 1) { @@ -84,38 +82,104 @@ void OTVoleBase::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) c } template -void OTVoleBase::hash_row(octetStream& os, const Row& row, const __m128i* coefficients) { +template +void OTVoleBase::hash_row(octetStream& os, const U& row, + const __m128i* coefficients) +{ octet hash[VOLE_HASH_SIZE] = {0}; this->hash_row(hash, row, coefficients); os.append(hash, VOLE_HASH_SIZE); } template -void OTVoleBase::hash_row(octet* hash, const Row& row, const __m128i* coefficients) { - int num_blocks = DIV_CEIL(row.size() * T::size(), 16); - - os.clear(); - for(auto& x : row.rows) - x.pack(os); - os.serialize(int128()); - - __m128i prods[2]; - avx_memzero(prods, sizeof(prods)); +template +void OTVoleBase::hash_row(octet* hash, const U& row, + const __m128i* coefficients) +{ __m128i res[2]; - avx_memzero(res, sizeof(res)); - - for (int i = 0; i < num_blocks; ++i) { - __m128i block; - os.unserialize(block); - mul128(block, coefficients[i], &prods[0], &prods[1]); - res[0] ^= prods[0]; - res[1] ^= prods[1]; - } + for (int i = 0; i < 2; i++) + res[i] = _mm_setzero_si128(); + hash_row(res, row, coefficients); crypto_generichash(hash, crypto_generichash_BYTES, (octet*) res, crypto_generichash_BYTES, NULL, 0); } +template +template +void OTVoleBase::hash_row(__m128i res[2], const U& row, + const __m128i* coefficients) +{ + auto coeff_base = coefficients; + int num_blocks = DIV_CEIL(row.size() * T::size(), 16); + __m128i buffer[T::size()]; + size_t next = 0; + while (next + 16 < row.size()) + { + for (int j = 0; j < 16; j++) + memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size()); + for (int j = 0; j < T::size(); j++) + add_mul(res, buffer[j], *coefficients++); + } + for (int j = 0; j < 16; j++) + if (next < row.size()) + memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size()); + for (int j = 0; j < num_blocks % T::size(); j++) + add_mul(res, buffer[j], *coefficients++); + assert(coefficients == coeff_base + num_blocks); +} + +template +void OTVoleBase::add_mul(__m128i res[2], __m128i a, __m128i b) +{ + __m128i prods[2]; + mul128(a, b, &prods[0], &prods[1]); + res[0] ^= prods[0]; + res[1] ^= prods[1]; +} + +template <> +template +inline +void OTVoleBase>::hash_row(__m128i res[2], + const U& row, const __m128i* coefficients) +{ + for (size_t i = 0; i < row.size(); i++) + { + __m128i block = int128(row[i].get_limb(1), row[i].get_limb(0)).a; + add_mul(res, block, coefficients[i]); + } +} + +template <> +template +inline +void OTVoleBase>::hash_row(__m128i res[2], const U& row, + const __m128i* coefficients) +{ + __m128i block; + size_t j; + for (j = 0; j < row.size() - 1; j += 2) + { + auto x = row[j]; + auto y = row[j + 1]; + block = int128(x.get_limb(1), x.get_limb(0)).a; + add_mul(res, block, *coefficients++); + block = int128(y.get_limb(0), x.get_limb(2)).a; + add_mul(res, block, *coefficients++); + block = int128(y.get_limb(2), y.get_limb(1)).a; + add_mul(res, block, *coefficients++); + } + if (j < row.size()) + { + auto x = row[j]; + block = int128(x.get_limb(1), x.get_limb(0)).a; + add_mul(res, block, *coefficients++); + block = int128(x.get_limb(2)).a; + add_mul(res, block, *coefficients++); + } +} + template void OTVoleBase::consistency_check(vector& os) { PRNG coef_prng_sender; @@ -144,17 +208,16 @@ void OTVoleBase::consistency_check(vector& os) { __m128i coefficients[num_blocks]; this->set_coeffs(coefficients, coef_prng_sender, num_blocks); - Row t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size()); for (int alpha = 0; alpha < S; ++alpha) { for (int i = 0; i < n_challenges(); i++) { int beta = get_challenge(coef_prng_sender, i); - t00 = t0[alpha] - t0[beta]; - t01 = t0[alpha] - t1[beta]; - t10 = t1[alpha] - t0[beta]; - t11 = t1[alpha] - t1[beta]; + auto t00 = t0[alpha] - t0[beta]; + auto t01 = t0[alpha] - t1[beta]; + auto t10 = t1[alpha] - t0[beta]; + auto t11 = t1[alpha] - t1[beta]; this->hash_row(os[0], t00, coefficients); this->hash_row(os[0], t01, coefficients); @@ -202,16 +265,16 @@ void OTVoleBase::consistency_check(vector& os) { int choice_alpha = this->baseReceiverInput.get_bit(alpha); int choice_beta = this->baseReceiverInput.get_bit(beta); - tmp = t[alpha] - t[beta]; + auto diff = t[alpha] - t[beta]; octet* choice_hash = hashes[choice_alpha][choice_beta]; octet diff_t[VOLE_HASH_SIZE] = {0}; - this->hash_row(diff_t, tmp, coefficients); + this->hash_row(diff_t, diff, coefficients); octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta]; octet other_diff[VOLE_HASH_SIZE] = {0}; - tmp = u[alpha] - u[beta]; - tmp -= t[alpha]; - tmp += t[beta]; + auto a = u[alpha] - u[beta]; + auto b = a - t[alpha]; + auto tmp = b + t[beta]; this->hash_row(other_diff, tmp, coefficients); if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) { diff --git a/OT/Row.h b/OT/Row.h index 21772c00..98b94512 100644 --- a/OT/Row.h +++ b/OT/Row.h @@ -5,7 +5,8 @@ #include "Math/gf2nlong.h" #define VOLE_HASH_SIZE crypto_generichash_BYTES -template class DeferredMinus; +template class DeferredMinus; +template class DeferredPlus; template class Row @@ -20,7 +21,10 @@ public: Row(const vector& _rows) : rows(_rows) {} - Row(DeferredMinus d) { *this = d; } + template + Row(DeferredMinus d) { *this = d; } + template + Row(DeferredPlus d) { *this = d; } bool operator==(const Row& other) const; bool operator!=(const Row& other) const { return not (*this == other); } @@ -31,23 +35,28 @@ public: Row& operator*=(const T& other); Row operator*(const T& other); - Row operator+(const Row & other); - DeferredMinus operator-(const Row & other); + DeferredPlus> operator+(const Row & other); + DeferredMinus> operator-(const Row& other); - Row& operator=(DeferredMinus d); + template + Row& operator=(const DeferredMinus& d); + template + Row& operator=(const DeferredPlus& d); Row operator<<(int i) const; // fine, since elements in vector are allocated contiguously const void* get_ptr() const { return rows[0].get_ptr(); } - void randomize(PRNG& G); + void randomize(PRNG& G, size_t size); void pack(octetStream& o) const; void unpack(octetStream& o); size_t size() const { return rows.size(); } + const T& operator[](size_t i) const { return rows[i]; } + template friend ostream& operator<<(ostream& o, const Row& x); }; @@ -55,17 +64,67 @@ public: template using Z2kRow = Row>; -template +template class DeferredMinus { public: - const Row& x; + const U& x; const Row& y; - DeferredMinus(const Row& x, const Row& y) : x(x), y(y) + DeferredMinus(const U& x, const Row& y) : x(x), y(y) { assert(x.size() == y.size()); } + + size_t size() const + { + return x.size(); + } + + T operator[](size_t i) const + { + return x[i] - y[i]; + } + + DeferredPlus operator+(const Row& other) + { + return {*this, other}; + } + + DeferredMinus operator-(const Row& other) + { + return {*this, other}; + } +}; + +template +class DeferredPlus +{ +public: + const U& x; + const Row& y; + + DeferredPlus(const U& x, const Row& y) : x(x), y(y) + { + assert(x.size() == y.size()); + } + + size_t size() const + { + return x.size(); + } + + T operator[](size_t i) const + { + return x[i] + y[i]; + } + + DeferredPlus operator+(const Row& other) + { + return {*this, other}; + } + + void pack(octetStream& o) const; }; #endif /* OT_ROW_H_ */ diff --git a/OT/Row.hpp b/OT/Row.hpp index 464c0ace..5e186951 100644 --- a/OT/Row.hpp +++ b/OT/Row.hpp @@ -46,34 +46,46 @@ Row Row::operator *(const T& other) } template -Row Row::operator +(const Row& other) +DeferredPlus> Row::operator +(const Row& other) { - Row res = other; - res += *this; - return res; + return {*this, other}; } template -DeferredMinus Row::operator -(const Row& other) +DeferredMinus> Row::operator -(const Row& other) { - return DeferredMinus(*this, other); + return {*this, other}; } template -Row& Row::operator=(DeferredMinus d) +template +Row& Row::operator=(const DeferredMinus& d) { size_t size = d.x.size(); rows.resize(size); for (size_t i = 0; i < size; i++) - rows[i] = d.x.rows[i] - d.y.rows[i]; + rows[i] = d[i]; return *this; } template -void Row::randomize(PRNG& G) +template +Row& Row::operator=(const DeferredPlus& d) { - for (size_t i = 0; i < this->size(); i++) - rows[i].randomize(G); + size_t size = d.x.size(); + rows.resize(size); + for (size_t i = 0; i < size; i++) + rows[i] = d[i]; + return *this; +} + +template +void Row::randomize(PRNG& G, size_t size) +{ + rows.clear(); + rows.reserve(size); + for (size_t i = 0; i < size; i++) + rows.push_back(G.get()); } template @@ -87,6 +99,14 @@ Row Row::operator<<(int i) const { return res; } +template +void DeferredPlus::pack(octetStream& o) const +{ + o.store(this->size()); + for (size_t i = 0; i < this->size(); i++) + (*this)[i].pack(o); +} + template void Row::pack(octetStream& o) const { @@ -100,9 +120,10 @@ void Row::unpack(octetStream& o) { size_t size; o.get(size); - this->rows.resize(size); - for (size_t i = 0; i < this->size(); i++) - rows[i].unpack(o); + rows.clear(); + rows.reserve(size); + for (size_t i = 0; i < size; i++) + rows.push_back(o.get()); } template diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 9185685d..125ee2b6 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -11,6 +11,9 @@ using namespace std; #include "Math/BitVec.h" #include "Data_Files.h" +#include "Protocols/Replicated.h" +#include "Protocols/MAC_Check_Base.h" +#include "Processor/Input.h" class Player; class DataPositions; @@ -26,17 +29,21 @@ template class ShareThread; } template -class DummyMC +class DummyMC : public MAC_Check_Base { public: - void POpen(vector&, vector&, Player&) + DummyMC() { - throw not_implemented(); } - void Check(Player& P) + template + DummyMC(U, int = 0, int = 0) { - (void) P; + } + + void exchange(const Player&) + { + throw not_implemented(); } DummyMC& get_part_MC() @@ -51,7 +58,8 @@ public: } }; -class DummyProtocol +template +class DummyProtocol : public ProtocolBase { public: Player& P; @@ -66,13 +74,11 @@ public: { } - template void init_mul(SubProcessor* = 0) { throw not_implemented(); } - template - void prepare_mul(const T&, const T&, int = 0) + typename T::clear prepare_mul(const T&, const T&, int = 0) { throw not_implemented(); } @@ -80,10 +86,10 @@ public: { throw not_implemented(); } - int finalize_mul(int = 0) + T finalize_mul(int = 0) { throw not_implemented(); - return 0; + return {}; } }; @@ -91,6 +97,13 @@ template class DummyLivePrep : public Preprocessing { public: + static void basic_setup(Player&) + { + } + static void teardown() + { + } + static void fail() { throw runtime_error( @@ -106,6 +119,11 @@ public: { } + DummyLivePrep(SubProcessor*, DataPositions& usage) : + Preprocessing(usage) + { + } + void set_protocol(typename T::Protocol&) { } @@ -177,7 +195,7 @@ public: throw not_implemented(); } template - static void input(SubProcessor& proc, vector regs) + static void input(SubProcessor& proc, vector regs, int) { (void) proc, (void) regs; throw not_implemented(); @@ -206,6 +224,14 @@ public: (void) a, (void) b; throw not_implemented(); } + static void raw_input(SubProcessor&, vector, int) + { + throw not_implemented(); + } + static void input_mixed(SubProcessor&, vector, int, bool) + { + throw not_implemented(); + } }; class NotImplementedOutput diff --git a/Processor/Instruction.h b/Processor/Instruction.h index f81cd87f..64e7b8f2 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -106,10 +106,12 @@ enum SQUARE = 0x52, INV = 0x53, INPUTMASK = 0x56, + INPUTMASKREG = 0x5C, PREP = 0x57, DABIT = 0x58, EDABIT = 0x59, SEDABIT = 0x5A, + RANDOMS = 0x5B, // Input INPUT = 0x60, INPUTFIX = 0xF0, @@ -143,6 +145,7 @@ enum SHRC = 0x81, SHLCI = 0x82, SHRCI = 0x83, + SHRSI = 0x84, // Branching and comparison JMP = 0x90, JMPNZ = 0x91, @@ -283,13 +286,16 @@ enum // Register types enum RegType { MODP, - GF2N, INT, SBIT, CBIT, DYN_SBIT, + SINT, + CINT, + SGF2N, + CGF2N, + NONE, MAX_REG_TYPE, - NONE }; enum SecrecyType { @@ -323,6 +329,8 @@ struct TempVars { class BaseInstruction { + friend class Program; + protected: int opcode; // The code int size; // Vector size @@ -346,10 +354,10 @@ public: bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } virtual int get_reg_type() const; - bool is_direct_memory_access(SecrecyType sec_type) const; + bool is_direct_memory_access() const; // Returns the memory size used if applicable and known - unsigned get_mem(RegType reg_type, SecrecyType sec_type) const; + unsigned get_mem(RegType reg_type) const; // Returns the maximal register used unsigned get_max_reg(int reg_type) const; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 2be96714..080ad5b8 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -7,6 +7,7 @@ #include "Processor/IntInput.h" #include "Processor/FixInput.h" #include "Processor/FloatInput.h" +#include "Processor/instructions.h" #include "Exceptions/Exceptions.h" #include "Tools/time-func.h" #include "Tools/parse.h" @@ -102,6 +103,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case MULINT: case DIVINT: case CONDPRINTPLAIN: + case INPUTMASKREG: r[0]=get_int(s); r[1]=get_int(s); r[2]=get_int(s); @@ -199,6 +201,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case ORCI: case SHLCI: case SHRCI: + case SHRSI: case SHLCBI: case SHRCBI: case NOTC: @@ -220,7 +223,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case USE: case USE_INP: case USE_EDABIT: - case RUN_TAPE: case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: case STOPPRIVATEOUTPUT: @@ -261,6 +263,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case INV2M: case CONDPRINTSTR: case CONDPRINTSTRB: + case RANDOMS: r[0]=get_int(s); n = get_int(s); break; @@ -310,6 +313,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RAWINPUT: case GRAWINPUT: case TRUNC_PR: + case RUN_TAPE: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; @@ -444,6 +448,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case ANDRS: case ANDS: case INPUTB: + case INPUTBVEC: case REVEAL: get_vector(get_int(s), start, s); break; @@ -544,15 +549,57 @@ int BaseInstruction::get_reg_type() const case USE_PREP: case GUSE_PREP: case USE_EDABIT: + case RUN_TAPE: // those use r[] not for registers return NONE; + case LDI: + case LDMC: + case STMC: + case LDMCI: + case STMCI: + case MOVC: + case ADDC: + case ADDCI: + case SUBC: + case SUBCI: + case SUBCFI: + case MULC: + case MULCI: + case DIVC: + case DIVCI: + case MODC: + case MODCI: + case LEGENDREC: + case DIGESTC: + case INV2M: + case OPEN: + case ANDC: + case XORC: + case ORC: + case ANDCI: + case XORCI: + case ORCI: + case NOTC: + case SHLC: + case SHRC: + case SHLCI: + case SHRCI: + case CONVINT: + return CINT; default: if (is_gf2n_instruction()) - return GF2N; + { + Instruction tmp; + tmp.opcode = opcode - 0x100; + if (tmp.get_reg_type() == CINT) + return CGF2N; + else + return SGF2N; + } else if (opcode >> 4 == 0x9) return INT; else - return MODP; + return SINT; } } @@ -564,27 +611,35 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const int size_offset = 0; int size = this->size; - if (opcode == DABIT) + // special treatment for instructions writing to different types + switch (opcode) { + case DABIT: if (reg_type == SBIT) return r[1] + size; - else if (reg_type == MODP) + else if (reg_type == SINT) return r[0] + size; else return 0; - } - else if (opcode == EDABIT or opcode == SEDABIT) - { + case EDABIT: + case SEDABIT: if (reg_type == SBIT) skip = 1; - else if (reg_type == MODP) + else if (reg_type == SINT) return r[0] + size; else return 0; - } - else if (get_reg_type() != reg_type) - { - return 0; + break; + case INPUTMASKREG: + if (reg_type == SINT) + return r[0] + size; + else if (reg_type == CINT) + return r[1] + size; + else + return 0; + default: + if (get_reg_type() != reg_type) + return 0; } switch (opcode) @@ -607,6 +662,9 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return r[0] + start[0] * start[2]; case CONV2DS: return r[0] + start[0] * start[1]; + case OPEN: + skip = 2; + break; case LDMSD: case LDMSDI: skip = 3; @@ -627,6 +685,20 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const offset = 3; size_offset = -2; break; + case INPUTBVEC: + { + int res = 0; + auto it = start.begin(); + while (it < start.end()) + { + int n = *it - 3; + it += 3; + assert(it + n <= start.end()); + for (int i = 0; i < n; i++) + res = max(res, *it++); + } + return res + 1; + } case ANDM: case NOTS: size = DIV_CEIL(n, 64); @@ -673,16 +745,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const } inline -unsigned BaseInstruction::get_mem(RegType reg_type, SecrecyType sec_type) const +unsigned BaseInstruction::get_mem(RegType reg_type) const { - if (get_reg_type() == reg_type and is_direct_memory_access(sec_type)) + if (get_reg_type() == reg_type and is_direct_memory_access()) return n + size; else return 0; } inline -bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const +bool BaseInstruction::is_direct_memory_access() const { switch (opcode) { @@ -690,12 +762,10 @@ bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const case STMS: case GLDMS: case GSTMS: - return sec_type == SECRET; case LDMC: case STMC: case GLDMC: case GSTMC: - return sec_type == CLEAR; case LDMINT: case STMINT: case LDMSB: @@ -731,16 +801,9 @@ __attribute__((always_inline)) #endif inline void Instruction::execute(Processor& Proc) const { - Proc.PC+=1; auto& Procp = Proc.Procp; auto& Proc2 = Proc.Proc2; - // binary instructions - typedef typename sint::bit_type T; - auto& processor = Proc.Procb; - auto& instruction = *this; - auto& Ci = Proc.get_Ci(); - // optimize some instructions switch (opcode) { @@ -772,102 +835,6 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Proc.get_S2_ref(r[0] + i).mul(Proc.read_S2(r[1] + i),Proc.read_C2(r[2] + i)); return; - case LDI: - Proc.temp.assign_ansp(n); - for (int i = 0; i < size; i++) - Proc.write_Cp(r[0] + i,Proc.temp.ansp); - return; - case LDMSI: - for (int i = 0; i < size; i++) - Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(Proc.read_Ci(r[1] + i))); - return; - case LDMS: - for (int i = 0; i < size; i++) - Proc.write_Sp(r[0] + i, Proc.machine.Mp.read_S(n + i)); - return; - case STMSI: - for (int i = 0; i < size; i++) - Proc.machine.Mp.write_S(Proc.read_Ci(r[1] + i), Proc.read_Sp(r[0] + i), Proc.PC); - return; - case STMS: - for (int i = 0; i < size; i++) - Proc.machine.Mp.write_S(n + i, Proc.read_Sp(r[0] + i), Proc.PC); - return; - case ADDC: - for (int i = 0; i < size; i++) - Proc.get_Cp_ref(r[0] + i).add(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i)); - return; - case ADDS: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i)); - return; - case ADDM: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i),Proc.P.my_num(),Proc.MCp.get_alphai()); - return; - case ADDCI: - Proc.temp.assign_ansp(n); - for (int i = 0; i < size; i++) - Proc.get_Cp_ref(r[0] + i).add(Proc.temp.ansp,Proc.read_Cp(r[1] + i)); - return; - case SUBS: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i)); - return; - case SUBSFI: - Proc.temp.assign_ansp(n); - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).sub(Proc.temp.ansp,Proc.read_Sp(r[1] + i),Proc.P.my_num(),Proc.MCp.get_alphai()); - return; - case SUBML: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i), Proc.read_Cp(r[2] + i), - Proc.P.my_num(), Proc.MCp.get_alphai()); - return; - case SUBMR: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Cp(r[1] + i), Proc.read_Sp(r[2] + i), - Proc.P.my_num(), Proc.MCp.get_alphai()); - return; - case MULM: - for (int i = 0; i < size; i++) - Proc.get_Sp_ref(r[0] + i).mul(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i)); - return; - case MULC: - for (int i = 0; i < size; i++) - Proc.get_Cp_ref(r[0] + i).mul(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i)); - return; - case MULCI: - Proc.temp.assign_ansp(n); - for (int i = 0; i < size; i++) - Proc.get_Cp_ref(r[0] + i).mul(Proc.temp.ansp,Proc.read_Cp(r[1] + i)); - return; - case SHRCI: - for (int i = 0; i < size; i++) - Proc.get_Cp_ref(r[0] + i).SHR(Proc.read_Cp(r[1] + i), n); - return; - case TRIPLE: - for (int i = 0; i < size; i++) - Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i), - Proc.get_Sp_ref(r[1] + i), Proc.get_Sp_ref(r[2] + i)); - return; - case BIT: - for (int i = 0; i < size; i++) - Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0] + i)); - return; - case LDINT: - for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, int(n)); - return; - case INCINT: - { - for (int i = 0; i < size; i++) - { - int inc = (i / start[0]) % start[1]; - Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1]) + inc * int(n)); - } - } - return; case SHUFFLE: for (int i = 0; i < size; i++) Proc.write_Ci(r[0] + i, Proc.read_Ci(r[1] + i)); @@ -877,29 +844,6 @@ inline void Instruction::execute(Processor& Proc) const swap(Proc.get_Ci_ref(r[0] + i), Proc.get_Ci_ref(r[0] + i + j)); } return; - case ADDINT: - for (int i = 0; i < size; i++) - Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) + Proc.read_Ci(r[2] + i); - return; - case SUBINT: - for (int i = 0; i < size; i++) - Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) - Proc.read_Ci(r[2] + i); - return; - case MULINT: - for (int i = 0; i < size; i++) - Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) * Proc.read_Ci(r[2] + i); - return; - case DIVINT: - for (int i = 0; i < size; i++) - Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) / Proc.read_Ci(r[2] + i); - return; - case CONVINT: - for (int i = 0; i < size; i++) - { - Proc.temp.assign_ansp(Proc.read_Ci(r[1] + i)); - Proc.get_Cp_ref(r[0] + i) = Proc.temp.ansp; - } - return; case CONVMODP: if (n == 0) { @@ -914,9 +858,6 @@ inline void Instruction::execute(Processor& Proc) const throw Processor_Error(to_string(n) + "-bit conversion impossible; " "integer registers only have 64 bits"); return; -#define X(NAME, CODE) case NAME: CODE; return; - COMBI_INSTRUCTIONS -#undef X } int r[3] = {this->r[0], this->r[1], this->r[2]}; @@ -1275,6 +1216,13 @@ inline void Instruction::execute(Processor& Proc) const case GINV: Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); break; + case RANDOMS: + Procp.protocol.randoms_inst(Procp, *this); + return; + case INPUTMASKREG: + Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2])); + Proc.write_Cp(r[1], Proc.temp.rrp); + break; case INPUTMASK: Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); if (n == Proc.P.my_num()) @@ -1388,6 +1336,9 @@ inline void Instruction::execute(Processor& Proc) const case GSHRCI: Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n); break; + case SHRSI: + Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n; + break; case GBITDEC: for (int j = 0; j < size; j++) { @@ -1649,7 +1600,7 @@ inline void Instruction::execute(Processor& Proc) const break; case RUN_TAPE: Proc.DataF.skip( - Proc.machine.run_tape(r[0], n, r[1], -1, &Proc.DataF.DataFp, + Proc.machine.run_tapes(start, &Proc.DataF.DataFp, &Proc.share_thread.DataF)); break; case JOIN_TAPE: @@ -1788,8 +1739,41 @@ void Program::execute(Processor& Proc) const { unsigned int size = p.size(); Proc.PC=0; + + auto& Procp = Proc.Procp; + + // binary instructions + typedef typename sint::bit_type T; + auto& processor = Proc.Procb; + auto& Ci = Proc.get_Ci(); + while (Proc.PC #include @@ -44,9 +45,6 @@ class Machine : public BaseMachine // Keep record of used offline data DataPositions pos; - int tn,numt; - bool usage_unknown; - void load_program(string threadname, string filename); public: @@ -71,6 +69,7 @@ class Machine : public BaseMachine OnlineOptions opts; atomic data_sent; + ExecutionStats stats; Machine(int my_number, Names& playerNames, string progname, string memtype, int lg2, bool direct, int opening_sum, @@ -79,9 +78,12 @@ class Machine : public BaseMachine const Names& get_N() { return N; } - DataPositions run_tape(int thread_number, int tape_number, int arg, - int line_number, Preprocessing* prep = 0, - Preprocessing* bit_prep = 0); + DataPositions run_tapes(const vector &args, Preprocessing *prep, + Preprocessing *bit_prep); + void fill_buffers(int thread_number, int tape_number, + Preprocessing *prep, + Preprocessing *bit_prep); + DataPositions run_tape(int thread_number, int tape_number, int arg); DataPositions join_tape(int thread_number); void run(); diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index bd602e4f..90212825 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -9,6 +9,7 @@ #include "Math/Setup.h" #include "Tools/mkpath.h" +#include "Tools/Bundle.h" #include #include @@ -22,7 +23,7 @@ Machine::Machine(int my_number, Names& playerNames, string progname_str, string memtype, int lg2, bool direct, int opening_sum, bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep, OnlineOptions opts) - : my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false), + : my_number(my_number), N(playerNames), direct(direct), opening_sum(opening_sum), receive_threads(receive_threads), max_broadcast(max_broadcast), use_encryption(use_encryption), live_prep(live_prep), opts(opts), @@ -40,9 +41,12 @@ Machine::Machine(int my_number, Names& playerNames, // make directory for outputs if necessary mkdir_p(PREP_DIR); - auto P = new PlainPlayer(N, 0xF00); - sint::LivePrep::basic_setup(*P); - delete P; + if (opts.live_prep) + { + auto P = new PlainPlayer(N, 0xF00); + sint::LivePrep::basic_setup(*P); + delete P; + } sint::read_or_generate_mac_key(prep_dir_prefix(), N, alphapi); sgf2n::read_or_generate_mac_key(prep_dir_prefix(), N, alpha2i); @@ -131,21 +135,29 @@ void Machine::load_program(string threadname, string filename) int i = progs.size() - 1; progs[i].parse(pinp); pinp.close(); - M2.minimum_size(GF2N, progs[i], threadname); - Mp.minimum_size(MODP, progs[i], threadname); - Mi.minimum_size(INT, progs[i], threadname); + M2.minimum_size(SGF2N, CGF2N, progs[i], threadname); + Mp.minimum_size(SINT, CINT, progs[i], threadname); + Mi.minimum_size(NONE, INT, progs[i], threadname); } template -DataPositions Machine::run_tape(int thread_number, int tape_number, - int arg, int line_number, Preprocessing* prep, +DataPositions Machine::run_tapes(const vector& args, + Preprocessing* prep, Preprocessing* bit_prep) +{ + assert(args.size() % 3 == 0); + for (unsigned i = 0; i < args.size(); i += 3) + fill_buffers(args[i], args[i + 1], prep, bit_prep); + DataPositions res(N.num_players()); + for (unsigned i = 0; i < args.size(); i += 3) + res.increase(run_tape(args[i], args[i + 1], args[i + 2])); + return res; +} + +template +void Machine::fill_buffers(int thread_number, int tape_number, + Preprocessing* prep, Preprocessing* bit_prep) { - if (size_t(thread_number) >= tinfo.size()) - throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size())); - if (size_t(tape_number) >= progs.size()) - throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size())); - // central preprocessing auto usage = progs[tape_number].get_offline_data_used(); if (sint::expensive and prep != 0 and OnlineOptions::singleton.bucket_size == 3) @@ -203,6 +215,16 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, #endif } } +} + +template +DataPositions Machine::run_tape(int thread_number, int tape_number, + int arg) +{ + if (size_t(thread_number) >= tinfo.size()) + throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size())); + if (size_t(tape_number) >= progs.size()) + throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size())); queues[thread_number]->schedule({tape_number, arg, pos}); //printf("Send signal to run program %d in thread %d\n",tape_number,thread_number); @@ -211,21 +233,10 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, { if (not opts.live_prep) { - // only one thread allowed - if (numt>1) - { cerr << "Line " << line_number << " has " << - numt << " threads but tape " << tape_number << + cerr << "Internally called tape " << tape_number << " has unknown offline data usage" << endl; - throw invalid_program(); - } - else if (line_number == -1) - { - cerr << "Internally called tape " << tape_number << - " has unknown offline data usage" << endl; - throw invalid_program(); - } + throw invalid_program(); } - usage_unknown = true; return DataPositions(N.num_players()); } else @@ -252,43 +263,12 @@ void Machine::run() proc_timer.start(); timer[0].start(); - bool flag=true; - usage_unknown=false; - int exec=0; - while (flag) - { inpf >> numt; - if (numt==0) - { flag=false; } - else - { for (int i=0; i> tn; - - // Cope with passing an integer parameter to a tape - int arg; - if (inpf.get() == ':') - inpf >> arg; - else - arg = 0; - - //cerr << "Run scheduled tape " << tn << " in thread " << i << endl; - pos.increase(run_tape(i, tn, arg, exec)); - } - // Make sure all terminate before we continue - auto new_pos = join_tape(0); - for (int i=1; i> _ >> _ >> _; + // run main tape + pos.increase(run_tape(0, 0, 0)); + join_tape(0); print_compiler(); @@ -317,6 +297,16 @@ void Machine::run() finish_timer.stop(); #ifdef VERBOSE + cerr << "Memory usage: "; + tinfo[0].print_usage(cerr, Mp.MS, "sint"); + tinfo[0].print_usage(cerr, Mp.MC, "cint"); + tinfo[0].print_usage(cerr, M2.MS, "sgf2n"); + tinfo[0].print_usage(cerr, M2.MS, "cgf2n"); + tinfo[0].print_usage(cerr, bit_memories.MS, "sbits"); + tinfo[0].print_usage(cerr, bit_memories.MC, "cbits"); + tinfo[0].print_usage(cerr, Mi.MC, "regint"); + cerr << endl; + for (unsigned int i = 0; i < join_timer.size(); i++) cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl; cerr << "Finish timer: " << finish_timer.elapsed() << endl; @@ -326,6 +316,15 @@ void Machine::run() print_timers(); cerr << "Data sent = " << data_sent / 1e6 << " MB" << endl; + PlainPlayer P(N, 0xFFF0); + Bundle bundle(P); + bundle.mine.store(data_sent.load()); + P.Broadcast_Receive_no_stats(bundle); + size_t global = 0; + for (auto& os : bundle) + global += os.get_int(8); + cerr << "Global data sent = " << global / 1e6 << " MB" << endl; + #ifdef VERBOSE if (opening_sum < N.num_players() && !direct) cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl; @@ -374,6 +373,39 @@ void Machine::run() pos.print_cost(); #endif + if (not stats.empty()) + { + cerr << "Instruction statistics:" << endl; + set> sorted_stats; + for (auto& x : stats) + { + sorted_stats.insert({x.second, x.first}); + } + for (auto& x : sorted_stats) + { + auto opcode = x.second; + auto calls = x.first; + cerr << "\t"; + int n_fill = 15; + switch (opcode) + { +#define X(NAME, PRE, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; + ARITHMETIC_INSTRUCTIONS +#undef X +#define X(NAME, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; + COMBI_INSTRUCTIONS +#undef X + default: + cerr << hex << setw(5) << showbase << left << opcode; + n_fill -= 5; + cerr << setw(0); + } + for (int i = 0; i < n_fill; i++) + cerr << " "; + cerr << dec << calls << endl; + } + } + #ifndef INSECURE Data_Files df(*this); df.seekg(pos); diff --git a/Processor/Memory.h b/Processor/Memory.h index 4146ef78..5cb3ed6f 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -26,7 +26,7 @@ class Memory public: CheckVector MS; - vector MC; + CheckVector MC; void resize_s(int sz) { MS.resize(sz); } @@ -73,7 +73,8 @@ class Memory { (void)start, (void)end; cerr << "Memory protection not activated" << endl; } #endif - void minimum_size(RegType reg_type, const Program& program, string threadname); + void minimum_size(RegType secret_type, RegType clear_type, + const Program& program, string threadname); friend ostream& operator<< <>(ostream& s,const Memory& M); friend istream& operator>> <>(istream& s,Memory& M); diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 9686b83c..bee3b40d 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -4,10 +4,13 @@ #include template -void Memory::minimum_size(RegType reg_type, const Program& program, string threadname) +void Memory::minimum_size(RegType secret_type, RegType clear_type, + const Program &program, string threadname) { (void) threadname; - const unsigned* sizes = program.direct_mem(reg_type); + unsigned sizes[MAX_SECRECY_TYPE]; + sizes[SECRET]= program.direct_mem(secret_type); + sizes[CLEAR] = program.direct_mem(clear_type); if (sizes[SECRET] > size_s()) { #ifdef DEBUG_MEMORY diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index 5836e2bf..7fe6c7cd 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -28,6 +28,9 @@ class thread_info static void purge_preprocessing(Machine& machine); + template + static void print_usage(ostream& o, const vector& regs, string name); + void Sub_Main_Func(); }; diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 6f61dabb..bf4e8074 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -20,6 +20,15 @@ using namespace std; +template +template +void thread_info::print_usage(ostream &o, + const vector& regs, string name) +{ + if (regs.capacity()) + o << name << "=" << regs.capacity() << " "; +} + template void thread_info::Sub_Main_Func() { @@ -105,7 +114,7 @@ void thread_info::Sub_Main_Func() while (flag) { // Wait until I have a program to run wait_timer.start(); - auto job = queues->next(); + ThreadJob job = queues->next(); program = job.prognum; wait_timer.stop(); #ifdef DEBUG_THREADS @@ -279,6 +288,16 @@ void thread_info::Sub_Main_Func() cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl; cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl; + + cerr << "Register usage: "; + print_usage(cerr, Proc.Procp.get_S(), "sint"); + print_usage(cerr, Proc.Procp.get_C(), "cint"); + print_usage(cerr, Proc.Proc2.get_S(), "sgf2n"); + print_usage(cerr, Proc.Proc2.get_C(), "cgf2n"); + print_usage(cerr, Proc.Procb.S, "sbits"); + print_usage(cerr, Proc.Procb.C, "cbits"); + print_usage(cerr, Proc.get_Ci(), "regint"); + cerr << endl; #endif // wind down thread by thread @@ -287,6 +306,7 @@ void thread_info::Sub_Main_Func() prep_sent += Proc.Procp.bit_prep.data_sent(); for (auto& x : Proc.Procp.personal_bit_preps) prep_sent += x->data_sent(); + machine.stats += Proc.stats; delete processor; machine.data_sent += P.sent + prep_sent; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 410ea67c..29cc9f0a 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -163,6 +163,7 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, opt.parse(argc, argv); vector allArgs(opt.firstArgs); + allArgs.insert(allArgs.end(), opt.unknownArgs.begin(), opt.unknownArgs.end()); allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end()); string usage; vector badOptions; diff --git a/Processor/Processor.h b/Processor/Processor.h index ae276f76..ec8f25f3 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -26,7 +26,7 @@ class Program; template class SubProcessor { - vector C; + CheckVector C; CheckVector S; DataPositions bit_usage; @@ -70,11 +70,16 @@ public: int b); void conv2ds(const Instruction& instruction); - vector& get_S() + CheckVector& get_S() { return S; } + CheckVector& get_C() + { + return C; + } + T& get_S_ref(int i) { return S[i]; @@ -132,7 +137,7 @@ public: template class Processor : public ArithmeticProcessor { - int reg_max2,reg_maxp,reg_maxi; + int reg_max2, reg_maxi; // Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ) octetStream socket_stream; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 47d803d3..750c2760 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -117,11 +117,11 @@ string Processor::get_filename(const char* prefix, bool use_number) template void Processor::reset(const Program& program,int arg) { - reg_max2 = program.num_reg(GF2N); - reg_maxp = program.num_reg(MODP); reg_maxi = program.num_reg(INT); - Proc2.resize(reg_max2); - Procp.resize(reg_maxp); + Proc2.get_S().resize(program.num_reg(SGF2N)); + Proc2.get_C().resize(program.num_reg(CGF2N)); + Procp.get_S().resize(program.num_reg(SINT)); + Procp.get_C().resize(program.num_reg(CINT)); Ci.resize(reg_maxi); this->arg = arg; Procb.reset(program); diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h index cf0a93c0..fe591c2e 100644 --- a/Processor/ProcessorBase.h +++ b/Processor/ProcessorBase.h @@ -11,6 +11,8 @@ #include using namespace std; +#include "Tools/ExecutionStats.h" + class ProcessorBase { // Stack @@ -24,6 +26,8 @@ protected: int arg; public: + ExecutionStats stats; + void pushi(long x) { stacki.push(x); } void popi(long& x) { x = stacki.top(); stacki.pop(); } diff --git a/Processor/Program.cpp b/Processor/Program.cpp index 34075eb0..f0eef143 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -10,8 +10,7 @@ void Program::compute_constants() for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) { max_reg[reg_type] = 0; - for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) - max_mem[reg_type][sec_type] = 0; + max_mem[reg_type] = 0; } for (unsigned int i=0; i> n) \ + X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \ + auto c = &Procp.get_S()[r[2]], \ + Procp.DataF.get_three(DATA_TRIPLE, *a++, *b++, *c++)) \ + X(BIT, auto dest = &Procp.get_S()[r[0]], \ + Procp.DataF.get_one(DATA_BIT, *dest++)) \ + X(LDINT, auto dest = &Proc.get_Ci()[r[0]], \ + *dest++ = int(n)) \ + X(ADDINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \ + auto op2 = &Proc.get_Ci()[r[2]], \ + *dest++ = *op1++ + *op2++) \ + X(SUBINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \ + auto op2 = &Proc.get_Ci()[r[2]], \ + *dest++ = *op1++ - *op2++) \ + X(MULINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \ + auto op2 = &Proc.get_Ci()[r[2]], \ + *dest++ = *op1++ * *op2++) \ + X(DIVINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \ + auto op2 = &Proc.get_Ci()[r[2]], \ + *dest++ = *op1++ / *op2++) \ + X(INCINT, auto dest = &Proc.get_Ci()[r[0]]; auto base = Proc.get_Ci()[r[1]], \ + int inc = (i / start[0]) % start[1]; *dest++ = base + inc * int(n)) \ + X(CONVINT, auto dest = &Procp.get_C()[r[0]]; auto source = &Proc.get_Ci()[r[1]], \ + *dest++ = *source++) \ + + +#endif /* PROCESSOR_INSTRUCTIONS_H_ */ diff --git a/Programs/Circuits b/Programs/Circuits index 82dfda9d..90845282 160000 --- a/Programs/Circuits +++ b/Programs/Circuits @@ -1 +1 @@ -Subproject commit 82dfda9d12b6fd2865f21f02809f6e7a5323d0be +Subproject commit 908452826cbd67f9757850a4943871bf66574e48 diff --git a/Programs/Source/benchmark_net.mpc b/Programs/Source/benchmark_net.mpc index cf3931e7..ff9c94ff 100644 --- a/Programs/Source/benchmark_net.mpc +++ b/Programs/Source/benchmark_net.mpc @@ -2,11 +2,7 @@ import ml import util import math -if 'trunc_pr' in program.args: - program.use_trunc_pr = True -if 'split' in program.args: - program.use_split(3) - +program.options_from_args() program.options.cisc = True try: @@ -72,7 +68,7 @@ else: opt = ml.Optimizer() opt.layers = layers for layer in layers: - layer.input_from(0, raw='raw' in program.args) + layer.input_from(0) layers[0].X.input_from(1) start_timer(1) opt.forward(1) diff --git a/Programs/Source/logreg.mpc b/Programs/Source/logreg.mpc index e7cb4262..a73153ef 100644 --- a/Programs/Source/logreg.mpc +++ b/Programs/Source/logreg.mpc @@ -3,13 +3,7 @@ from Compiler import ml debug = False program.use_edabit(True) -program.use_trunc_pr = True - -if 'split' in program.args: - program.use_split(3) - -if 'split2' in program.args: - program.use_split(2) +program.options_from_args() sfix.set_precision(16, 31) cfix.set_precision(16, 31) @@ -29,9 +23,19 @@ dense = ml.Dense(12800, dim, 1) layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)] sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False) -sgd.reset([X_normal, X_pos]) -sgd.run(batch_size=batch) -# @for_range(1000) -# def _(i): -# sgd.backward() +if not ('forward' in program.args or 'backward' in program.args): + sgd.reset([X_normal, X_pos]) + sgd.run(batch_size=batch) + +if 'forward' in program.args: + @for_range(1000) + def _(i): + sgd.forward(N=batch) + +if 'backward' in program.args: + b = regint.Array(batch) + b.assign(regint.inc(batch)) + @for_range(1000) + def _(i): + sgd.backward(batch=b) diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index 73df18cb..700f38ae 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -97,3 +97,6 @@ test(c[0], 0) test(c[1], 1) test(c[2], 1) test(c[3], 0) + +test(sbit(sbits(3)), 1) +test(sbits(sbit(1)), 1) diff --git a/Protocols/FakeInput.h b/Protocols/FakeInput.h new file mode 100644 index 00000000..3c232520 --- /dev/null +++ b/Protocols/FakeInput.h @@ -0,0 +1,51 @@ +/* + * FakeProtocol.h + * + */ + +#ifndef PROTOCOLS_FAKEINPUT_H_ +#define PROTOCOLS_FAKEINPUT_H_ + +#include "Replicated.h" +#include "Processor/Input.h" + +template +class FakeInput : public InputBase +{ + PointerVector results; + +public: + FakeInput(SubProcessor&, typename T::MAC_Check&) + { + } + + void reset(int) + { + results.clear(); + } + + void add_mine(const typename T::open_type& x, int = -1) + { + results.push_back(x); + } + + void add_other(int) + { + } + + void send_mine() + { + } + + T finalize_mine() + { + return results.next(); + } + + void finalize_other(int, T&, octetStream&, int = -1) + { + throw not_implemented(); + } +}; + +#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */ diff --git a/Protocols/FakeMC.h b/Protocols/FakeMC.h new file mode 100644 index 00000000..d16dda1c --- /dev/null +++ b/Protocols/FakeMC.h @@ -0,0 +1,31 @@ +/* + * FakeMC.h + * + */ + +#ifndef PROTOCOLS_FAKEMC_H_ +#define PROTOCOLS_FAKEMC_H_ + +#include "MAC_Check_Base.h" + +template +class FakeMC : public MAC_Check_Base +{ +public: + FakeMC(T, int = 0, int = 0) + { + } + + void exchange(const Player&) + { + for (auto& x : this->secrets) + this->values.push_back(x); + } + + FakeMC& get_part_MC() + { + return *this; + } +}; + +#endif /* PROTOCOLS_FAKEMC_H_ */ diff --git a/Protocols/FakePrep.h b/Protocols/FakePrep.h new file mode 100644 index 00000000..277aa9f2 --- /dev/null +++ b/Protocols/FakePrep.h @@ -0,0 +1,82 @@ +/* + * FakePrep.h + * + */ + +#ifndef PROTOCOLS_FAKEPREP_H_ +#define PROTOCOLS_FAKEPREP_H_ + +#include "ReplicatedPrep.h" + +template +class FakePrep : public BufferPrep +{ + SeededPRNG G; + +public: + FakePrep(SubProcessor*, DataPositions& usage) : + BufferPrep(usage) + { + } + + FakePrep(DataPositions& usage, GC::ShareThread&) : + BufferPrep(usage) + { + } + + FakePrep(DataPositions& usage, int = 0) : + BufferPrep(usage) + { + } + + void set_protocol(typename T::Protocol&) + { + } + + void buffer_triples() + { + for (int i = 0; i < 1000; i++) + { + auto a = G.get(); + auto b = G.get(); + this->triples.push_back({{a, b, a * b}}); + } + } + + void buffer_squares() + { + for (int i = 0; i < 1000; i++) + { + auto a = G.get(); + this->squares.push_back({{a, a * a}}); + } + } + + void buffer_inverses() + { + for (int i = 0; i < 1000; i++) + { + auto a = G.get(); + T aa; + aa.invert(a); + this->inverses.push_back({{a, aa}}); + } + } + + void buffer_bits() + { + for (int i = 0; i < 1000; i++) + { + this->bits.push_back(G.get_bit()); + } + } + + void get_dabit_no_count(T& a, typename T::bit_type& b) + { + auto bit = G.get_bit(); + a = bit; + b = bit; + } +}; + +#endif /* PROTOCOLS_FAKEPREP_H_ */ diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h new file mode 100644 index 00000000..006dd753 --- /dev/null +++ b/Protocols/FakeProtocol.h @@ -0,0 +1,63 @@ +/* + * FakeProtocol.h + * + */ + +#ifndef PROTOCOLS_FAKEPROTOCOL_H_ +#define PROTOCOLS_FAKEPROTOCOL_H_ + +#include "Replicated.h" + +template +class FakeProtocol : public ProtocolBase +{ + PointerVector results; + SeededPRNG G; + +public: + Player& P; + + FakeProtocol(Player& P) : P(P) + { + } + + void init_mul(SubProcessor*) + { + results.clear(); + } + + typename T::clear prepare_mul(const T& x, const T& y, int = -1) + { + results.push_back(x * y); + return {}; + } + + void exchange() + { + } + + T finalize_mul(int = -1) + { + return results.next(); + } + + void randoms(T& res, int n_bits) + { + res.randomize_part(G, n_bits); + } + + int get_n_relevant_players() + { + return 1; + } + + void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { + for (size_t i = 0; i < regs.size(); i += 4) + for (int l = 0; l < size; l++) + proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l) + >> regs[i + 3]; + } +}; + +#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */ diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h new file mode 100644 index 00000000..204b9389 --- /dev/null +++ b/Protocols/FakeShare.h @@ -0,0 +1,80 @@ +/* + * FakeShare.h + * + */ + +#ifndef PROTOCOLS_FAKESHARE_H_ +#define PROTOCOLS_FAKESHARE_H_ + +#include "GC/FakeSecret.h" +#include "ShareInterface.h" +#include "Processor/NoLivePrep.h" +#include "FakeMC.h" +#include "FakeProtocol.h" +#include "FakePrep.h" +#include "FakeInput.h" + +template +class FakeShare : public T, public ShareInterface +{ + typedef FakeShare This; + +public: + typedef T mac_key_type; + typedef T open_type; + typedef T clear; + + typedef FakePrep LivePrep; + typedef FakeMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef FakeInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef FakeProtocol Protocol; + + typedef GC::FakeSecret bit_type; + + static string type_short() + { + return "emul"; + } + + static int threshold(int) + { + return 0; + } + + static T constant(T value, int = 0, T = 0) + { + return value; + } + + FakeShare() + { + } + + template + FakeShare(U other) : + T(other) + { + } + + void assign(T value, int = 0, T = 0) + { + *this = value; + } + + void add(T a, T b, int = 0, T = {}) + { + *this = a + b; + } + + void sub(T a, T b, int = 0, T = {}) + { + *this = a - b; + } + + static void split(vector& dest, const vector& regs, int n_bits, + const This* source, int n_inputs, Player& P); +}; + +#endif /* PROTOCOLS_FAKESHARE_H_ */ diff --git a/Protocols/FakeShare.hpp b/Protocols/FakeShare.hpp new file mode 100644 index 00000000..29e49db3 --- /dev/null +++ b/Protocols/FakeShare.hpp @@ -0,0 +1,50 @@ +/* + * FakeShare.cpp + * + */ + +#include "FakeShare.h" +#include "Math/Z2k.h" +#include "GC/square64.h" + +template +void FakeShare::split(vector& dest, + const vector& regs, int n_bits, const This* source, int n_inputs, + Player&) +{ + assert(n_bits <= 64); + int unit = GC::Clear::N_BITS; + for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) + { + int start = k * unit; + int m = min(unit, n_inputs - start); + + switch (regs.size() / n_bits) + { + case 3: + { + for (int i = 0; i < n_bits; i++) + for (int j = 1; j < 3; j++) + dest.at(regs.at(3 * i + j) + k) = {}; + + square64 square; + + for (int j = 0; j < m; j++) + { + square.rows[j] = (source[j + start]).get_limb(0); + } + + square.transpose(m, n_bits); + + for (int j = 0; j < n_bits; j++) + { + auto& dest_reg = dest.at(regs.at(3 * j) + k); + dest_reg = square.rows[j]; + } + break; + } + default: + not_implemented(); + } + } +} diff --git a/Protocols/MalRepRingPrep.h b/Protocols/MalRepRingPrep.h index f29303de..7e30152b 100644 --- a/Protocols/MalRepRingPrep.h +++ b/Protocols/MalRepRingPrep.h @@ -27,10 +27,20 @@ public: void buffer_inputs(int player); }; +template +class RingOnlyBitsFromSquaresPrep : public virtual BufferPrep +{ +public: + RingOnlyBitsFromSquaresPrep(SubProcessor* proc, DataPositions& usage); + + void buffer_bits(); +}; + // extra class to avoid recursion template class MalRepRingPrepWithBits: public virtual MaliciousRingPrep, - public virtual MalRepRingPrep + public virtual MalRepRingPrep, + public virtual RingOnlyBitsFromSquaresPrep { public: MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage); @@ -40,12 +50,20 @@ public: MaliciousRingPrep::set_protocol(protocol); } + void buffer_triples() + { + MalRepRingPrep::buffer_triples(); + } + void buffer_squares() { MalRepRingPrep::buffer_squares(); } - void buffer_bits(); + void buffer_bits() + { + RingOnlyBitsFromSquaresPrep::buffer_bits(); + } void get_dabit_no_count(T& a, typename T::bit_type& b) { diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 75a81f4f..ca7f3f70 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -20,12 +20,20 @@ MalRepRingPrep::MalRepRingPrep(SubProcessor*, DataPositions& usage) : { } +template +RingOnlyBitsFromSquaresPrep::RingOnlyBitsFromSquaresPrep(SubProcessor*, + DataPositions& usage) : + BufferPrep(usage) +{ +} + template MalRepRingPrepWithBits::MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), - MaliciousRingPrep(proc, usage), MalRepRingPrep(proc, usage) + MaliciousRingPrep(proc, usage), MalRepRingPrep(proc, usage), + RingOnlyBitsFromSquaresPrep(proc, usage) { } @@ -204,14 +212,14 @@ void ShuffleSacrifice::triple_sacrifice(vector>& triples, } template -void MalRepRingPrepWithBits::buffer_bits() +void RingOnlyBitsFromSquaresPrep::buffer_bits() { auto proc = this->proc; assert(proc != 0); - typedef MalRepRingShare BitShare; + typedef typename T::SquareToBitShare BitShare; typename BitShare::MAC_Check MC; DataPositions usage; - MalRepRingPrep prep(0, usage); + typename BitShare::SquarePrep prep(0, usage); SubProcessor bit_proc(MC, prep, proc->P); prep.set_proc(&bit_proc); bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep); diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index f67ae04e..63bfe63a 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -10,6 +10,7 @@ #include "Protocols/MaliciousRep3Share.h" template class MalRepRingPrepWithBits; +template class MalRepRingPrep; template class MalRepRingShare : public MaliciousRep3Share> @@ -29,6 +30,8 @@ public: typedef MalRepRingPrepWithBits LivePrep; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; + typedef MalRepRingShare SquareToBitShare; + typedef MalRepRingPrep SquarePrep; static string type_short() { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 27085f0c..def702d1 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -11,6 +11,7 @@ template class HashMaliciousRepMC; template class Beaver; template class MaliciousRepPrepWithBits; +template class MaliciousRepPrep; namespace GC { @@ -30,6 +31,7 @@ public: typedef ::PrivateOutput> PrivateOutput; typedef Rep3Share Honest; typedef MaliciousRepPrepWithBits LivePrep; + typedef MaliciousRepPrep TriplePrep; typedef MaliciousRep3Share prep_type; typedef T random_type; diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 2eb93969..6f5404db 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -11,6 +11,7 @@ #include "Protocols/MaliciousShamirMC.h" template class MaliciousRepPrepWithBits; +template class MaliciousRepPrep; namespace GC { @@ -30,6 +31,7 @@ public: typedef ::PrivateOutput PrivateOutput; typedef ShamirShare Honest; typedef MaliciousRepPrepWithBits LivePrep; + typedef MaliciousRepPrep TriplePrep; typedef T random_type; typedef GC::MaliciousCcdSecret bit_type; diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index 1aebfe99..7a56e693 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -10,7 +10,7 @@ template MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), OTPrep(proc, usage) + OTPrep(proc, usage) { this->params.amplify = true; this->params.generateMACs = true; diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index d7769917..990eaa33 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -11,7 +11,7 @@ #include "OT/MascotParams.h" template -class OTPrep : public virtual RingPrep +class OTPrep : public virtual BitPrep { public: typename T::TripleGenerator* triple_generator; @@ -27,30 +27,41 @@ public: NamedCommStats comm_stats(); }; +template +class MascotTriplePrep : public OTPrep, public RandomPrep +{ +public: + MascotTriplePrep(SubProcessor *proc, DataPositions &usage) : + BufferPrep(usage), BitPrep(proc, usage), + OTPrep(proc, usage) + { + } + + void buffer_triples(); + void buffer_inputs(int player); + + T get_random(); +}; + template class MascotPrep: public virtual MaliciousRingPrep, - public virtual OTPrep, - public RandomPrep + public virtual MascotTriplePrep { public: MascotPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), MaliciousRingPrep(proc, usage), - OTPrep(proc, usage) + MascotTriplePrep(proc, usage) { } virtual ~MascotPrep() { } - void buffer_triples(); - void buffer_inputs(int player); void buffer_bits() { throw runtime_error("use subclass"); } virtual void buffer_dabits(ThreadQueues* queues); void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues); - - T get_random(); }; template @@ -64,7 +75,7 @@ public: BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), MaliciousRingPrep(proc, usage), - OTPrep(proc, usage), MascotPrep(proc, usage) + MascotTriplePrep(proc, usage), MascotPrep(proc, usage) { } }; diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index bef1401f..7cf95621 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -19,7 +19,7 @@ template OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), triple_generator(0) + triple_generator(0) { } @@ -33,7 +33,7 @@ OTPrep::~OTPrep() template void OTPrep::set_protocol(typename T::Protocol& protocol) { - RingPrep::set_protocol(protocol); + BitPrep::set_protocol(protocol); SubProcessor* proc = this->proc; assert(proc != 0); triple_generator = new typename T::TripleGenerator( @@ -45,7 +45,7 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) } template -void MascotPrep::buffer_triples() +void MascotTriplePrep::buffer_triples() { #ifdef INSECURE #ifdef FAKE_MASCOT_TRIPLES @@ -100,7 +100,7 @@ void MascotPrep::buffer_dabits(ThreadQueues* queues) } template -void MascotPrep::buffer_inputs(int player) +void MascotTriplePrep::buffer_inputs(int player) { auto& triple_generator = this->triple_generator; assert(triple_generator); @@ -112,7 +112,7 @@ void MascotPrep::buffer_inputs(int player) } template -T MascotPrep::get_random() +T MascotTriplePrep::get_random() { assert(this->proc); return BufferPrep::get_random_from_inputs(this->proc->P.num_players()); @@ -135,7 +135,7 @@ T BufferPrep::get_random_from_inputs(int nplayers) template size_t OTPrep::data_sent() { - size_t res = RingPrep::data_sent(); + size_t res = BitPrep::data_sent(); if (triple_generator) res += triple_generator->data_sent(); return res; diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 85adf8b8..9a90f772 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -25,6 +25,7 @@ public: typedef SignedZ2 clear; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; + typedef MalRepRingShare SquareToBitShare; typedef PostSacrifice Protocol; typedef HashMaliciousRepMC MAC_Check; @@ -53,6 +54,16 @@ public: PostSacriRepRingShare(const U& other) : super(other) { } + + template + static void split(vector& dest, const vector& regs, + int n_bits, const super* source, int n_inputs, Player& P) + { + if (regs.size() / n_bits != 3) + throw runtime_error("only secure with three-way split"); + + super::split(dest, regs, n_bits, source, n_inputs, P); + } }; #endif /* PROTOCOLS_POSTSACRIREPRINGSHARE_H_ */ diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index 3e1ec432..0ce53c7d 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -33,6 +33,8 @@ public: void check(); int get_n_relevant_players() { return internal.get_n_relevant_players(); } + + virtual void randoms(T& res, int n_bits) { randomizer.randoms(res, n_bits); } }; #endif /* PROTOCOLS_POSTSACRIFICE_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 82d1e6f3..440ec19e 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -13,6 +13,7 @@ #include "ShareInterface.h" template class ReplicatedPrep; +template class ReplicatedRingPrep; template class PrivateOutput; template @@ -30,6 +31,7 @@ public: typedef ReplicatedInput Input; typedef ::PrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; + typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; typedef GC::SemiHonestRepSecret bit_type; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 4f0760eb..d3545eb7 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -14,9 +14,9 @@ template class ReplicatedPrep2k; template -class Rep3Share2 : public Rep3Share> +class Rep3Share2 : public Rep3Share> { - typedef SignedZ2 T; + typedef Z2 T; public: typedef Replicated Protocol; @@ -26,6 +26,7 @@ public: typedef ::PrivateOutput PrivateOutput; typedef ReplicatedPrep2k LivePrep; typedef Rep3Share2 Honest; + typedef SignedZ2 clear; typedef GC::SemiHonestRepSecret bit_type; diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index c059dbba..5da6c9e9 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -24,6 +24,7 @@ template class Share; template class Rep3Share; template class MAC_Check_Base; template class Preprocessing; +class Instruction; class ReplicatedBase { @@ -71,6 +72,9 @@ public: virtual void trunc_pr(const vector& regs, int size, SubProcessor& proc) { (void) regs, (void) size; (void) proc; throw runtime_error("trunc_pr not implemented"); } + virtual void randoms(T&, int) { throw runtime_error("randoms not implemented"); } + virtual void randoms_inst(SubProcessor&, const Instruction&); + virtual void start_exchange() { exchange(); } virtual void stop_exchange() {} @@ -119,6 +123,7 @@ public: void trunc_pr(const vector& regs, int size, SubProcessor& proc); T get_random(); + void randoms(T& res, int n_bits); void start_exchange(); void stop_exchange(); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 945745e8..7262367b 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -66,7 +66,7 @@ ProtocolBase::~ProtocolBase() { #ifdef VERBOSE if (counter) - cerr << "Number of multiplications: " << counter << endl; + cerr << "Number of " << T::type_string() << " multiplications: " << counter << endl; #endif } @@ -222,6 +222,24 @@ T Replicated::get_random() return res; } +template +void ProtocolBase::randoms_inst(SubProcessor& proc, + const Instruction& instruction) +{ + for (int j = 0; j < instruction.get_size(); j++) + { + auto& res = proc.get_S_ref(instruction.get_r(0) + j); + randoms(res, instruction.get_n()); + } +} + +template +void Replicated::randoms(T& res, int n_bits) +{ + for (int i = 0; i < 2; i++) + res[i].randomize_part(shared_prngs[i], n_bits); +} + template void trunc_pr(const vector& regs, int size, SubProcessor>& proc) diff --git a/Protocols/ReplicatedFieldMachine.hpp b/Protocols/ReplicatedFieldMachine.hpp new file mode 100644 index 00000000..c54e3f85 --- /dev/null +++ b/Protocols/ReplicatedFieldMachine.hpp @@ -0,0 +1,39 @@ +/* + * ReplicatedFieldMachine.hpp + * + */ + +#ifndef PROTOCOLS_REPLICATEDFIELDMACHINE_HPP_ +#define PROTOCOLS_REPLICATEDFIELDMACHINE_HPP_ + +#include "ReplicatedMachine.hpp" + +template class T> +ReplicatedFieldMachine::ReplicatedFieldMachine(int argc, + const char **argv) +{ + ez::ezOptionParser opt; + OnlineOptions online_opts(opt, argc, argv, 0, true, true); + int n_limbs = DIV_CEIL(online_opts.lgp, 64); + switch (n_limbs) + { +#undef X +#define X(L) \ + case L: \ + ReplicatedMachine>, T>(argc, argv, opt, online_opts); \ + break; +#ifdef MORE_PRIMES + X(1) X(2) X(3) +#endif +#if GFP_MOD_SZ > 3 or not defined(MORE_PRIMES) + X(GFP_MOD_SZ) +#endif +#undef X + default: + cerr << "Not compiled for " << online_opts.lgp << "-bit primes" << endl; + cerr << "Compile with -DGFP_MOD_SZ=" << n_limbs << endl; + exit(1); + } +} + +#endif /* PROTOCOLS_REPLICATEDFIELDMACHINE_HPP_ */ diff --git a/Protocols/ReplicatedMachine.h b/Protocols/ReplicatedMachine.h index c5189e1c..2da565fe 100644 --- a/Protocols/ReplicatedMachine.h +++ b/Protocols/ReplicatedMachine.h @@ -13,6 +13,8 @@ template class ReplicatedMachine { public: + ReplicatedMachine(int argc, const char **argv, ez::ezOptionParser &opt, + OnlineOptions &online_opts, int n_players = 3); ReplicatedMachine(int argc, const char** argv, string name, ez::ezOptionParser& opt, int nplayers = 3); ReplicatedMachine(int argc, const char** argv, ez::ezOptionParser& opt, @@ -22,4 +24,11 @@ public: } }; +template class U> +class ReplicatedFieldMachine +{ +public: + ReplicatedFieldMachine(int argc, const char** argv); +}; + #endif /* PROTOCOLS_REPLICATEDMACHINE_H_ */ diff --git a/Protocols/ReplicatedMachine.hpp b/Protocols/ReplicatedMachine.hpp index 595a449d..58e00cdd 100644 --- a/Protocols/ReplicatedMachine.hpp +++ b/Protocols/ReplicatedMachine.hpp @@ -18,6 +18,14 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, (void) name; OnlineOptions online_opts(opt, argc, argv, 10000, true, T::clear::invertible); + ReplicatedMachine(argc, argv, opt, online_opts, nplayers); +} + +template +ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, + ez::ezOptionParser& opt, OnlineOptions& online_opts, + int nplayers) +{ OnlineOptions::singleton = online_opts; NetworkOptionsWithNumber network_opts(opt, argc, argv, nplayers, false); opt.add( diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index eb9b0fb2..a60f00a6 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -9,9 +9,10 @@ template SemiPrep::SemiPrep(SubProcessor* proc, DataPositions& usage) : - BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), + BufferPrep(usage), + BitPrep(proc, usage), OTPrep(proc, usage), + RingPrep(proc, usage), SemiHonestRingPrep(proc, usage) { this->params.set_passive(); diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 179d8c21..23769b13 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -14,8 +14,9 @@ class SemiPrep2k : public SemiPrep public: SemiPrep2k(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), + OTPrep(proc, usage), RingPrep(proc, usage), - OTPrep(proc, usage), SemiHonestRingPrep(proc, usage), + SemiHonestRingPrep(proc, usage), SemiPrep(proc, usage) { } diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 50321da2..1e1a8a1e 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -46,6 +46,7 @@ public: typedef ::PrivateOutput PrivateOutput; typedef SPDZ Protocol; typedef SemiPrep LivePrep; + typedef LivePrep TriplePrep; typedef SemiShare prep_type; typedef SemiMultiplier Multiplier; diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index cfe936f2..43e2b5e9 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -32,6 +32,8 @@ class Shamir : public ProtocolBase vector random; + typename T::open_type dotprod_share; + void buffer_random(); int threshold; @@ -77,6 +79,11 @@ public: T finalize(int n_input_players); + void init_dotprod(SubProcessor* proc); + void prepare_dotprod(const T& x, const T& y); + void next_dotprod(); + T finalize_dotprod(int length); + T get_random(); }; diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d1ede2a8..ccdd8356 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -151,6 +151,33 @@ T Shamir::finalize(int n_relevant_players) return res; } +template +void Shamir::init_dotprod(SubProcessor* proc) +{ + init_mul(proc); + dotprod_share = 0; +} + +template +void Shamir::prepare_dotprod(const T& x, const T& y) +{ + dotprod_share += x * y * rec_factor; +} + +template +void Shamir::next_dotprod() +{ + if (P.my_num() < n_mul_players) + resharing->add_mine(dotprod_share); + dotprod_share = 0; +} + +template +T Shamir::finalize_dotprod(int) +{ + return finalize_mul(); +} + template T Shamir::get_random() { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index c196a366..6491ce33 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -12,6 +12,7 @@ #include "ShareInterface.h" template class ReplicatedPrep; +template class ReplicatedRingPrep; namespace GC { @@ -34,6 +35,7 @@ public: typedef ShamirInput Input; typedef ::PrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; + typedef ReplicatedRingPrep TriplePrep; typedef ShamirShare Honest; typedef GC::CcdSecret bit_type; diff --git a/Protocols/Share.h b/Protocols/Share.h index c329b8c4..8b3e0097 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -23,6 +23,7 @@ template class MascotMultiplier; template class MascotFieldPrep; template class MascotTripleGenerator; template class MascotPrep; +template class MascotTriplePrep; union square128; @@ -127,7 +128,7 @@ class Share_ : public ShareInterface Share_ operator<<(int i) { return this->operator*(T(1) << i); } Share_& operator<<=(int i) { return *this = *this << i; } - Share_ operator>>(int i) { return {a >> i, mac >> i}; } + Share_ operator>>(int i) const { return {a >> i, mac >> i}; } void force_to_bit() { a.force_to_bit(); } @@ -173,6 +174,7 @@ public: typedef SPDZ Protocol; typedef MascotFieldPrep LivePrep; typedef MascotPrep RandomPrep; + typedef MascotTriplePrep TriplePrep; static const bool expensive = true; diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 478da7a8..a0b95475 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -20,6 +20,13 @@ class ShareInterface { public: typedef GC::NoShare part_type; + typedef GC::NoShare bit_type; + + static const bool needs_ot = false; + static const bool expensive = false; + static const bool expensive_triples = false; + + static string type_short() { return "undef"; } template static void split(vector, vector, int, T*, int, Player&) @@ -29,6 +36,9 @@ public: template static void read_or_generate_mac_key(const string&, const Names&, T&) {} + + template + static void generate_mac_key(T&, U&) {} }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index e0c23a93..71dbff23 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -18,7 +18,7 @@ class Spdz2kPrep : public virtual MascotPrep, public virtual RingOnlyPrep { typedef Spdz2kShare BitShare; DataPositions bit_pos; - MascotPrep* bit_prep; + MascotTriplePrep* bit_prep; SubProcessor* bit_proc; typename BitShare::MAC_Check* bit_MC; Sub_Data_Files* bit_DataF; diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index 1482dbfd..d46a4f52 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -16,7 +16,8 @@ Spdz2kPrep::Spdz2kPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), MaliciousRingPrep(proc, usage), - OTPrep(proc, usage), MascotPrep(proc, usage), + MascotTriplePrep(proc, usage), + MascotPrep(proc, usage), RingOnlyPrep(proc, usage) { this->params.amplify = false; @@ -51,7 +52,7 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) bit_pos = DataPositions(proc->P.num_players()); bit_DataF = new Sub_Data_Files(0, 0, "", bit_pos, 0); bit_proc = new SubProcessor(*bit_MC, *bit_DataF, proc->P); - bit_prep = new MascotPrep(bit_proc, bit_pos); + bit_prep = new MascotTriplePrep(bit_proc, bit_pos); bit_prep->params.amplify = false; bit_protocol = new typename BitShare::Protocol(proc->P); bit_prep->set_protocol(*bit_protocol); @@ -72,7 +73,8 @@ void MaliciousRingPrep::buffer_bits() // one of the two is not a zero divisor, so if the product is zero, one of them is too protocol.prepare_mul(one - bit, bit); protocol.exchange(); - vector checks(this->bits.size()); + vector checks; + checks.reserve(this->bits.size()); for (size_t i = 0; i < this->bits.size(); i++) checks.push_back(protocol.finalize_mul()); this->proc->MC.CheckFor(0, checks, protocol.P); diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 0dc8ca79..dc87450a 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -308,6 +308,13 @@ void generate_mac_keys(typename T::mac_type::Scalar& key, cout << "Final Key: " << key << endl; } +inline void check_files(ofstream* outf, int N) +{ + for (int i = 0; i < N; i++) + if (outf[i].fail()) + throw runtime_error("couldn't write to file"); +} + /* N = Number players * ntrip = Number triples needed * str = "2" or "p" @@ -347,6 +354,7 @@ void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, Sc[j].output(outf[j],false); } } + check_files(outf, N); for (int i=0; i ``` @@ -326,6 +332,9 @@ lines. If you run with any other protocol, you will need to remove CrypTFlow repository that includes the patch in https://github.com/mkskeller/EzPC/commit/2021be90d21dc26894be98f33cd10dd26769f479. +[The reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.ml) +contains further documentation on available layers. + ## Dishonest majority Some full implementations require oblivious transfer, which is @@ -373,7 +382,7 @@ al.](https://eprint.iacr.org/2016/944) CowGear denotes a covertly secure version of LowGear. The reason for this is the key generation that only achieves covert security. It is -possible however to run full LowGear for triple generation by using +possible however to run full LowGear for the offline phase by using `-s` with the desired security parameter. The same holds for ChaiGear, an adapted version of HighGear. Option `-T` activates [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs in @@ -425,7 +434,11 @@ argument to change that. ### Yao's garbled circuits -We use the implementation optimized for AES-NI by [Bellare et al.](https://eprint.iacr.org/2013/426) +We use half-gate garbling as described by [Guo et +al.](https://eprint.iacr.org/2014/756.pdf). Alternatively, you can +activate the implementation optimized by [Bellare et +al.](https://eprint.iacr.org/2013/426) by adding `MY_CFLAGS += +-DFULL_GATES` to `CONFIG.mine`. Compile the virtual machine: @@ -555,6 +568,7 @@ lists the available schemes. | Program | Protocol | Dishonest Maj. | Malicious | \# parties | Script | | --- | --- | --- | --- | --- | --- | | `real-bmr-party.x` | MASCOT | Y | Y | 2 or more | `real-bmr.sh` | +| `semi-bmr-party.x` | Semi | Y | Y | 2 or more | `semi-bmr.sh` | | `shamir-bmr-party.x` | Shamir | N | N | 3 or more | `shamir-bmr.sh` | | `mal-shamir-bmr-party.x` | Shamir | N | Y | 3 or more | `mal-shamir-bmr.sh` | | `rep-bmr-party.x` | Replicated | N | N | 3 | `rep-bmr.sh` | @@ -657,6 +671,13 @@ e.g. if this machine is name `diffie` on the local network: The software uses TCP ports around 5000 by default, use the `-pn` argument to change that. +### SPDZ2k + +Creating fake offline data for SPDZ2k requires to call +`Fake-Offline.x` directly instead of via `setup-online.sh`: + +`./Fake-Offline.x -Z -S ` + ### Honest-majority three-party computation of binary circuits with malicious security Compile the virtual machines: diff --git a/Scripts/fixed-rep-to-raw.py b/Scripts/fixed-rep-to-raw.py new file mode 100755 index 00000000..d3359970 --- /dev/null +++ b/Scripts/fixed-rep-to-raw.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +import sys, operator, struct + +try: + filename = sys.argv[3] +except: + filename = 'Player-Data/Private-Input-0' + +out = open(filename, 'bw') + +for line in open(sys.argv[1]): + line = (line.strip()) + if line: + x = (line.split(' ')) + for xx in x: + out.write(struct.pack('&2 echo Running $prefix $SPDZROOT/$bin $last_player $params $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 + wait } sleep 0.5 diff --git a/Scripts/semi-bmr.sh b/Scripts/semi-bmr.sh new file mode 100755 index 00000000..d86c8d46 --- /dev/null +++ b/Scripts/semi-bmr.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player semi-bmr-party.x $* || exit 1 diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index c6315612..52376b68 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +echo MY_CFLAGS += -DINSECURE >> CONFIG.mine +touch ECDSA/Fake-ECDSA.cpp + make -j4 ecdsa Fake-ECDSA.x secure.x run() diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 1476863f..0c799d7b 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -62,11 +62,21 @@ done ./compile.py tutorial -test_vm cowgear -T -test_vm chaigear -T -l 3 -c 2 +test_vm cowgear $run_opts -T +test_vm chaigear $run_opts -T -l 3 -c 2 + +if test $skip_binary; then + exit +fi ./compile.py -B 16 $compile_opts tutorial -for i in replicated mal-rep-bin semi-bin ccd mal-ccd yao tinier rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do - test_vm $i +for i in replicated mal-rep-bin semi-bin ccd mal-ccd; do + test_vm $i $run_opts +done + +test_vm yao + +for i in tinier rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do + test_vm $i $run_opts done diff --git a/Tools/BitVector.h b/Tools/BitVector.h index d92ddec5..a289373d 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -189,8 +189,10 @@ class BitVector void set(const FixedVec& a); bool get_bit(int i) const { +#ifdef CHECK_SIZE if (i >= (int)nbits) throw out_of_range("BitVector access: " + to_string(i) + "/" + to_string(nbits)); +#endif return (bytes[i/8] >> (i % 8)) & 1; } void set_bit(int i,unsigned int a) diff --git a/Tools/ExecutionStats.h b/Tools/ExecutionStats.h new file mode 100644 index 00000000..bd9748a5 --- /dev/null +++ b/Tools/ExecutionStats.h @@ -0,0 +1,23 @@ +/* + * ExecutionsStats.h + * + */ + +#ifndef TOOLS_EXECUTIONSTATS_H_ +#define TOOLS_EXECUTIONSTATS_H_ + +#include +using namespace std; + +class ExecutionStats : public map +{ +public: + ExecutionStats& operator+=(const ExecutionStats& other) + { + for (auto it : other) + (*this)[it.first] += it.second; + return *this; + } +}; + +#endif /* TOOLS_EXECUTIONSTATS_H_ */ diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index ab83f7fa..f96a294d 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -73,6 +73,7 @@ void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero) Sc[j].output(outf[j],false); } } + check_files(outf, N); for (int i=0; i>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); + make_basic>>({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); make_bits({}, nplayers, default_num, zero); + gf2n_short::reset(); + gf2n_short::init_field(40); + Z2<41> keyt; generate_mac_keys>(keyt, nplayers, prep_data_prefix); diff --git a/Yao/YaoEvalInput.h b/Yao/YaoEvalInput.h new file mode 100644 index 00000000..1a8c757b --- /dev/null +++ b/Yao/YaoEvalInput.h @@ -0,0 +1,33 @@ +/* + * YaoEvalInput.h + * + */ + +#ifndef YAO_YAOEVALINPUT_H_ +#define YAO_YAOEVALINPUT_H_ + +#include "YaoEvaluator.h" + +class YaoEvalInput +{ +public: + YaoEvaluator& evaluator; + BitVector inputs; + int i_bit; + octetStream os; + + YaoEvalInput() : + evaluator(YaoEvaluator::s()) + { + inputs.resize(0); + i_bit = 0; + } + + void exchange() + { + evaluator.ot_ext.extend_correlated(inputs.size(), inputs); + evaluator.player.receive(os); + } +}; + +#endif /* YAO_YAOEVALINPUT_H_ */ diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 9d5af9d7..a867afdb 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -7,6 +7,7 @@ #include "YaoEvalWire.h" #include "YaoGate.h" #include "YaoEvaluator.h" +#include "YaoEvalInput.h" #include "BMR/prf.h" #include "BMR/common.h" #include "GC/ArgTuples.h" @@ -14,6 +15,7 @@ #include "GC/Processor.hpp" #include "GC/Secret.hpp" #include "GC/Thread.hpp" +#include "GC/ShareSecret.hpp" #include "YaoCommon.hpp" ostream& YaoEvalWire::out = cout; @@ -133,58 +135,51 @@ void YaoEvalWire::and_(GC::Memory >& S, } } +template +void YaoEvalWire::my_input(T& inputter, bool value, int n_bits) +{ + assert(n_bits == 1); + auto& inputs = inputter.inputs; + size_t start = inputs.size(); + inputs.resize(start + 1); + inputs.set_bit(start, value); +} + +template +void YaoEvalWire::finalize_input(T& inputter, int from, int n_bits) +{ + assert(n_bits == 1); + + if (from == 1) + { + auto& i_bit = inputter.i_bit; + Key key; + inputter.os.unserialize(key); + set(key ^ inputter.evaluator.ot_ext.receiverOutputMatrix[i_bit], + inputter.inputs.get_bit(i_bit)); + i_bit++; + } + else + { + set(0); + } +} + void YaoEvalWire::inputb(GC::Processor >& processor, const vector& args) { - InputArgList a(args); - BitVector inputs; - inputs.resize(0); - auto& evaluator = YaoEvaluator::s(); - bool interactive = evaluator.n_interactive_inputs_from_me(a) > 0; + YaoEvalInput inputter; + processor.inputb(inputter, processor, args, inputter.evaluator.P->my_num()); + return; +} - for (auto x : a) - { - auto& dest = processor.S[x.dest]; - dest.resize_regs(x.n_bits); - if (x.from == 0) - { - for (auto& reg : dest.get_regs()) - { - reg.set(0); - } - } - else - { - long long input = processor.get_input(x.params, interactive); - size_t start = inputs.size(); - inputs.resize(start + x.n_bits); - for (int i = 0; i < x.n_bits; i++) - inputs.set_bit(start + i, (input >> i) & 1); - } - } - - if (interactive) - cout << "Thank you" << endl; - - evaluator.ot_ext.extend_correlated(inputs.size(), inputs); - octetStream os; - evaluator.player.receive(os); - int i_bit = 0; - - for (auto x : a) - { - if (x.from == 1) - { - for (auto& reg : processor.S[x.dest].get_regs()) - { - Key key; - os.unserialize(key); - reg.set(key ^ evaluator.ot_ext.receiverOutputMatrix[i_bit], - inputs.get_bit(i_bit)); - i_bit++; - } - } - } +void YaoEvalWire::inputbvec(GC::Processor >& processor, + ProcessorBase& input_processor, const vector& args) +{ + YaoEvalInput inputter; + processor.inputbvec(inputter, input_processor, args, + inputter.evaluator.P->my_num()); + return; } void YaoEvalWire::op(const YaoEvalWire& left, const YaoEvalWire& right, diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index d5787bc7..4de503ed 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -14,11 +14,15 @@ #include "YaoWire.h" class YaoEvaluator; +class YaoEvalInput; +class ProcessorBase; class YaoEvalWire : public YaoWire { public: typedef YaoEvaluator Party; + typedef YaoEvalInput Input; + typedef GC::Processor> Processor; static string name() { return "YaoEvalWire"; } @@ -51,6 +55,8 @@ public: static void inputb(GC::Processor>& processor, const vector& args); + static void inputbvec(Processor& processor, ProcessorBase& input_processor, + const vector& args); static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&); @@ -72,6 +78,11 @@ public: void public_input(bool value); void op(const YaoEvalWire& left, const YaoEvalWire& right, Function func); bool get_output(); + + template + void my_input(T&, bool value, int n_bits); + template + void finalize_input(T& inputter, int from, int n_bits); }; #endif /* YAO_YAOEVALWIRE_H_ */ diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 8ac34d78..4a8b24f8 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -17,6 +17,8 @@ class YaoEvaluator: public GC::Thread>, public YaoCommon { + typedef GC::Thread> super; + protected: static thread_local YaoEvaluator* singleton; @@ -56,6 +58,9 @@ public: int get_n_worker_threads() { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } + + size_t data_sent() + { return super::data_sent() + player.comm_stats.total_data(); } }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbleInput.h b/Yao/YaoGarbleInput.h new file mode 100644 index 00000000..8976e7bc --- /dev/null +++ b/Yao/YaoGarbleInput.h @@ -0,0 +1,29 @@ +/* + * YaoGarbleInput.h + * + */ + +#ifndef YAO_YAOGARBLEINPUT_H_ +#define YAO_YAOGARBLEINPUT_H_ + +#include "YaoGarbler.h" + +class YaoGarbleWire; + +class YaoGarbleInput +{ +public: + YaoGarbler& garbler; + + YaoGarbleInput() : + garbler(YaoGarbler::s()) + { + } + + void exchange() + { + garbler.receiver_input_keys.push_back({}); + } +}; + +#endif /* YAO_YAOGARBLEINPUT_H_ */ diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index eff25d5b..637ce4b4 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -6,16 +6,21 @@ #include "YaoGarbleWire.h" #include "YaoGate.h" #include "YaoGarbler.h" +#include "YaoGarbleInput.h" #include "GC/ArgTuples.h" #include "GC/Processor.hpp" #include "GC/Secret.hpp" #include "GC/Thread.hpp" +#include "GC/ShareSecret.hpp" #include "YaoCommon.hpp" void YaoGarbleWire::random() { - key_ = YaoGarbler::s().prng.get_bit(); + if (YaoGarbler::s().prng.get_bit()) + key_ = YaoGarbler::s().get_delta(); + else + key_ = 0; } void YaoGarbleWire::public_input(bool value) @@ -154,46 +159,17 @@ void YaoGarbleWire::and_(GC::Memory >& S, void YaoGarbleWire::inputb(GC::Processor>& processor, const vector& args) { - InputArgList a(args); - int n_evaluator_bits = 0; auto& garbler = YaoGarbler::s(); - bool interactive = garbler.n_interactive_inputs_from_me(a) > 0; - for (auto x : a) - { - auto& dest = processor.S[x.dest]; - dest.resize_regs(x.n_bits); - if (x.from == 0) - { - long long input = processor.get_input(x.params, interactive); - for (auto& reg : dest.get_regs()) - { - reg.public_input(input & 1); - input >>= 1; - } - } - else - { - n_evaluator_bits += x.n_bits; - } - } + YaoGarbleInput input; + processor.inputb(input, processor, args, garbler.P->my_num()); +} - if (interactive) - cout << "Thank you"; - - garbler.receiver_input_keys.push_back({}); - - for (auto x : a) - { - if (x.from == 1) - { - for (auto& reg : processor.S[x.dest].get_regs()) - { - reg.set(garbler.prng.get_doubleword(), 0); - assert(reg.mask() == 0); - garbler.receiver_input_keys.back().push_back(reg.full_key()); - } - } - } +void YaoGarbleWire::inputbvec(GC::Processor>& processor, + ProcessorBase& input_processor, const vector& args) +{ + auto& garbler = YaoGarbler::s(); + YaoGarbleInput input; + processor.inputbvec(input, input_processor, args, garbler.P->my_num()); } inline void YaoGarbler::store_gate(const YaoGate& gate) diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 196c8de6..885bde7c 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -14,11 +14,15 @@ #include class YaoGarbler; +class YaoGarbleInput; +class ProcessorBase; class YaoGarbleWire : public YaoWire { public: typedef YaoGarbler Party; + typedef YaoGarbleInput Input; + typedef GC::Processor> Processor; static string name() { return "YaoGarbleWire"; } @@ -50,6 +54,8 @@ public: static void inputb(GC::Processor>& processor, const vector& args); + static void inputbvec(Processor& processor, ProcessorBase& input_processor, + const vector& args); static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>&); @@ -83,6 +89,25 @@ public: void public_input(bool value); void op(const YaoGarbleWire& left, const YaoGarbleWire& right, Function func); char get_output(); + + template + void my_input(T&, bool value, int n_bits) + { + assert(n_bits == 1); + public_input(value); + } + + template + void finalize_input(T& inputter, int from, int n_bits) + { + assert(n_bits == 1); + if (from == 1) + { + set(inputter.garbler.prng.get_doubleword(), 0); + assert(mask() == 0); + inputter.garbler.receiver_input_keys.back().push_back(full_key()); + } + } }; inline void YaoGarbleWire::randomize(PRNG& prng) diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index 754a0d3d..4f7d8698 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -112,3 +112,8 @@ void YaoGarbler::process_receiver_inputs() receiver_input_keys.pop_front(); } } + +size_t YaoGarbler::data_sent() +{ + return super::data_sent() + player.comm_stats.total_data(); +} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 860ff311..f1239c22 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -24,6 +24,8 @@ class YaoGarbler: public GC::Thread>, friend class YaoGarbleWire; friend class YaoCommon; + typedef GC::Thread> super; + protected: static thread_local YaoGarbler* singleton; @@ -69,6 +71,8 @@ public: int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } + + size_t data_sent(); }; inline YaoGarbler& YaoGarbler::s() diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index 2a7ef170..d0e195d0 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -61,12 +61,16 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) OnlineOptions online_opts(opt, argc, argv); opt.parse(argc, argv); opt.syntax = "./yao-player.x [OPTIONS] "; - if (opt.lastArgs.size() == 1) + vector free_args = opt.firstArgs; + free_args.insert(free_args.end(), opt.unknownArgs.begin(), opt.unknownArgs.end()); + free_args.insert(free_args.end(), opt.lastArgs.begin(), opt.lastArgs.end()); + if (free_args.size() == 2) { - progname = *opt.lastArgs[0]; + progname = *free_args[1]; } else { + throw exception(); string usage; opt.getUsage(usage); cerr << usage; diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h index e6b8f61a..d726e080 100644 --- a/Yao/YaoWire.h +++ b/Yao/YaoWire.h @@ -25,6 +25,9 @@ public: { key_ = left.key_ ^ right.key_; } + + template + void other_input(T&, int) {} }; #endif /* YAO_YAOWIRE_H_ */ diff --git a/Yao/config.h b/Yao/config.h index 237469c3..adb9c732 100644 --- a/Yao/config.h +++ b/Yao/config.h @@ -10,12 +10,10 @@ //#define CHECK_BUFFER -#define HALF_GATES - class YaoFullGate; class YaoHalfGate; -#ifdef HALF_GATES +#ifndef FULL_GATES typedef YaoHalfGate YaoGate; #else typedef YaoFullGate YaoGate; diff --git a/compile.py b/compile.py index cd834a17..9bcc42f8 100755 --- a/compile.py +++ b/compile.py @@ -26,8 +26,6 @@ def main(): help="specify output file") parser.add_option("-a", "--asm-output", dest="asmoutfile", help="asm output file for debugging") - parser.add_option("-p", "--primesize", dest="param", default=-1, - help="bit length of modulus") parser.add_option("-g", "--galoissize", dest="galois", default=40, help="bit length of Galois field") parser.add_option("-d", "--debug", action="store_true", dest="debug", @@ -81,7 +79,7 @@ def main(): print('Note that -O/--optimize-hard currently has no effect') def compilation(): - prog = Compiler.run(args, options, param=int(options.param), + prog = Compiler.run(args, options, merge_opens=options.merge_opens, debug=options.debug) prog.write_bytes(options.outfile)