From 1669ce5bf538892f89d3c8dff9f7c712e17d8fac Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 28 May 2018 22:45:08 +0200 Subject: [PATCH] SPDZ-Yao. --- .gitignore | 1 + BMR/AndJob.h | 1 + BMR/CommonParty.cpp | 7 - BMR/CommonParty.h | 4 +- BMR/Gate.h | 1 + BMR/Key.h | 16 +- BMR/Party.cpp | 38 ++-- BMR/Party.h | 21 ++- BMR/Register.cpp | 179 ++++++++++++++++-- BMR/Register.h | 53 ++++-- BMR/TrustedParty.cpp | 19 +- BMR/TrustedParty.h | 5 +- BMR/common.h | 2 + BMR/network/Server.cpp | 2 +- BMR/network/utils.cpp | 7 - BMR/network/utils.h | 2 - BMR/prf.h | 7 + BMR/proto_utils.h | 2 - CONFIG | 2 +- Compiler/GC/instructions.py | 11 ++ Compiler/GC/types.py | 186 ++++++++++++++++--- Compiler/__init__.py | 6 +- Compiler/allocator.py | 4 + Compiler/comparison.py | 14 +- Compiler/compilerLib.py | 2 +- Compiler/config.py | 6 +- Compiler/floatingpoint.py | 2 +- Compiler/instructions.py | 14 +- Compiler/instructions_base.py | 7 +- Compiler/library.py | 56 ++---- Compiler/oram.py | 14 +- Compiler/program.py | 7 +- Compiler/types.py | 335 +++++++++++++++++++++++----------- GC/FakeSecret.cpp | 18 ++ GC/FakeSecret.h | 10 + GC/Instruction.cpp | 22 ++- GC/Instruction.h | 3 + GC/Machine.cpp | 5 + GC/Memory.cpp | 5 + GC/Processor.cpp | 47 ++++- GC/Processor.h | 9 + GC/Program.cpp | 17 ++ GC/Program.h | 1 + GC/Secret.cpp | 74 ++++++-- GC/Secret.h | 42 +++-- GC/Secret_inline.h | 46 ++--- GC/instructions.h | 2 + Makefile | 20 +- Math/bigint.h | 22 +++ Math/gfp.cpp | 16 +- Math/gfp.h | 9 +- Networking/Player.cpp | 28 ++- Networking/Player.h | 6 + OT/BitMatrix.cpp | 1 - OT/Tools.cpp | 2 +- Player-Online.cpp | 28 ++- Processor/Instruction.cpp | 55 +++--- Processor/Instruction.h | 11 +- Processor/Machine.cpp | 12 +- Processor/Processor.h | 3 + Programs/Source/test_gc.mpc | 11 +- README.md | 28 ++- Scripts/bmr-program-run.sh | 14 +- SimpleOT | 2 +- Tools/Config.cpp | 2 +- Tools/FlexBuffer.h | 29 ++- Tools/MMO.cpp | 11 -- Tools/MMO.h | 28 +++ Tools/Worker.h | 29 ++- Tools/octetStream.cpp | 26 ++- Tools/octetStream.h | 7 +- Tools/time-func.h | 13 ++ Yao/YaoAndJob.h | 58 ++++++ Yao/YaoEvalWire.cpp | 115 ++++++++++++ Yao/YaoEvalWire.h | 39 ++++ Yao/YaoEvaluator.cpp | 61 +++++++ Yao/YaoEvaluator.h | 63 +++++++ Yao/YaoGarbleWire.cpp | 204 +++++++++++++++++++++ Yao/YaoGarbleWire.h | 47 +++++ Yao/YaoGarbler.cpp | 70 +++++++ Yao/YaoGarbler.h | 76 ++++++++ Yao/YaoGate.cpp | 42 +++++ Yao/YaoGate.h | 78 ++++++++ Yao/YaoPlayer.cpp | 110 +++++++++++ Yao/YaoPlayer.h | 23 +++ Yao/YaoSimulator.cpp | 18 ++ Yao/YaoSimulator.h | 19 ++ yao-player.cpp | 11 ++ yao-simulate.cpp | 13 ++ 89 files changed, 2345 insertions(+), 449 deletions(-) create mode 100644 Yao/YaoAndJob.h create mode 100644 Yao/YaoEvalWire.cpp create mode 100644 Yao/YaoEvalWire.h create mode 100644 Yao/YaoEvaluator.cpp create mode 100644 Yao/YaoEvaluator.h create mode 100644 Yao/YaoGarbleWire.cpp create mode 100644 Yao/YaoGarbleWire.h create mode 100644 Yao/YaoGarbler.cpp create mode 100644 Yao/YaoGarbler.h create mode 100644 Yao/YaoGate.cpp create mode 100644 Yao/YaoGate.h create mode 100644 Yao/YaoPlayer.cpp create mode 100644 Yao/YaoPlayer.h create mode 100644 Yao/YaoSimulator.cpp create mode 100644 Yao/YaoSimulator.h create mode 100644 yao-player.cpp create mode 100644 yao-simulate.cpp diff --git a/.gitignore b/.gitignore index ca41754c..aa53a0df 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ config_mine.py *.rej *.tmp callgrind.out.* +HOSTS* # Vim .*.swp diff --git a/BMR/AndJob.h b/BMR/AndJob.h index 6679cffa..e329afe7 100644 --- a/BMR/AndJob.h +++ b/BMR/AndJob.h @@ -10,6 +10,7 @@ #include "GC/Secret.h" #include "Register.h" +#include "GarbledGate.h" #include using namespace std; diff --git a/BMR/CommonParty.cpp b/BMR/CommonParty.cpp index 239090e7..10929d9e 100644 --- a/BMR/CommonParty.cpp +++ b/BMR/CommonParty.cpp @@ -95,13 +95,6 @@ void CommonParty::next_gate(GarbledGate& gate) gate.init_inputs(gate_counter2, _N); } -void CommonParty::input(Register& reg, party_id_t from) -{ - (void)reg; - (void)from; - throw not_implemented(); -} - SendBuffer& CommonParty::get_buffer(MSG_TYPE type) { SendBuffer& buffer = buffers[type]; diff --git a/BMR/CommonParty.h b/BMR/CommonParty.h index 3adfdef3..7be0aeb5 100644 --- a/BMR/CommonParty.h +++ b/BMR/CommonParty.h @@ -100,13 +100,13 @@ public: int init(const char* netmap_file, int id); virtual void reset(); + virtual party_id_t get_id() { return -1; } + gate_id_t new_gate(); void next_gate(GarbledGate& gate); gate_id_t next_gate(int skip) { return gate_counter2 += skip; } size_t get_garbled_tbl_size() { return garbled_tbl_size; } - void input(Register& reg, party_id_t from); - SendBuffer& get_buffer(MSG_TYPE type); gf2n get_mac_key() { return mac_key; } diff --git a/BMR/Gate.h b/BMR/Gate.h index e1c7f4c1..e38bf698 100644 --- a/BMR/Gate.h +++ b/BMR/Gate.h @@ -9,6 +9,7 @@ #include #include "Key.h" +#include "common.h" #define NO_LAYER (-1) diff --git a/BMR/Key.h b/BMR/Key.h index e72e5ff8..647f0a71 100644 --- a/BMR/Key.h +++ b/BMR/Key.h @@ -42,7 +42,10 @@ public: void serialize(SendBuffer& output) const { output.serialize(r); } void serialize_no_allocate(SendBuffer& output) const { output.serialize_no_allocate(r); } - bool get_signal() { return _mm_cvtsi128_si64(r) & 1; } + bool get_signal() const { return _mm_cvtsi128_si64(r) & 1; } + void set_signal(bool signal); + + Key doubling(int i) const; template T get() const; @@ -79,6 +82,17 @@ inline __m128i Key::get() const return r; } +inline void Key::set_signal(bool signal) +{ + r &= ~_mm_cvtsi64x_si128(1); + r ^= _mm_cvtsi64x_si128(signal); +} + +inline Key Key::doubling(int i) const +{ + return _mm_sllv_epi64(r, _mm_set_epi64x(i, i)); +} + #else //__PRIME_FIELD__ is defined diff --git a/BMR/Party.cpp b/BMR/Party.cpp index c938dd99..65536344 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -131,11 +131,10 @@ void BaseParty::NewMessage(int from, ReceivedMsg& msg) #ifdef DEBUG_STEPS printf("TYPE_MASK_INPUTS\n"); #endif - input_masks.insert(input_masks.end(), message + sizeof(MSG_TYPE), message + len); -#ifdef DEBUG_COMM - cout << "got " << dec << input_masks.size() << " input masks" << endl; +#ifdef DEBUG_INPUT + cout << "received " << msg.left() << " input masks" << endl; #endif - input_mask = input_masks.begin(); + mask_input(msg); break; } case TYPE_MASK_OUTPUT: @@ -677,6 +676,15 @@ void Party::mask_output(ReceivedMsg& msg) #endif } +void Party::mask_input(ReceivedMsg& msg) +{ + input_masks.insert(input_masks.end(), msg.data() + sizeof(MSG_TYPE), msg.data() + msg.size()); +#ifdef DEBUG_COMM + cout << "got " << dec << input_masks.size() << " input masks" << endl; +#endif + input_mask = input_masks.begin(); +} + void Party::receive_keys(Key* keys) { resize_registers(); for (size_t w = 0; w < _W; w++) @@ -722,10 +730,10 @@ ProgramParty::ProgramParty(int argc, char** argv) : } _id = atoi(argv[1]); - ifstream file((string("Programs/Bytecode/") + argv[2] + "-0.bc").c_str()); - program.parse(file); + program.parse(string(argv[2]) + "-0"); machine.reset(program); processor.reset(program); + processor.open_input_file("user_inputs/user_" + to_string(_id - 1) + "_input.txt"); prf_machine.reset(*reinterpret_cast >* >(&program)); prf_processor.reset(*reinterpret_cast >* >(&program)); if (singleton) @@ -768,6 +776,8 @@ ProgramParty::ProgramParty(int argc, char** argv) : else threshold = 128; cout << "Threshold for multi-threaded evaluation: " << threshold << endl; + eval_threads = new Worker[N_EVAL_THREADS]; + and_jobs.resize(N_EVAL_THREADS); } ProgramParty::~ProgramParty() @@ -777,6 +787,7 @@ ProgramParty::~ProgramParty() delete P; if (MC) delete MC; + delete[] eval_threads; cout << "SPDZ loading: " << spdz_counters[SPDZ_LOAD] << endl; cout << "SPDZ storing: " << spdz_counters[SPDZ_STORE] << endl; cout << "SPDZ wire storage: " << 1e-9 * spdz_storage << " GB" << endl; @@ -808,9 +819,14 @@ void ProgramParty::load_garbled_circuit() throw runtime_error("no garbled circuit available"); if (not output_masks_store.pop(output_masks)) throw runtime_error("no output masks available"); + if (not input_masks_store.pop(input_masks)) + throw runtime_error("no input masks available"); #ifdef DEBUG_OUTPUT_MASKS cout << "loaded " << output_masks.left() << " output masks" << endl; #endif +#ifdef DEBUG_INPUT + cout << "loaded " << input_masks.left() << " input masks" << endl; +#endif } void ProgramParty::start_online_round() @@ -876,16 +892,6 @@ void ProgramParty::receive_all_keys(Register& reg, bool external) reg.keys[external][i] = *(keys_for_prf++); } -void ProgramParty::input_value(party_id_t from, char value) -{ - if (from == _id) - { - if (value and (1 - value)) - throw runtime_error("invalid input"); - } - throw not_implemented(); -} - void ProgramParty::receive_spdz_wires(ReceivedMsg& msg) { int op; diff --git a/BMR/Party.h b/BMR/Party.h index c09b289d..10d0b480 100644 --- a/BMR/Party.h +++ b/BMR/Party.h @@ -10,6 +10,8 @@ #include #include +#include + #include "Register.h" #include "GarbledGate.h" #include "network/Node.h" @@ -31,7 +33,7 @@ class BooleanCircuit; #ifndef N_EVAL_THREADS // Default Intel desktop processor has 8 half cores. // This is beneficial if only one AES available per full core. -#define N_EVAL_THREADS (8) +#define N_EVAL_THREADS (get_nprocs()) #endif @@ -61,9 +63,6 @@ protected: // int _num_evaluation_threads; struct timeval _start_online_net, _end_online_net; - vector input_masks; - vector::iterator input_mask; - Timer online_timer; Key delta; @@ -81,6 +80,7 @@ protected: virtual void _check_evaluate() = 0; virtual void mask_output(ReceivedMsg& msg) = 0; + virtual void mask_input(ReceivedMsg& msg) = 0; void done(); @@ -110,6 +110,9 @@ private: vector _input; + vector input_masks; + vector::iterator input_mask; + SendBuffer _external_values_msg; int _num_externals_msg_received; std::mutex _process_externals_mx; @@ -159,6 +162,7 @@ private: void start_online_round(); void mask_output(ReceivedMsg& msg); + void mask_input(ReceivedMsg& msg); int get_n_inputs(); }; @@ -177,10 +181,11 @@ class ProgramParty : public BaseParty size_t garbled_storage; vector spdz_counters; - Worker eval_threads[N_EVAL_THREADS]; - AndJob and_jobs[N_EVAL_THREADS]; + Worker* eval_threads; + vector and_jobs; ReceivedMsgStore output_masks_store; + ReceivedMsgStore input_masks_store; GC::Memory< GC::Secret::DynamicType > dynamic_memory; GC::Machine< GC::Secret > machine; @@ -212,6 +217,7 @@ class ProgramParty : public BaseParty void start_online_round(); void mask_output(ReceivedMsg& msg) { output_masks_store.push(msg); } + void mask_input(ReceivedMsg& msg) { input_masks_store.push(msg); } public: static ProgramParty* singleton; @@ -220,6 +226,7 @@ public: ReceivedMsgStore garbled_circuits; ReceivedMsg output_masks; + ReceivedMsg input_masks; MAC_Check* MC; Player* P; @@ -234,8 +241,6 @@ public: void reset(); - void input_value(party_id_t from, char value); - void get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire); void store_wire(const Register& reg); diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 20631f90..4ec9d951 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -237,6 +237,7 @@ void Register::eval(const Register& left, const Register& right, GarbledGate& ga phex(gate.input(ext_l, 1), 16); printf("right input"); phex(gate.input(ext_r, 1), 16); + unsigned g = gate.id; #endif Key k; @@ -400,7 +401,7 @@ void Register::garble(const Register& left, const Register& right, char maskr = right.mask; char masko = mask; #ifdef DEBUG - printf("\ngate %u, leftwire=%u, rightwire=%u, outwire=%u: func=%d%d%d%d, msk_l=%d, msk_r=%d, msk_o=%d\n" + printf("\ngate %u, leftwire=%lu, rightwire=%lu, outwire=%lu: func=%d%d%d%d, msk_l=%d, msk_r=%d, msk_o=%d\n" , g,gate->_left, gate->_right, gate->_out ,func[0],func[1],func[2],func[3], maskl, maskr, masko); #endif @@ -479,10 +480,11 @@ void PRFRegister::op(const ProgramRegister& left, const ProgramRegister& right, void PRFRegister::input(party_id_t from, char value) { + (void)from; (void)value; ProgramParty& party = *ProgramParty::singleton; party.receive_keys(*this); - party.input(*this, from); + party.store_wire(*this); #ifdef DEBUG cout << "(PRF) input from " << from << ":" << endl; keys.print(get_id()); @@ -511,12 +513,158 @@ void PRFRegister::random() #endif } +void EvalRegister::check_input(long long in, int n_bits) +{ + auto test = in >> (n_bits - 1); + if (n_bits == 1) + { + if (not (in == 0 or in == 1)) + throw runtime_error("input not a bit: " + to_string(in)); + } + else if (not (test == 0 or test == -1)) + { + throw runtime_error( + "input too large for a " + std::to_string(n_bits) + + "-bit signed integer: " + to_string(in)); + } +} + +class InputAccess +{ + party_id_t from; + size_t n_bits; + GC::Secret& dest; + GC::Processor >& processor; + ProgramParty& party; + +public: + InputAccess(int from, int n_bits, GC::Secret& dest, + GC::Processor >& processor) : + from(from), n_bits(n_bits), dest(dest), processor(processor), party( + ProgramParty::s()) + { + if (from > party.get_n_parties() or n_bits > 100) + throw runtime_error("invalid input parameters"); + } + + void prepare_masks(octetStream& os) + { + dest.resize_regs(n_bits); + for (auto& reg : dest.get_regs()) + party.load_wire(reg); + if (from == party.get_id()) + { + long long in; + processor.input_file >> in; + EvalRegister::check_input(in, n_bits); + for (size_t i = 0; i < n_bits; i++) + { + auto& reg = dest.get_reg(i); + reg.input_helper((in >> i) & 1, os); + } + } + } + + void received_masks(vector& oss) + { + size_t id = party.get_id() - 1; + for (auto& reg : dest.get_regs()) + { + if (party.get_id() != from) + { + char ext; + oss[from - 1].unserialize(ext); + reg.set_external(ext); + } + oss[id].serialize(reg.get_garbled_entry()[id]); + } + } + + void received_labels(vector& oss) + { + for (auto& reg : dest.get_regs()) + { + for (party_id_t id = 1; id < (size_t)party.get_n_parties() + 1; id++) + { + Key key; + if (id != party.get_id()) + { + oss[id - 1].unserialize(key); + reg.set_external_key(id, key); + } + } + } + } +}; + +template <> +void EvalRegister::inputb(GC::Processor >& processor, + const vector& args) +{ + auto& party = ProgramParty::s(); + vector oss(party.get_n_parties()); + octetStream& my_os = oss[party.get_id() - 1]; + vector accesses; + for (size_t j = 0; j < args.size(); j += 3) + { + accesses.push_back( + { args[j] + 1, args[j + 1], processor.S[args[j + 2]], processor }); + } + for (auto& access : accesses) + access.prepare_masks(my_os); + party.P->Broadcast_Receive(oss, true); + my_os.reset_write_head(); + for (auto& access : accesses) + access.received_masks(oss); + party.P->Broadcast_Receive(oss, true); + for (auto& access : accesses) + access.received_labels(oss); +} + +void EvalRegister::input_helper(char value, octetStream& os) +{ + set_mask(ProgramParty::s().input_masks.pop_front()); + set_external(get_mask() ^ value); + os.serialize(get_external()); +} + void EvalRegister::input(party_id_t from, char value) { - ProgramParty::s().input_value(from, value); + auto& party = ProgramParty::s(); + party.load_wire(*this); + octetStream os; + if (from == party.get_id()) + { + if (value and (1 - value)) + throw runtime_error("invalid input"); + input_helper(value, os); + party.P->send_all(os, true); + } + else + { + party.P->receive_player(from - 1, os); + char ext; + os.unserialize(ext); + set_external(ext); + } + vector oss(party.get_n_parties()); + size_t id = party.get_id() - 1; + oss[id].serialize(garbled_entry[id]); +#ifdef DEBUG_COMM + cout << "send " << garbled_entry[id] << ", " + << oss[id].get_length() << " bytes from " << id << endl; +#endif + party.P->Broadcast_Receive(oss, true); + for (size_t i = 0; i < (size_t)party.get_n_parties(); i++) + { + if (i != id) + oss[i].unserialize(garbled_entry[i]); + } + keys[external] = garbled_entry; #ifdef DEBUG cout << "(Input) input from " << from << ":" << endl; keys.print(get_id()); + cout << garbled_entry << endl; #endif } @@ -565,11 +713,13 @@ void RandomRegister::input(party_id_t from, char value) { (void)value; randomize(); + auto& party = TrustedProgramParty::s(); + party.store_wire(*this); + party.msg_input_masks[from - 1].push_back(get_mask()); #ifdef DEBUG cout << "(Random) input from " << from << ":" << endl; keys.print(get_id()); #endif - CommonParty::singleton->input(*this, from); } void RandomRegister::public_input(bool value) @@ -585,6 +735,13 @@ void RandomRegister::public_input(bool value) party.store_wire(*this); } +void GarbleRegister::input(party_id_t from, char value) +{ + (void)from; + (void)value; + TrustedProgramParty::s().load_wire(*this); +} + void GarbleRegister::public_input(bool value) { (void)value; @@ -672,7 +829,7 @@ void EvalRegister::XOR(const Register& left, const Register& right) garbled_entry = left.get_garbled_entry() ^ right.get_garbled_entry(); #ifdef DEBUG cout << "Eval XOR *" << get_id() << " = *" << left.get_id() << " ^ *" << right.get_id() << endl; - for (int i = 0; i < garbled_entry.size(); i++) + for (size_t i = 0; i < garbled_entry.size(); i++) cout << garbled_entry[i] << " = " << left.get_garbled_entry()[i] << " ^ " << right.get_garbled_entry()[i] << endl; #endif @@ -779,12 +936,12 @@ void EvalRegister::store(GC::Memory& mem, { GC::SpdzShare& dest = mem[access.address]; dest.assign_zero(); - const vector& sources = access.source.get_regs(); + const vector& sources = access.source.get_regs(); for (unsigned int i = 0; i < sources.size(); i++) { SpdzWire spdz_wire; party.get_spdz_wire(SPDZ_STORE, spdz_wire); - const Register& reg = sources[i]; + const EvalRegister& reg = sources[i]; Share tmp; gf2n ext = (int)reg.get_external(); //cout << "ext:" << ext << "/" << (int)reg.get_external() << " " << endl; @@ -909,7 +1066,7 @@ void EvalRegister::load(vector > >& acc { const GC::SpdzShare& source = mem[access.address]; Share mask; - vector& dests = access.dest.get_regs(); + vector& dests = access.dest.get_regs(); for (unsigned int i = 0; i < dests.size(); i++) { spdz_wires.push_back({}); @@ -945,7 +1102,7 @@ void EvalRegister::load(vector > >& acc for (size_t j = 0; j < accesses.size(); j++) { - vector& dests = accesses[j].dest.get_regs(); + vector& dests = accesses[j].dest.get_regs(); for (unsigned int i = 0; i < dests.size(); i++) { bool ext = masked[j].get_bit(i); @@ -961,7 +1118,7 @@ void EvalRegister::load(vector > >& acc int base = 0; for (auto access : accesses) { - vector& dests = access.dest.get_regs(); + vector& dests = access.dest.get_regs(); for (unsigned int i = 0; i < dests.size(); i++) for (int j = 0; j < party.get_n_parties(); j++) { @@ -1050,7 +1207,7 @@ void KeyTuple::randomize() { CommonParty::s().prng.get_octets((octet*)keys[i].data(), part_size()); #ifdef DEBUG - for (int j = 0; j < keys[i].size(); j++) + for (unsigned j = 0; j < keys[i].size(); j++) { keys[i][j] = { 0, (counter << 16) + (i << 8) + j }; counter++; diff --git a/BMR/Register.h b/BMR/Register.h index fb0e3ee2..32245f22 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -162,6 +162,7 @@ public: char get_output() { check_external(); check_mask(); return mask ^ external; } char get_output_no_check() { return mask ^ external; } const KeyVector& get_garbled_entry() const { return garbled_entry; } + const Key& get_garbled_wire(party_id_t i) const { return garbled_entry[i-1]; } void print_input(int id); void print() const { keys.print(get_id()); } void check_external() const; @@ -192,16 +193,12 @@ public: inline BlackHole& endl(BlackHole& b) { return b; } inline BlackHole& flush(BlackHole& b) { return b; } -class ProgramRegister : public Register +class Phase { public: typedef BlackHole out_type; static const BlackHole out; - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - static void check(const int128& value, word share, int128 mac) { (void)value; (void)share; (void)mac; } static void get_dyn_mask(GC::Mask& mask, int length, int mac_length) @@ -222,24 +219,29 @@ public: template static void andrs(T& processor, const vector& args) { processor.andrs(args); } + template + static void inputb(T& processor, const vector& args) { processor.input(args); } + static void check_input(long long in, int n_bits) { (void)in; (void)n_bits; } void input(party_id_t from, char value = -1) { (void)from; (void)value; } void public_input(bool value) { (void)value; } void random() {} + void output() {} +}; + +class ProgramRegister : public Phase, public Register +{ +public: + static Register new_reg(); + static Register tmp_reg() { return new_reg(); } + static Register and_reg() { return new_reg(); } + char get_output() { return 0; } + + ProgramRegister(const Register& reg) : Register(reg) {} }; -class FirstRoundRegister : public ProgramRegister -{ -public: -}; - -class SecondRoundRegister : public ProgramRegister -{ -public: -}; - -class PRFRegister : public FirstRoundRegister +class PRFRegister : public ProgramRegister { public: static string name() { return "PRF"; } @@ -248,6 +250,8 @@ public: static void load(vector >& accesses, const GC::Memory& source); + PRFRegister(const Register& reg) : ProgramRegister(reg) {} + void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -256,7 +260,7 @@ public: void output(); }; -class EvalRegister : public SecondRoundRegister +class EvalRegister : public ProgramRegister { public: static string name() { return "Evaluation"; } @@ -278,6 +282,10 @@ public: template static void andrs(T& processor, const vector& args); + template + static void inputb(T& processor, const vector& args); + + EvalRegister(const Register& reg) : ProgramRegister(reg) {} void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -290,10 +298,12 @@ public: template static void store_clear_in_dynamic(GC::Memory& mem, const vector& accesses); + static void check_input(long long input, int n_bits); void input(party_id_t from, char value = -1); + void input_helper(char value, octetStream& os); }; -class GarbleRegister : public SecondRoundRegister +class GarbleRegister : public ProgramRegister { public: static string name() { return "Garbling"; } @@ -302,14 +312,17 @@ public: static void load(vector >& accesses, const GC::Memory& source); + GarbleRegister(const Register& reg) : ProgramRegister(reg) {} + void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); + void input(party_id_t from, char value = -1); void public_input(bool value); void random(); void output() {} }; -class RandomRegister : public FirstRoundRegister +class RandomRegister : public ProgramRegister { public: static string name() { return "Randomization"; } @@ -321,6 +334,8 @@ public: static void load(vector >& accesses, const GC::Memory& source); + RandomRegister(const Register& reg) : ProgramRegister(reg) {} + void randomize(); void op(const Register& left, const Register& right, Function func); diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index ce3dac92..b5969b1a 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -61,9 +61,7 @@ TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : cerr << "Usage: " << argv[0] << " [netmap]" << endl; exit(1); } - - ifstream file((string("Programs/Bytecode/") + argv[1] + "-0.bc").c_str()); - program.parse(file); + program.parse(string(argv[1]) + "-0"); processor.reset(program); machine.reset(program); random_processor.reset(program.cast< GC::Secret >()); @@ -182,6 +180,15 @@ void TrustedParty::send_input_masks(party_id_t pid) _node->Send(pid, buffer); } +void TrustedProgramParty::send_input_masks(party_id_t pid) +{ + SendBuffer& buffer = msg_input_masks[pid-1]; +#ifdef DEBUG_ROUNDS + cout << "sending " << buffer.size() << " input masks to " << pid << endl; +#endif + _node->Send(pid, buffer); +} + void TrustedParty::send_output_masks() { prepare_output_regs(); @@ -388,6 +395,12 @@ bool TrustedProgramParty::_fill_keys() spdz_wires[i].resize(get_n_parties()); } msg_output_masks = get_buffer(TYPE_MASK_OUTPUT); + msg_input_masks.resize(get_n_parties()); + for (auto& buffer : msg_input_masks) + { + buffer.clear(); + fill_message_type(buffer, TYPE_MASK_INPUTS); + } return GC::DONE_BREAK == first_phase(program, random_processor, random_machine); } diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 834f220f..e9503192 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -19,6 +19,7 @@ class BaseTrustedParty : virtual public CommonParty { public: vector prf_outputs; + vector msg_input_masks; BaseTrustedParty(); virtual ~BaseTrustedParty() {} @@ -37,7 +38,7 @@ protected: std::atomic_uint _received_gc_received; std::atomic_uint n_received; - vector msg_keys, msg_input_masks; + vector msg_keys; int randomfd; @@ -130,7 +131,7 @@ private: bool _fill_keys(); void _launch_online(); - void send_input_masks(party_id_t pid) { (void)pid; } + void send_input_masks(party_id_t pid); void send_output_masks(); void garble(); diff --git a/BMR/common.h b/BMR/common.h index 740b9145..ae3de2d1 100644 --- a/BMR/common.h +++ b/BMR/common.h @@ -9,6 +9,7 @@ #define CIRCUIT_INC_COMMON_H_ #include +#include typedef unsigned long wire_id_t; typedef unsigned long gate_id_t; @@ -32,6 +33,7 @@ public: rep[i] = (int_rep << shift(i)) & 1; } uint8_t operator[](int i) { return rep[i]; } + bool call(bool left, bool right) { return rep[2 * left + right]; } }; template diff --git a/BMR/network/Server.cpp b/BMR/network/Server.cpp index 37639af9..fa3719d6 100644 --- a/BMR/network/Server.cpp +++ b/BMR/network/Server.cpp @@ -39,7 +39,7 @@ Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsig _servaddr.sin_port = htons(_port); if( 0 != bind(_servfd, (struct sockaddr *) &_servaddr, sizeof(_servaddr)) ) - printf("Server:: Error bind: \n%s\n",strerror(errno)); + printf("Server:: Error binding to %d: \n%s\n", _port, strerror(errno)); if(0 != listen(_servfd, _expected_clients)) printf("Server:: Error listen: \n%s\n",strerror(errno)); diff --git a/BMR/network/utils.cpp b/BMR/network/utils.cpp index f6704e6d..8ab3f637 100644 --- a/BMR/network/utils.cpp +++ b/BMR/network/utils.cpp @@ -15,13 +15,6 @@ #include "utils.h" -void fill_random(void* buffer, unsigned int length) -{ - int nullfd = open("/dev/urandom", O_RDONLY); - read(nullfd, (char*)buffer, length); - close(nullfd); -} - char cs(char* msg, unsigned int len, char result) { for(size_t i = 0; i < len; i++) result += msg[i]; diff --git a/BMR/network/utils.h b/BMR/network/utils.h index 6b4dbc69..c76a921b 100644 --- a/BMR/network/utils.h +++ b/BMR/network/utils.h @@ -9,8 +9,6 @@ #define NETWORK_TEST_UTILS_H_ -void fill_random(void* buffer, unsigned int length); - char cs(char* msg, unsigned int len, char result=0); void phex (const void *addr, int len); diff --git a/BMR/prf.h b/BMR/prf.h index 236528ab..231c595a 100644 --- a/BMR/prf.h +++ b/BMR/prf.h @@ -15,6 +15,13 @@ void PRF_single(const Key& key, char* input, char* output); +inline Key PRF_single(const Key& key, const Key& input) +{ + Key output; + PRF_single(key, (char*)&input, (char*)&output); + return output; +} + inline void PRF_chunk(const Key& key, char* input, char* output, int number) { __m128i* in = (__m128i*)input; diff --git a/BMR/proto_utils.h b/BMR/proto_utils.h index 6d321fbb..35d4a190 100644 --- a/BMR/proto_utils.h +++ b/BMR/proto_utils.h @@ -20,8 +20,6 @@ using namespace std; #define LOOPBACK_STR "LOOPBACK" -void fill_random(void* buffer, unsigned int length); - class SendBuffer; void fill_message_type(void* buffer, MSG_TYPE type); diff --git a/CONFIG b/CONFIG index 7cc1a5e3..1197541b 100644 --- a/CONFIG +++ b/CONFIG @@ -46,7 +46,7 @@ endif BOOST = -lboost_system -lboost_thread $(MY_BOOST) CXX = g++ -CFLAGS += $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 -mavx2 --std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = g++ diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index c7992501..fd169758 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -33,6 +33,8 @@ opcodes = dict( CONVCINT = 0x213, REVEAL = 0x214, STMSDCI = 0x215, + INPUTB = 0x216, + PRINTREGSIGNED = 0x220, ) class xors(base.Instruction): @@ -155,6 +157,11 @@ class reveal(base.Instruction): code = opcodes['REVEAL'] arg_format = ['int','cbw','sb'] +class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): + __slots__ = [] + code = opcodes['INPUTB'] + arg_format = tools.cycle(['p','int','sbw']) + class print_reg(base.IOInstruction): code = base.opcodes['PRINTREG'] arg_format = ['cb','i'] @@ -164,3 +171,7 @@ class print_reg(base.IOInstruction): class print_reg_plain(base.IOInstruction): code = base.opcodes['PRINTREGPLAIN'] arg_format = ['cb'] + +class print_reg_signed(base.IOInstruction): + code = opcodes['PRINTREGSIGNED'] + arg_format = ['int','cb'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 0eddb2f5..c045f6be 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -1,6 +1,7 @@ # (C) 2018 University of Bristol, Bar-Ilan University. See License.txt from Compiler.types import MemValue, read_mem_value, regint, Array +from Compiler.types import _bitint, _number, _fix from Compiler.program import Tape, Program from Compiler.exceptions import * from Compiler import util, oram, floatingpoint @@ -12,6 +13,7 @@ class bits(Tape.Register): size = 1 PreOp = staticmethod(floatingpoint.PreOpN) MemValue = staticmethod(lambda value: MemValue(value)) + decomposed = None @staticmethod def PreOR(l): return [1 - x for x in \ @@ -41,30 +43,30 @@ class bits(Tape.Register): return res hard_conv = conv @classmethod - def compose(cls, items, bit_length): - return cls.bit_compose(sum([item.bit_decompose(bit_length) for item in items], [])) + def compose(cls, items, bit_length=1): + return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], [])) @classmethod def bit_compose(cls, bits): if len(bits) == 1: return bits[0] bits = list(bits) res = cls.new(n=len(bits)) - cls.bitcom(res, *bits) + cls.bitcom(res, *(sbit.conv(bit) for bit in bits)) res.decomposed = bits return res def bit_decompose(self, bit_length=None): n = bit_length or self.n - if n > self.n: - raise Exception('wanted %d bits, only got %d' % (n, self.n)) - if n == 1: + suffix = [0] * (n - self.n) + if n == 1 and self.n == 1: return [self] + n = min(n, self.n) if self.decomposed is None or len(self.decomposed) < n: - res = [self.bit_type() for i in range(n)] + res = [self.bit_type() for i in range(self.n)] self.bitdec(self, *res) self.decomposed = res - return res + return res + suffix else: - return self.decomposed[:n] + return self.decomposed[:n] + suffix @classmethod def load_mem(cls, address, mem_type=None): res = cls() @@ -75,12 +77,13 @@ class bits(Tape.Register): return res def store_in_mem(self, address): self.store_inst[isinstance(address, (int, long))](self, address) - def __init__(self, value=None, n=None): + def __init__(self, value=None, n=None, size=None): + if size != 1 and size is not None: + raise Exception('invalid size for bit type: %s' % size) Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape) self.set_length(n or self.n) if value is not None: self.load_other(value) - self.decomposed = None def set_length(self, n): if n > self.max_length: print self.max_length @@ -95,8 +98,12 @@ class bits(Tape.Register): elif isinstance(self, type(other)) or isinstance(other, type(self)): self.mov(self, other) else: - raise CompilerError('cannot convert from %s to %s' % \ - (type(other), type(self))) + try: + other = self.bit_compose(other.bit_decompose()) + self.mov(self, other) + except: + raise CompilerError('cannot convert from %s to %s' % \ + (type(other), type(self))) def __repr__(self): return '%s(%d/%d)' % \ (super(bits, self).__repr__(), self.n, type(self).n) @@ -151,7 +158,7 @@ class cbits(bits): def print_reg(self, desc=''): inst.print_reg(self, desc) def print_reg_plain(self): - inst.print_reg_plain(self) + inst.print_reg_signed(self.n, self) output = print_reg_plain def reveal(self): return self @@ -183,6 +190,13 @@ class sbits(bits): inst.bit(res) return res @classmethod + def get_input_from(cls, player, n_bits=None): + if n_bits is None: + n_bits = cls.n + res = cls() + inst.inputb(player, n_bits, res) + return res + @classmethod def load_dynamic_mem(cls, address): res = cls() if isinstance(address, (long, int)): @@ -196,16 +210,19 @@ class sbits(bits): else: inst.stmsdi(self, cbits.conv(address)) def load_int(self, value): - if abs(value) < 2**31: - if (abs(value) > (1 << self.n)): - raise Exception('public value %d longer than %d bits' \ - % (value, self.n)) + if (abs(value) > (1 << self.n)): + raise Exception('public value %d longer than %d bits' \ + % (value, self.n)) + if self.n <= 32: inst.ldbits(self, self.n, value) - else: - value %= 2**self.n - if value >> 64 != 0: - raise NotImplementedError('public value too large') + elif self.n <= 64: self.load_other(regint(value)) + elif self.n <= 128: + lower = sbits.get_type(64)(value % 2**64) + upper = sbits.get_type(self.n - 64)(value >> 64) + self.mov(self, lower + (upper << 64)) + else: + raise NotImplementedError('more than 128 bits wanted') @read_mem_value def __add__(self, other): if isinstance(other, int): @@ -213,11 +230,17 @@ class sbits(bits): else: if not isinstance(other, sbits): other = sbits(other) - n = self.n - else: - n = max(self.n, other.n) + n = min(self.n, other.n) res = self.new(n=n) inst.xors(n, res, self, other) + max_n = max(self.n, other.n) + if max_n > n: + if self.n > n: + longer = self + else: + longer = other + bits = res.bit_decompose() + longer.bit_decompose()[n:] + res = self.bit_compose(bits) return res __radd__ = __add__ __sub__ = __add__ @@ -291,13 +314,41 @@ class sbits(bits): def equal(self, other, n=None): bits = (~(self + other)).bit_decompose() return reduce(operator.mul, bits) + def TruncPr(self, k, m, kappa=None): + if k > self.n: + raise Exception('TruncPr overflow: %d > %d' % (k, self.n)) + bits = self.bit_decompose() + res = self.get_type(k - m).bit_compose(bits[m:k]) + return res + @classmethod + def two_power(cls, n): + if n > cls.n: + raise Exception('two_power overflow: %d > %d' % (n, cls.n)) + res = cls() + if n == cls.n: + res.load_int(-1 << (n - 1)) + else: + res.load_int(1 << n) + return res class bit(object): n = 1 +def result_conv(x, y): + if util.is_constant(x): + if util.is_constant(y): + return lambda x: x + else: + return type(y).conv + if util.is_constant(y): + return type(x).conv + if type(x) is type(y): + return type(x).conv + return lambda x: x + class sbit(bit, sbits): def if_else(self, x, y): - return self * (x ^ y) ^ y + return result_conv(x, y)(self * (x ^ y) ^ y) class cbit(bit, cbits): pass @@ -350,3 +401,86 @@ class DynamicArray(Array): sbits.dynamic_array = DynamicArray cbits.dynamic_array = Array + +class sbitint(_bitint, _number, sbits): + n_bits = None + bin_type = None + types = {} + @classmethod + def get_type(cls, n): + if n in cls.types: + return cls.types[n] + sbits_type = sbits.get_type(n) + class _(sbitint, sbits_type): + # n_bits is used by _bitint + n_bits = n + bin_type = sbits_type + _.__name__ = 'sbitint' + str(n) + cls.types[n] = _ + return _ + @classmethod + def new(cls, value=None, n=None): + return cls.get_type(n)(value) + def set_length(*args): + pass + @classmethod + def bit_compose(cls, bits): + # truncate and extend bits + bits = bits[:cls.n] + bits += [0] * (cls.n - len(bits)) + return super(sbitint, cls).bit_compose(bits) + def force_bit_decompose(self, n_bits=None): + return sbits.bit_decompose(self, n_bits) + def TruncMul(self, other, k, m, kappa=None): + self_bits = self.bit_decompose() + other_bits = other.bit_decompose() + if len(self_bits) + len(other_bits) != k: + raise Exception('invalid parameters for TruncMul: ' + 'self:%d, other:%d, k:%d' % + (len(self_bits), len(other_bits), k)) + t = self.get_type(k) + a = t.bit_compose(self_bits + [self_bits[-1]] * (k - len(self_bits))) + b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits))) + product = a * b + res_bits = product.bit_decompose()[m:k] + return self.bit_compose(res_bits) + def Norm(self, k, f, kappa=None, simplex_flag=False): + absolute_val = abs(self) + #next 2 lines actually compute the SufOR for little indian encoding + bits = absolute_val.bit_decompose(k)[::-1] + suffixes = floatingpoint.PreOR(bits)[::-1] + z = [0] * k + for i in range(k - 1): + z[i] = suffixes[i] - suffixes[i+1] + z[k - 1] = suffixes[k-1] + z.reverse() + t2k = self.get_type(2 * k) + acc = t2k.bit_compose(z) + sign = self.bit_decompose()[-1] + signed_acc = sign.if_else(-acc, acc) + absolute_val_2k = t2k.bit_compose(absolute_val.bit_decompose()) + part_reciprocal = absolute_val_2k * acc + return part_reciprocal, signed_acc + def extend(self, n): + bits = self.bit_decompose() + bits += [bits[-1]] * (n - len(bits)) + return self.get_type(n).bit_compose(bits) + +class sbitfix(_fix): + float_type = type(None) + clear_type = staticmethod(lambda x: x) + @classmethod + def set_precision(cls, f, k=None): + super(cls, sbitfix).set_precision(f, k) + cls.int_type = sbitint.get_type(cls.k) + def __xor__(self, other): + return type(self)(self.v ^ other.v) + def __mul__(self, other): + if isinstance(other, sbit): + return type(self)(self.int_type(other * self.v)) + else: + return super(sbitfix, self).__mul__(other) + __rxor__ = __xor__ + __rmul__ = __mul__ + +sbitfix.set_precision(20, 41) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 145ad2aa..83923d63 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -1,6 +1,7 @@ # (C) 2018 University of Bristol, Bar-Ilan University. See License.txt import compilerLib, program, instructions, types, library, floatingpoint +import GC.types import inspect from config import * from compilerLib import run @@ -10,8 +11,9 @@ from compilerLib import run compilerLib.VARS = {} instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] -instr_classes += [t[1] for t in inspect.getmembers(types, inspect.isclass)\ - if t[1].__module__ == types.__name__] +for mod in (types, GC.types): + instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ + if t[1].__module__ == mod.__name__] instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ if t[1].__module__ == library.__name__] diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 96163a5a..9a552b38 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -164,6 +164,10 @@ class Merger: def do_merge(self, merges_iter): """ Merge an iterable of nodes in G, returning the number of merged instructions and the index of the merged instruction. """ + # sort merges, necessary for inputb + merge = list(merges_iter) + merge.sort() + merges_iter = iter(merge) instructions = self.instructions mergecount = 0 try: diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 13691bd5..6c7ea18e 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -64,14 +64,11 @@ def ld2i(c, n): inverse_of_two = {} -def divide_by_two(res, x): +def divide_by_two(res, x, m=1): """ Faster clear division by two using a cached value of 2^-1 mod p """ - from program import Program - import types - block = Program.prog.curr_block - if len(inverse_of_two) == 0 or block not in inverse_of_two: - inverse_of_two[block] = types.cint(1) / 2 - mulc(res, x, inverse_of_two[block]) + tmp = program.curr_block.new_reg('c') + inv2m(tmp, m) + mulc(res, x, tmp) def LTZ(s, a, k, kappa): """ @@ -104,8 +101,7 @@ def Trunc(d, a, k, m, kappa, signed): Mod2m(a_prime, a, k, m, kappa, signed) subs(t, a, a_prime) ldi(c[1], 1) - ld2i(c2m, m) - divc(c[2], c[1], c2m) + divide_by_two(c[2], c[1], m) mulm(d, t, c[2]) def TruncRoundNearest(a, k, m, kappa): diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index c52ab09b..02f9d0b6 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -70,7 +70,7 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ if prog.main_thread_running: prog.update_req(prog.curr_tape) print 'Program requires:', repr(prog.req_num) - print 'Cost:', prog.req_num.cost() + print 'Cost:', 0 if prog.req_num is None else prog.req_num.cost() print 'Memory size:', prog.allocated_mem # finalize the memory diff --git a/Compiler/config.py b/Compiler/config.py index 913e899c..5b93fe48 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -18,9 +18,9 @@ P_VALUES = { -1: 2147483713, \ 256: 57896044624266469032429686755131815517604980759976795324963608525438406557697, \ 512: 6703903964971298549787012499123814115273848577471136527425966013026501536706464354255445443244279389455058889493431223951165286470575994074291745908195329 } -BIT_LENGTHS = { -1: 24, - 32: 24, - 64: 32, +BIT_LENGTHS = { -1: 32, + 32: 16, + 64: 16, 128: 64, 256: 64, 512: 64 } diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 336032be..39298afc 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -57,7 +57,7 @@ def bits(a,m): c[1][0] = a for i in range(1,m): subc(c[0][i], c[1][i-1], res[i-1]) - divci(c[1][i], c[0][i], 2) + comparison.divide_by_two(c[1][i], c[0][i]) modci(res[i], c[1][i], 2) return res diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 3b2f64bd..278eaa1c 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -295,7 +295,7 @@ class reqbl(base.Instruction): code = base.opcodes['REQBL'] arg_format = ['int'] -class time(base.Instruction): +class time(base.IOInstruction): r""" Output epoch time. """ code = base.opcodes['TIME'] arg_format = [] @@ -447,6 +447,12 @@ class modc(base.Instruction): def execute(self): self.args[0].value = self.args[1].value % self.args[2].value +@base.vectorize +class inv2m(base.Instruction): + __slots__ = [] + code = base.opcodes['INV2M'] + arg_format = ['cw','int'] + @base.vectorize class legendrec(base.Instruction): r""" Clear Legendre symbol computation, $c_i = (c_j / p)$. """ @@ -936,6 +942,12 @@ class print_int(base.IOInstruction): code = base.opcodes['PRINTINT'] arg_format = ['ci'] +@base.vectorize +class print_float_plain(base.IOInstruction): + __slots__ = [] + code = base.opcodes['PRINTFLOATPLAIN'] + arg_format = ['c', 'c', 'c', 'c'] + class print_char(base.IOInstruction): r""" Print a single character to stdout. """ code = base.opcodes['PRINTCHR'] diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 096a2a27..ad0c0ab6 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -81,6 +81,7 @@ opcodes = dict( MODCI = 0x37, LEGENDREC = 0x38, DIGESTC = 0x39, + INV2M = 0x3a, GMULBITC = 0x136, GMULBITM = 0x137, # Open @@ -159,8 +160,8 @@ opcodes = dict( PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, - WRITEFILESHARE = 0xBD, - READFILESHARE = 0xBE, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, GBITDEC = 0x184, GBITCOM = 0x185, # Secure socket @@ -406,7 +407,7 @@ class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): super(ImmediateModpAF, cls).check(arg) - if arg >= 2**31 or arg < -2**31: + if arg >= 2**32 or arg < -2**32: raise ArgumentError(arg, 'Immediate value outside of 32-bit range') class ImmediateGF2NAF(IntArgFormat): diff --git a/Compiler/library.py b/Compiler/library.py index 2c888497..1cbd7119 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -77,32 +77,7 @@ def print_str(s, *args): else: raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfix): - # print decimal representation of a clear fixed point number - # number is encoded as [left].[right] - left = val.v - sign = -1 * (val.v < 0) + 1 * (val.v >= 0) - positive_left = cint(sign) * left - right = positive_left % 2**val.f - @if_(sign == -1) - def block(): - print_str('-') - cint((positive_left - right + 1) >> val.f).print_reg_plain() - x = 0 - max_dec_base = 8 # max 32-bit precision - last_nonzero = 0 - for i,b in enumerate(reversed(right.bit_decompose(val.f))): - x += b * int(10**max_dec_base / 2**(i + 1)) - v = x - for i in range(max_dec_base): - t = v % 10 - b = (t > 0) - last_nonzero = (1 - b) * last_nonzero + b * i - v = (v - t) / 10 - print_plain_str('.') - @for_range(max_dec_base - 1 - last_nonzero) - def f(i): - print_str('0') - x.print_reg_plain() + val.print_plain() elif isinstance(val, sfix) or isinstance(val, sfloat): raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfloat): @@ -1269,33 +1244,34 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False): Goldschmidt method as presented in Catrina10, """ theta = int(ceil(log(k/3.5) / log(2))) - alpha = two_power(2*f) - w = AppRcr(b, k, f, kappa, simplex_flag) - x = alpha - b * w + alpha = b.get_type(2 * k).two_power(2*f) + w = AppRcr(b, k, f, kappa, simplex_flag).extend(2 * k) + x = alpha - b.extend(2 * k) * w - y = a * w - y = TruncPr(y, 2*k, f, kappa) + y = a.extend(2 *k) * w + y = y.TruncPr(2*k, f, kappa) for i in range(theta): - y = y * (alpha + x) + x = x.extend(2 * k) + y = y.extend(2 * k) * (alpha + x).extend(2 * k) x = x * x - y = TruncPr(y, 2*k, 2*f, kappa) - x = TruncPr(x, 2*k, 2*f, kappa) + y = y.TruncPr(2*k, 2*f, kappa) + x = x.TruncPr(2*k, 2*f, kappa) - y = y * (alpha + x) - y = TruncPr(y, 2*k, 2*f, kappa) + y = y.extend(2 * k) * (alpha + x).extend(2 * k) + y = y.TruncPr(k + 2*f, 2*f, kappa) return y def AppRcr(b, k, f, kappa, simplex_flag=False): """ Approximate reciprocal of [b]: Given [b], compute [1/b] """ - alpha = cint(int(2.9142 * 2**k)) - c, v = Norm(b, k, f, kappa, simplex_flag) + alpha = b.get_type(2 * k)(int(2.9142 * 2**k)) + c, v = b.Norm(k, f, kappa, simplex_flag) #v should be 2**{k - m} where m is the length of the bitwise repr of [b] d = alpha - 2 * c w = d * v - w = TruncPr(w, 2 * k, 2 * (k - f)) + w = w.TruncPr(2 * k, 2 * (k - f)) # now w * 2 ^ {-f} should be an initial approximation of 1/b return w @@ -1315,7 +1291,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding - bits = absolute_val.bit_decompose(k)[::-1] + bits = absolute_val.bit_decompose(k, kappa)[::-1] suffixes = PreOR(bits)[::-1] z = [0] * k diff --git a/Compiler/oram.py b/Compiler/oram.py index a8e711dd..3923fb7d 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1148,7 +1148,7 @@ class TreeORAM(AbstractORAM): Program.prog.curr_tape.\ start_new_basicblock(name='read_and_remove-%d-end' % self.size) return [MemValue(v) for v in read_value], MemValue(read_empty) - def add(self, entry, state=None, evict=None): + def add(self, entry, state=None, evict=True): if state is None: state = self.state.read() #print_reg(cint(0), 'add') @@ -1160,9 +1160,10 @@ class TreeORAM(AbstractORAM): *(self.value_type(i.read()) for i in entry.x)) maybe_stop_timer(4) #print 'pre-evict', self - maybe_start_timer(5) - self.evict() - maybe_stop_timer(5) + if evict: + maybe_start_timer(5) + self.evict() + maybe_stop_timer(5) #print 'post-evict', self def evict(self): #print 'evict root', id(self) @@ -1604,9 +1605,10 @@ class PackedORAMWithEmpty(AbstractORAM, PackedIndexStructure): def _read(self, index): res = PackedIndexStructure.__getitem__(self, index) return res[1:], 1 - res[0] - def access(self, index, new_value, write, new_empty=False): + def access(self, index, new_value, write, new_empty=False, evict=True): res = PackedIndexStructure.access(self, index, (1 - new_empty,) + \ - tuplify(new_value), write) + tuplify(new_value), write, \ + evict=evict) return res[1:], 1 - res[0] def read_and_maybe_remove(self, index): return self.read(index), 0 diff --git a/Compiler/program.py b/Compiler/program.py index 74cc4627..d24aa250 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -45,7 +45,7 @@ class Program(object): self.param = param self.bit_length = BIT_LENGTHS[param] print 'Default bit length:', self.bit_length - self.security = STAT_SEC[param] + self.security = 40 print 'Default security parameter:', self.security self.galois_length = int(options.galois) print 'Galois length:', self.galois_length @@ -385,6 +385,11 @@ class Program(object): self.security = security print 'Changed statistical security for comparison etc. to', security + def optimize_for_gc(self): + from Compiler.GC.instructions import * + self.to_merge = [ldmsdi, stmsdi, ldmsd, stmsd, stmsdci, xors, andrs] + self.stop_class = type(None) + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program, param=-1): diff --git a/Compiler/types.py b/Compiler/types.py index 1e61d6a5..2da7e727 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -111,10 +111,6 @@ def read_mem_value(operation): class _number(object): - @staticmethod - def bit_compose(bits): - return sum(b << i for i,b in enumerate(bits)) - def square(self): return self * self @@ -231,6 +227,10 @@ class _register(Tape.Register, _number): def prep_res(cls, other): return cls() + @staticmethod + def bit_compose(bits): + return sum(b << i for i,b in enumerate(bits)) + def __init__(self, reg_type, val, size): super(_register, self).__init__(reg_type, program.curr_tape, size=size) if isinstance(val, (int, long)): @@ -241,6 +241,9 @@ class _register(Tape.Register, _number): def sizeof(self): return self.size + def extend(self, n): + return self + class _clear(_register): __slots__ = [] @@ -971,6 +974,7 @@ class sint(_secret, _int): PreOp = staticmethod(floatingpoint.PreOpL) PreOR = staticmethod(floatingpoint.PreOR) + get_type = staticmethod(lambda n: sint) @vectorized_classmethod def get_random_int(cls, bits): @@ -1159,6 +1163,19 @@ class sint(_secret, _int): security = security or program.security return floatingpoint.BitDec(self, bit_length, bit_length, security) + def TruncMul(self, other, k, m, kappa=None): + return floatingpoint.TruncPr(self * other, k, m, kappa) + + def TruncPr(self, k, m, kappa=None): + return floatingpoint.TruncPr(self, k, m, kappa) + + def Norm(self, k, f, kappa=None, simplex_flag=False): + return library.Norm(self, k, f, kappa, simplex_flag) + + @staticmethod + def two_power(n): + return floatingpoint.two_power(n) + class sgf2n(_secret, _gf2n): __slots__ = [] instruction_type = 'gf2n' @@ -1276,24 +1293,22 @@ for t in (sint, sgf2n): t.default_type = t -class sgf2nint(sgf2n): +class _bitint(object): bits = None + log_rounds = False @classmethod - def compose(cls, bits): - bits = list(bits) - if len(bits) > cls.n_bits: - raise CompilerError('Too many bits') - res = cls() - res.bits = bits + [0] * (cls.n_bits - len(bits)) - gmovs(res, sum(b << i for i,b in enumerate(bits))) - return res - - @staticmethod - def bit_adder(a, b): + def bit_adder(cls, a, b): a, b = list(a), list(b) a += [0] * (len(b) - len(a)) b += [0] * (len(a) - len(b)) + if cls.log_rounds: + return cls.carry_lookahead_adder(a, b) + else: + return cls.carry_select_adder(a, b) + + @staticmethod + def carry_lookahead_adder(a, b): lower = [] for (ai,bi) in zip(a,b): if ai is 0 or bi is 0: @@ -1311,10 +1326,50 @@ class sgf2nint(sgf2n): carries = [] return lower + [ai + bi + carry for (ai,bi,carry) in zip(a,b,carries)] + @classmethod + def carry_select_adder(cls, a, b): + n = len(a) + for m in range(100): + if sum(range(m + 1)) + 1 >= n: + break + for k in range(m, 0, -1): + if sum(range(m, k - 1, -1)) + 1 >= n: + break + blocks = range(m, k, -1) + blocks.append(n - sum(blocks)) + blocks.reverse() + #print 'blocks:', blocks + if len(blocks) > 1 and blocks[0] > blocks[1]: + raise Exception('block size not increasing:', blocks) + if sum(blocks) != n: + raise Exception('blocks not summing up: %s != %s' % \ + (sum(blocks), n)) + res = [] + carry = 0 + for m in blocks: + aa = a[:m] + bb = b[:m] + a = a[m:] + b = b[m:] + cc = [cls.ripple_carry_adder(aa, bb, i) for i in (0,1)] + for i in range(m): + res.append(util.if_else(carry, cc[1][i], cc[0][i])) + carry = util.if_else(carry, cc[1][m], cc[0][m]) + return res + + @classmethod + def ripple_carry_adder(cls, a, b, carry=0): + res = [] + for aa, bb in zip(a, b): + cc, carry = cls.full_adder(aa, bb, carry) + res.append(cc) + res.append(carry) + return res + @staticmethod def full_adder(a, b, carry): s = a + b - return s + carry, util.or_op(a * b, s * carry) + return s + carry, util.if_else(s, carry, a) @staticmethod def half_adder(a, b): @@ -1336,35 +1391,36 @@ class sgf2nint(sgf2n): def load_int(self, other): if -2**(self.n_bits-1) <= other < 2**(self.n_bits-1): - sgf2n.load_int(self, other + 2**self.n_bits if other < 0 else other) + self.bin_type.load_int(self, other + 2**self.n_bits \ + if other < 0 else other) else: raise CompilerError('Invalid signed %d-bit integer: %d' % \ (self.n_bits, other)) - def load_other(self, other): - if isinstance(other, sgf2nint): - gmovs(self, self.compose(other.bit_decompose(self.n_bits))) - elif isinstance(other, sgf2n): - gmovs(self, other) - else: - gaddm(self, sgf2n(0), cgf2n(other)) - def add(self, other): - if type(other) == sgf2n: + if type(other) == self.bin_type: raise CompilerError('Unclear addition') a = self.bit_decompose() b = util.bit_decompose(other, self.n_bits) return self.compose(self.bit_adder(a, b)) def mul(self, other): - if type(other) == sgf2n: + if type(other) == self.bin_type: raise CompilerError('Unclear multiplication') self_bits = self.bit_decompose() if isinstance(other, (int, long)): other_bits = util.bit_decompose(other, self.n_bits) bit_matrix = [[x * y for y in self_bits] for x in other_bits] else: - other = sgf2n(other) + try: + other_bits = other.bit_decompose() + if len(other_bits) == 1: + return type(self)(other_bits[0] * self) + if len(self_bits) != len(other_bits): + raise NotImplementedError('Multiplication of different lengths') + except AttributeError: + pass + other = self.bin_type(other) products = [x * other for x in self_bits] bit_matrix = [util.bit_decompose(x, self.n_bits) for x in products] columns = [filter(None, (bit_matrix[j][i-j] \ @@ -1418,7 +1474,7 @@ class sgf2nint(sgf2n): def bit_decompose(self, n_bits=None, *args): if self.bits is None: - self.bits = sgf2n(self).bit_decompose(self.n_bits) + self.bits = self.force_bit_decompose(self.n_bits) if n_bits is None: return self.bits[:] else: @@ -1460,9 +1516,51 @@ class sgf2nint(sgf2n): def __gt__(self, other): return 1 - (self <= other) + def __eq__(self, other): + diff = self ^ other + diff_bits = [1 - x for x in diff.bit_decompose()] + return floatingpoint.KMul(diff_bits) + + def __ne__(self, other): + return 1 - (self == other) + def __neg__(self): return 1 + self.compose(1 ^ b for b in self.bit_decompose()) + def __abs__(self): + return self.bit_decompose()[-1].if_else(-self, self) + + less_than = lambda self, other, *args, **kwargs: self < other + greater_than = lambda self, other, *args, **kwargs: self > other + less_equal = lambda self, other, *args, **kwargs: self <= other + greater_equal = lambda self, other, *args, **kwargs: self >= other + equal = lambda self, other, *args, **kwargs: self == other + not_equal = lambda self, other, *args, **kwargs: self != other + +class sgf2nint(_bitint, sgf2n): + bin_type = sgf2n + + @classmethod + def compose(cls, bits): + bits = list(bits) + if len(bits) > cls.n_bits: + raise CompilerError('Too many bits') + res = cls() + res.bits = bits + [0] * (cls.n_bits - len(bits)) + gmovs(res, sum(b << i for i,b in enumerate(bits))) + return res + + def load_other(self, other): + if isinstance(other, sgf2nint): + gmovs(self, self.compose(other.bit_decompose(self.n_bits))) + elif isinstance(other, sgf2n): + gmovs(self, other) + else: + gaddm(self, sgf2n(0), cgf2n(other)) + + def force_bit_decompose(self, n_bits=None): + return sgf2n(self).bit_decompose(n_bits) + class sgf2nuint(sgf2nint): def load_int(self, other): if 0 <= other < 2**self.n_bits: @@ -1777,11 +1875,22 @@ class cfix(_number): else: raise TypeError('Incompatible fixed point types in division') -class sfix(_number): + def print_plain(self): + sign = self.v < 0 + abs_v = sign.if_else(-self.v, self.v) + print_float_plain(cint(abs_v), cint(self.f - self.k + 1), \ + cint(0), cint(sign)) + +class _fix(_number): """ Shared fixed point type. """ __slots__ = ['v', 'f', 'k', 'size'] - reg_type = 's' kappa = 40 + + @property + @classmethod + def reg_type(cls): + return cls.int_type.reg_type + @classmethod def set_precision(cls, f, k = None): cls.f = f @@ -1789,6 +1898,8 @@ class sfix(_number): if k is None: cls.k = 2 * f else: + if k < f: + raise CompilerError('bit length cannot be less than precision') cls.k = k def conv(self): @@ -1798,14 +1909,14 @@ class sfix(_number): def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of n values input by a client. Assumes client has already run bit shift to convert fixed point to integer.""" - sint_inputs = sint.receive_from_client(n, client_id, ClientMessageType.TripleShares) - return map(sfix, sint_inputs) + sint_inputs = cls.int_type.receive_from_client(n, client_id, ClientMessageType.TripleShares) + return map(cls, sint_inputs) @vectorized_classmethod def load_mem(cls, address, mem_type=None): res = [] - res.append(sint.load_mem(address)) - return sfix(*res) + res.append(cls.int_type.load_mem(address)) + return cls(*res) @classmethod def load_sint(cls, v): @@ -1820,27 +1931,29 @@ class sfix(_number): k = self.k # warning: don't initialize a sfix from a sint, this is only used in internal methods; # for external initialization use load_int. - if isinstance(_v, sint): + if _v is None: + self.v = self.int_type(0) + elif isinstance(_v, self.int_type): self.v = _v elif isinstance(_v, cfix.scalars): - self.v = sint(int(round(_v * (2 ** f))), size=self.size) - elif isinstance(_v, sfloat): + self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size) + elif isinstance(_v, self.float_type): p = (f + _v.p) b = (p >= 0) a = b*(_v.v << (p)) + (1-b)*(_v.v >> (-p)) self.v = (1-2*_v.s)*a - elif isinstance(_v, sfix): + elif isinstance(_v, type(self)): self.v = _v.v elif isinstance(_v, MemFix): #this is a memvalue object - self.v = _v.v - # elif _v == None: - # self.v = sint(0) - self.kappa = sfix.kappa + self.v = _v @vectorize def load_int(self, v): - self.v = sint(v) * (2**self.f) + self.v = self.int_type(v) << self.f + + def conv(self): + return self def store_in_mem(self, address): self.v.store_in_mem(address) @@ -1851,25 +1964,28 @@ class sfix(_number): @vectorize def add(self, other): other = parse_type(other) - if isinstance(other, (sfix, cfix)): - return sfix(self.v + other.v) + if isinstance(other, (_fix, cfix)): + return type(self)(self.v + other.v) elif isinstance(other, cfix.scalars): tmp = cfix(other) return self + tmp else: - raise CompilerError('Invalid type %s for sfix.__add__' % type(other)) + raise CompilerError('Invalid type %s for _fix.__add__' % type(other)) @vectorize def mul(self, other): other = parse_type(other) - if isinstance(other, (sfix, cfix)): - val = floatingpoint.TruncPr(self.v * other.v, self.k * 2, self.f, self.kappa) - return sfix(val) + if isinstance(other, _fix): + val = self.v.TruncMul(other.v, self.k * 2, self.f, self.kappa) + return type(self)(val) + elif isinstance(other, cfix): + res = type(self)((self.v * other.v) >> self.f) + return res elif isinstance(other, cfix.scalars): scalar_fix = cfix(other) return self * scalar_fix else: - raise CompilerError('Invalid type %s for sfix.__mul__' % type(other)) + raise CompilerError('Invalid type %s for _fix.__mul__' % type(other)) @vectorize def __sub__(self, other): @@ -1878,7 +1994,7 @@ class sfix(_number): @vectorize def __neg__(self): - return sfix(-self.v) + return type(self)(-self.v) def __rsub__(self, other): return -self + other @@ -1886,7 +2002,7 @@ class sfix(_number): @vectorize def __eq__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.equal(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1894,7 +2010,7 @@ class sfix(_number): @vectorize def __le__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.less_equal(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1902,7 +2018,7 @@ class sfix(_number): @vectorize def __lt__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.less_than(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1910,7 +2026,7 @@ class sfix(_number): @vectorize def __ge__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.greater_equal(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1918,7 +2034,7 @@ class sfix(_number): @vectorize def __gt__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.greater_than(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1926,7 +2042,7 @@ class sfix(_number): @vectorize def __ne__(self, other): other = parse_type(other) - if isinstance(other, (cfix, sfix)): + if isinstance(other, (cfix, _fix)): return self.v.not_equal(other.v, self.k, self.kappa) else: raise NotImplementedError @@ -1934,20 +2050,24 @@ class sfix(_number): @vectorize def __div__(self, other): other = parse_type(other) - if isinstance(other, sfix): - return sfix(library.FPDiv(self.v, other.v, self.k, self.f, self.kappa)) + if isinstance(other, _fix): + return type(self)(library.FPDiv(self.v, other.v, self.k, self.f, self.kappa)) elif isinstance(other, cfix): - return sfix(library.sint_cint_division(self.v, other.v, self.k, self.f, self.kappa)) + return type(self)(library.sint_cint_division(self.v, other.v, self.k, self.f, self.kappa)) else: raise TypeError('Incompatible fixed point types in division') @vectorize def compute_reciprocal(self): - return sfix(library.FPDiv(cint(2) ** self.f, self.v, self.k, self.f, self.kappa, True)) + return type(self)(library.FPDiv(cint(2) ** self.f, self.v, self.k, self.f, self.kappa, True)) def reveal(self): val = self.v.reveal() - return cfix(val) + return self.clear_type(val) + +class sfix(_fix): + int_type = sint + clear_type = cfix # this is for 20 bit decimal precision # with 40 bitlength of entire number @@ -2253,6 +2373,8 @@ class cfloat(object): def print_float_plain(self): print_float_plain(self.v, self.p, self.z, self.s) +sfix.float_type = sfloat + _types = { 'c': cint, 's': sint, @@ -2271,6 +2393,7 @@ class Array(object): self.value_type = value_type if address is None: self.address = self._malloc() + self.address_cache = {} def _malloc(self): return program.malloc(self.length, self.value_type) @@ -2285,7 +2408,9 @@ class Array(object): if index >= self.length or index < 0: raise IndexError('index %s, length %s' % \ (str(index), str(self.length))) - return self.address + index + if (program.curr_block, index) not in self.address_cache: + self.address_cache[program.curr_block, index] = self.address + index + return self.address_cache[program.curr_block, index] def get_slice(self, index): if index.stop is None and self.length is None: @@ -2350,61 +2475,49 @@ class Array(object): self[i] = mem_value return self + def get_mem_value(self, index): + return MemValue(self[index], self.get_address(index)) + sint.dynamic_array = Array sgf2n.dynamic_array = Array -class Matrix(object): - def __init__(self, rows, columns, value_type, address=None): - self.rows = rows - self.columns = columns - if value_type in _types: - value_type = _types[value_type] - self.value_type = value_type - self.address = Array(rows * columns, value_type, address).address - - def __getitem__(self, index): - return Array(self.columns, self.value_type, \ - self.address + index * self.columns) - - def __len__(self): - return self.rows - - def assign_all(self, value): - @library.for_range(len(self)) - def f(i): - self[i].assign_all(value) - return self - - def get_address(self): - return self.address - class SubMultiArray(object): def __init__(self, sizes, value_type, address, index): self.sizes = sizes self.value_type = value_type self.address = address + index * reduce(operator.mul, self.sizes) + self.sub_cache = {} def __getitem__(self, index): - if len(self.sizes) == 2: - return Array(self.sizes[1], self.value_type, \ - self.address + index * self.sizes[0]) - else: - return SubMultiArray(self.sizes[1:], self.value_type, \ - self.address, index) + if index not in self.sub_cache: + if len(self.sizes) == 2: + self.sub_cache[index] = \ + Array(self.sizes[1], self.value_type, \ + self.address + index * self.sizes[1]) + else: + self.sub_cache[index] = \ + SubMultiArray(self.sizes[1:], self.value_type, \ + self.address, index) + return self.sub_cache[index] -class MultiArray(object): + def assign_all(self, value): + @library.for_range(self.sizes[0]) + def f(i): + self[i].assign_all(value) + return self + +class MultiArray(SubMultiArray): def __init__(self, sizes, value_type): - self.sizes = sizes - self.value_type = value_type self.array = Array(reduce(operator.mul, sizes), \ value_type) + SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0) if len(sizes) < 2: raise CompilerError('Use Array') - def __getitem__(self, index): - return SubMultiArray(self.sizes[1:], self.value_type, \ - self.array.address, index) +class Matrix(MultiArray): + def __init__(self, rows, columns, value_type): + MultiArray.__init__(self, [rows, columns], value_type) class VectorArray(object): def __init__(self, length, value_type, vector_size, address=None): @@ -2536,7 +2649,7 @@ class _mem(_number): class MemValue(_mem): __slots__ = ['last_write_block', 'reg_type', 'register', 'address', 'deleted'] - def __init__(self, value): + def __init__(self, value, address=None): self.last_write_block = None if isinstance(value, int): self.value_type = regint @@ -2546,9 +2659,12 @@ class MemValue(_mem): else: self.value_type = type(value) self.reg_type = self.value_type.reg_type - self.address = program.malloc(1, self.value_type) self.deleted = False - self.write(value) + if address is None: + self.address = program.malloc(1, self.value_type) + self.write(value) + else: + self.address = address def delete(self): program.free(self.address, self.reg_type) @@ -2631,7 +2747,14 @@ class MemFloat(_mem): class MemFix(_mem): def __init__(self, *args): - value = sfix(*args) + arg_type = type(*args) + if arg_type == sfix: + value = sfix(*args) + elif arg_type == cfix: + value = cfix(*args) + else: + raise CompilerError('MemFix init argument error') + self.reg_type = value.v.reg_type self.v = MemValue(value.v) def write(self, *args): diff --git a/GC/FakeSecret.cpp b/GC/FakeSecret.cpp index 4206cec2..ecaa3e53 100644 --- a/GC/FakeSecret.cpp +++ b/GC/FakeSecret.cpp @@ -55,4 +55,22 @@ void FakeSecret::store_clear_in_dynamic(Memory& mem, mem[access.address] = access.value; } +FakeSecret FakeSecret::input(int from, ifstream& input_file, int n_bits) +{ + long long int in; + input_file >> in; + return input(from, in, n_bits); +} + +FakeSecret FakeSecret::input(int from, const int128& input, int n_bits) +{ + (void)from; + (void)n_bits; + FakeSecret res; + res.a = ((__uint128_t)input.get_upper() << 64) + input.get_lower(); + if (res.a > ((__uint128_t)1 << n_bits)) + throw out_of_range("input too large"); + return res; +} + } /* namespace GC */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 7ddb1390..eab3c8b8 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -16,6 +16,7 @@ #include "Math/gf2n.h" #include +#include namespace GC { @@ -44,6 +45,12 @@ public: template static void andrs(T& processor, const vector& args) { processor.andrs(args); } + template + static void inputb(T& processor, const vector& args) + { processor.input(args); } + + static FakeSecret input(int from, ifstream& input_file, int n_bits); + static FakeSecret input(int from, const int128& input, int n_bits); FakeSecret() : a(0) {} FakeSecret(const Integer& x) : a(x.get()) {} @@ -67,9 +74,12 @@ public: void xor_(int n, const FakeSecret& x, const T& y) { (void)n; a = x.a ^ y.a; } void andrs(int n, const FakeSecret& x, const FakeSecret& y) { (void)n; a = x.a * y.a; } + void random_bit() { a = random() % 2; } void reveal(Clear& x) { x = a; } + + int size() { return -1; } }; } /* namespace GC */ diff --git a/GC/Instruction.cpp b/GC/Instruction.cpp index b519aaff..445f6b82 100644 --- a/GC/Instruction.cpp +++ b/GC/Instruction.cpp @@ -19,6 +19,9 @@ #include "GC/Instruction_inline.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + namespace GC { @@ -78,6 +81,7 @@ template int GC::Instruction::get_max_reg(int reg_type) const { int skip; + int offset = 0; switch (opcode) { case LDMSD: @@ -88,13 +92,22 @@ int GC::Instruction::get_max_reg(int reg_type) const case STMSDI: skip = 2; break; + case ANDRS: + case XORS: + skip = 4; + offset = 1; + break; + case INPUTB: + skip = 3; + offset = 2; + break; default: return BaseInstruction::get_max_reg(reg_type); } int m = 0; if (reg_type == SBIT) - for (size_t i = 0; i < start.size(); i += skip) + for (size_t i = offset; i < start.size(); i += skip) m = max(m, start[i] + 1); return m; } @@ -192,8 +205,13 @@ void Instruction::parse(istream& s, int pos) case STMSDCI: case XORS: case ANDRS: + case INPUTB: get_vector(get_int(s), start, s); break; + case PRINTREGSIGNED: + n = get_int(s); + get_ints(r, s, 1); + break; default: ostringstream os; os << "Invalid instruction " << showbase << hex << opcode @@ -224,5 +242,7 @@ template class Instruction< Secret >; template class Instruction< Secret >; template class Instruction< Secret >; template class Instruction< Secret >; +template class Instruction< Secret >; +template class Instruction< Secret >; } /* namespace GC */ diff --git a/GC/Instruction.h b/GC/Instruction.h index 7a45c406..f9e840e7 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -84,6 +84,9 @@ enum CONVCINT = 0x213, REVEAL = 0x214, STMSDCI = 0x215, + INPUTB = 0x216, + // don't write + PRINTREGSIGNED = 0x220, }; } /* namespace GC */ diff --git a/GC/Machine.cpp b/GC/Machine.cpp index ba5bc100..ec036cf7 100644 --- a/GC/Machine.cpp +++ b/GC/Machine.cpp @@ -10,6 +10,9 @@ #include "GC/Program.h" #include "Secret.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + namespace GC { @@ -41,5 +44,7 @@ template class Machine< Secret >; template class Machine< Secret >; template class Machine< Secret >; template class Machine< Secret >; +template class Machine< Secret >; +template class Machine< Secret >; } /* namespace GC */ diff --git a/GC/Memory.cpp b/GC/Memory.cpp index 6718511b..4aecda87 100644 --- a/GC/Memory.cpp +++ b/GC/Memory.cpp @@ -13,6 +13,9 @@ #include #include "Secret.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + namespace GC { @@ -30,6 +33,8 @@ template class Memory< Secret >; template class Memory< Secret >; template class Memory< Secret >; template class Memory< Secret >; +template class Memory< Secret >; +template class Memory< Secret >; template class Memory< AuthValue >; template class Memory< SpdzShare >; diff --git a/GC/Processor.cpp b/GC/Processor.cpp index 565dd393..b74f1bd0 100644 --- a/GC/Processor.cpp +++ b/GC/Processor.cpp @@ -14,6 +14,9 @@ using namespace std; #include "Secret.h" #include "Access.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + namespace GC { @@ -53,6 +56,14 @@ void Processor::reset(const Program& program) machine.reset(program); } +template +void GC::Processor::open_input_file(const string& name) +{ + cout << "opening " << name << endl; + input_file.open(name); + input_file.exceptions(ios::badbit | ios::eofbit); +} + template void Processor::bitdecc(const vector& regs, const Clear& x) { @@ -113,10 +124,17 @@ void GC::Processor::store_dynamic_indirect(const vector& args) complexity += accesses.size() / 2 * T::default_length; } -void check_args(const vector& args, int n) +template +int GC::Processor::check_args(const vector& args, int n) { if (args.size() % n != 0) throw runtime_error("invalid number of arguments"); + int total = 0; + for (size_t i = 0; i < args.size(); i += n) + { + total += args[i]; + } + return total; } template @@ -133,7 +151,8 @@ template void Processor::xors(const vector& args) { check_args(args, 4); - for (size_t i = 0; i < args.size(); i += 4) + size_t n_args = args.size(); + for (size_t i = 0; i < n_args; i += 4) { S[args[i+1]].xor_(args[i], S[args[i+2]], S[args[i+3]]); #ifndef FREE_XOR @@ -153,6 +172,19 @@ void Processor::andrs(const vector& args) } } +template +void Processor::input(const vector& args) +{ + check_args(args, 3); + for (size_t i = 0; i < args.size(); i += 3) + { + S[args[i+2]] = T::input(args[i] + 1, input_file, args[i+1]); +#ifdef DEBUG_INPUT + cout << "input to " << args[i+2] << "/" << &S[args[i+2]] << endl; +#endif + } +} + template void Processor::print_reg(int reg, int n) { @@ -170,6 +202,15 @@ void Processor::print_reg_plain(Clear& value) T::out << hex << showbase << value << dec << flush; } +template +void Processor::print_reg_signed(unsigned n_bits, Clear& value) +{ + unsigned n_shift = 0; + if (n_bits > 1) + n_shift = sizeof(value.get()) * 8 - n_bits; + T::out << dec << (value.get() << n_shift >> n_shift) << flush; +} + template void Processor::print_chr(int n) { @@ -187,5 +228,7 @@ template class Processor< Secret >; template class Processor< Secret >; template class Processor< Secret >; template class Processor< Secret >; +template class Processor< Secret >; +template class Processor< Secret >; } /* namespace GC */ diff --git a/GC/Processor.h b/GC/Processor.h index 938330c3..64b06fa7 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -27,10 +27,13 @@ template class Processor : public ::ProcessorBase { static Processor* singleton; + public: static Processor& s(); static int get_PC(); + static int check_args(const vector& args, int n); + Machine& machine; unsigned int PC; @@ -39,6 +42,8 @@ public: // rough measure for the memory usage size_t complexity; + ifstream input_file; + Memory S; Memory C; Memory I; @@ -47,6 +52,7 @@ public: ~Processor(); void reset(const Program& program); + void open_input_file(const string& name); void bitcoms(T& x, const vector& regs) { x.bitcom(S, regs); } void bitdecs(const vector& regs, const T& x) { x.bitdec(S, regs); } @@ -64,8 +70,11 @@ public: void xors(const vector& args); void andrs(const vector& args); + void input(const vector& args); + void print_reg(int reg, int n); void print_reg_plain(Clear& value); + void print_reg_signed(unsigned n_bits, Clear& value); void print_chr(int n); void print_str(int n); }; diff --git a/GC/Program.cpp b/GC/Program.cpp index 1c110fa7..b293df2c 100644 --- a/GC/Program.cpp +++ b/GC/Program.cpp @@ -11,6 +11,9 @@ #include +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + #ifdef MAX_INLINE #include "Instruction_inline.h" #endif @@ -47,6 +50,16 @@ void Program::compute_constants() } } +template +void Program::parse(const string& bytecode_name) +{ + string filename = "Programs/Bytecode/" + bytecode_name + ".bc"; + ifstream s(filename.c_str()); + if (s.bad() or s.fail()) + throw runtime_error("Cannot open " + filename); + parse(s); +} + template void Program::parse(istream& s) { @@ -57,6 +70,8 @@ void Program::parse(istream& s) CALLGRIND_STOP_INSTRUMENTATION; while (!s.eof()) { + if (s.bad() or s.fail()) + throw runtime_error("error reading program"); instr.parse(s, pos); p.push_back(instr); //cerr << "\t" << instr << endl; @@ -120,5 +135,7 @@ template class Program< Secret >; template class Program< Secret >; template class Program< Secret >; template class Program< Secret >; +template class Program< Secret >; +template class Program< Secret >; } /* namespace GC */ diff --git a/GC/Program.h b/GC/Program.h index b0cebe05..b386eef9 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -46,6 +46,7 @@ class Program Program(); // Read in a program + void parse(const string& bytecoode_name); void parse(istream& s); int get_offline_data_used() const { return offline_data_used; } diff --git a/GC/Secret.cpp b/GC/Secret.cpp index 940ce97e..c4111f12 100644 --- a/GC/Secret.cpp +++ b/GC/Secret.cpp @@ -12,6 +12,9 @@ #include "Secret.h" #include "Secret_inline.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoEvalWire.h" + namespace GC { @@ -36,12 +39,27 @@ ostream& operator<<(ostream& o, const AuthValue& auth_value) return o; } +template +Secret Secret::input(party_id_t from, ifstream& input_file, int n_bits) +{ + long long int in; + if (from == CommonParty::s().get_id()) + { + input_file >> in; + T::check_input(in, n_bits); + } + return input(from, in, n_bits); +} + template Secret Secret::input(party_id_t from, const int128& input, int n_bits) { Secret res; if (n_bits < 0) n_bits = default_length; +#ifdef DEBUG_INPUT + cout << "input " << input << endl; +#endif for (int i = 0; i < n_bits; i++) { res.get_new_reg().input(from, input.get_bit(i)); @@ -50,7 +68,7 @@ Secret Secret::input(party_id_t from, const int128& input, int n_bits) #endif } #ifdef DEBUG_INPUT - cout << endl; + cout << " input" << endl; #endif if ((size_t)n_bits != res.registers.size()) { @@ -58,6 +76,12 @@ Secret Secret::input(party_id_t from, const int128& input, int n_bits) throw runtime_error("wrong bit length in input()"); } #ifdef DEBUG_INPUT + for (auto& reg : res.registers) + cout << (int)reg.get_mask_no_check(); + cout << " mask " << endl; + for (auto& reg : res.registers) + cout << (int)reg.get_external_no_check(); + cout << " ext " << endl; int128 a; res.reveal(a); cout << " input " << hex << a << "(" << res.size() << ") from " << from @@ -106,7 +130,11 @@ void Secret::random(int n_bits, int128 share) template void Secret::random_bit() { +#ifdef NO_INPUT + return random(1, 0); +#else return random(1, CommonParty::s().prng.get_uchar() & 1); +#endif } template @@ -175,7 +203,7 @@ void Secret::store(Memory& mem, } template -void Secret::output(Register& reg) +void Secret::output(T& reg) { cast(reg).output(); } @@ -196,7 +224,7 @@ Secret Secret::carryless_mult(const Secret& x, const Secret& y) { int start = max((size_t)0, i - y.registers.size() + 1); int stop = min(i + 1, x.registers.size()); - Register sum = AND(x.get_reg(start), y.get_reg(i - start)); + T sum = AND(x.get_reg(start), y.get_reg(i - start)); #ifdef DEBUG_DYNAMIC2 output(sum); cout << "carryless " << i << " " << start << " " << i - start << @@ -207,7 +235,7 @@ Secret Secret::carryless_mult(const Secret& x, const Secret& y) #endif for (int j = start + 1; j < stop; j++) { - Register product = AND(x.get_reg(j), y.get_reg(i - j)); + T product = AND(x.get_reg(j), y.get_reg(i - j)); sum = XOR(sum, product); #ifdef DEBUG_DYNAMIC2 cout << "carryless " << @@ -246,18 +274,6 @@ Secret::Secret() } -template -T& GC::Secret::get_reg(int i) -{ - return *reinterpret_cast(®isters.at(i)); -} - -template -const T& GC::Secret::get_reg(int i) const -{ - return *reinterpret_cast(®isters.at(i)); -} - template T& GC::Secret::get_new_reg() @@ -366,7 +382,10 @@ template void Secret::bitdec(Memory& S, const vector& regs) const { if (regs.size() > registers.size()) - throw out_of_range("not enough bits for bit decomposition"); + throw out_of_range( + "not enough bits for bit decomposition: " + + to_string(regs.size()) + " > " + + to_string(registers.size())); for (unsigned int i = 0; i < regs.size(); i++) { Secret& secret = S[regs[i]]; @@ -380,7 +399,8 @@ template void Secret::reveal(U& x) { #ifdef DEBUG_OUTPUT - cout << "output: "; + cout << "revealing " << this << " with min(" << 8 * sizeof(U) << "," + << registers.size() << ") bits" << endl; #endif x = 0; for (unsigned int i = 0; i < min(8 * sizeof(U), registers.size()); i++) @@ -393,7 +413,13 @@ void Secret::reveal(U& x) #endif } #ifdef DEBUG_OUTPUT - cout << endl; + cout << " output" << endl; + for (auto& reg : registers) + cout << (int)reg.get_mask_no_check(); + cout << " mask" << endl; + for (auto& reg: registers) + cout << (int)reg.get_external_no_check(); + cout << " ext" << endl; #endif #ifdef DEBUG_VALUES cout << typeid(T).name() << " " << &x << endl; @@ -405,14 +431,24 @@ void Secret::reveal(U& x) #endif } +template +MAYBE_INLINE void Secret::resize_regs(int n) +{ + registers.resize(n, T::new_reg()); +} + template class Secret; template class Secret; template class Secret; template class Secret; +template class Secret; +template class Secret; template void Secret::reveal(Clear& x); template void Secret::reveal(Clear& x); template void Secret::reveal(Clear& x); template void Secret::reveal(Clear& x); +template void Secret::reveal(Clear& x); +template void Secret::reveal(Clear& x); } /* namespace GC */ diff --git a/GC/Secret.h b/GC/Secret.h index 7ef1e939..713aea1f 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -9,18 +9,30 @@ #define GC_SECRET_H_ #include "BMR/Register.h" -#include "BMR/CommonParty.h" #include "BMR/AndJob.h" #include "GC/Clear.h" #include "GC/Memory.h" #include "GC/Access.h" +#include "GC/Processor.h" #include "Math/Share.h" +#include + namespace GC { +template +inline void XOR(T& res, const T& left, const T& right) +{ +#ifdef FREE_XOR + Secret::cast(res).XOR(Secret::cast(left), Secret::cast(right)); +#else + Secret::cast(res).op(Secret::cast(left), Secret::cast(right), 0x0110); +#endif +} + class AuthValue { public: @@ -51,7 +63,7 @@ public: template class Secret { - CheckVector registers; + CheckVector registers; T& get_new_reg(); @@ -69,10 +81,11 @@ public: static typename T::out_type out; - static T& cast(Register& reg) { return *reinterpret_cast(®); } - static const T& cast(const Register& reg) { return *reinterpret_cast(®); } + static T& cast(T& reg) { return *reinterpret_cast(®); } + static const T& cast(const T& reg) { return *reinterpret_cast(®); } static Secret input(party_id_t from, const int128& input, int n_bits = -1); + static Secret input(party_id_t from, ifstream& input_file, int n_bits = -1); void random(int n_bits, int128 share); void random_bit(); static Secret reconstruct(const int128& x, int length); @@ -81,13 +94,15 @@ public: { T::store_clear_in_dynamic(mem, accesses); } void store(Memory& mem, size_t address); static Secret carryless_mult(const Secret& x, const Secret& y); - static void output(Register& reg); + static void output(T& reg); static void load(vector< ReadAccess< Secret > >& accesses, const Memory& mem); static void store(Memory& mem, vector< WriteAccess< Secret > >& accesses); static void andrs(Processor< Secret >& processor, const vector& args) { T::andrs(processor, args); } + static void inputb(Processor< Secret >& processor, const vector& args) + { T::inputb(processor, args); } Secret(); Secret(const Integer& x) { *this = x; } @@ -105,19 +120,24 @@ public: Secret operator+(const Secret x) const; Secret& operator+=(const Secret x) { *this = *this + x; return *this; } - void xor_(int n, const Secret& x, const Secret& y); + void xor_(int n, const Secret& x, const Secret& y) + { + resize_regs(n); + for (int i = 0; i < n; i++) + XOR(registers[i], x.get_reg(i), y.get_reg(i)); + } void andrs(int n, const Secret& x, const Secret& y); template void reveal(U& x); int size() const { return registers.size(); } - CheckVector& get_regs() { return registers; } - const CheckVector& get_regs() const { return registers; } + CheckVector& get_regs() { return registers; } + const CheckVector& get_regs() const { return registers; } - const T& get_reg(int i) const; - T& get_reg(int i); - void resize_regs(int n) { registers.resize(n, T::new_reg()); } + const T& get_reg(int i) const { return *reinterpret_cast(®isters.at(i)); } + T& get_reg(int i) { return *reinterpret_cast(®isters.at(i)); } + void resize_regs(int n); }; template diff --git a/GC/Secret_inline.h b/GC/Secret_inline.h index 075506c3..794d3858 100644 --- a/GC/Secret_inline.h +++ b/GC/Secret_inline.h @@ -19,25 +19,15 @@ namespace GC { template -inline void XOR(Register& res, const Register& left, const Register& right) +inline T XOR(const T& left, const T& right) { -#ifdef FREE_XOR - Secret::cast(res).XOR(Secret::cast(left), Secret::cast(right)); -#else - Secret::cast(res).op(Secret::cast(left), Secret::cast(right), 0x0110); -#endif -} - -template -inline Register XOR(const Register& left, const Register& right) -{ - Register res(T::new_reg()); + T res(T::new_reg()); XOR(res, left, right); return res; } template -inline void AND(Register& res, const Register& left, const Register& right) +inline void AND(T& res, const T& left, const T& right) { #ifdef KEY_SIGNAL #ifdef DEBUG_REGS @@ -49,9 +39,9 @@ inline void AND(Register& res, const Register& left, const Register& right) } template -inline Register AND(const Register& left, const Register& right) +inline T AND(const T& left, const T& right) { - Register res = T::new_reg(); + T res = T::new_reg(); AND(res, left, right); return res; } @@ -60,30 +50,20 @@ template inline Secret GC::Secret::operator+(const Secret x) const { Secret res; - res.xor_(max(registers.size(), x.registers.size()), *this, x); - return res; -} - -template -MAYBE_INLINE void Secret::xor_(int n, const Secret& x, const Secret& y) -{ - int min_n = min((size_t)n, min(x.registers.size(), y.registers.size())); - resize_regs(min_n); - for (int i = 0; i < min_n; i++) - { - XOR(registers[i], x.get_reg(i), y.get_reg(i)); - } - + int min_n = min(registers.size(), x.registers.size()); + int n = max(registers.size(), x.registers.size()); + res.xor_(min_n, *this, x); if (min_n < n) { - const vector* more_regs; - if (y.registers.size() < x.registers.size()) + const vector* more_regs; + if (registers.size() < x.registers.size()) more_regs = &x.registers; else - more_regs = &y.registers; - registers.insert(registers.end(), more_regs->begin() + min_n, + more_regs = ®isters; + res.registers.insert(res.registers.end(), more_regs->begin() + min_n, more_regs->begin() + min((size_t)n, more_regs->size())); } + return res; } template diff --git a/GC/instructions.h b/GC/instructions.h index eef212d6..3758ff4a 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -47,6 +47,7 @@ X(XORC, C0.xor_(C1, C2)) \ X(XORCI, C0.xor_(C1, N)) \ X(ANDRS, T::andrs(P, EXTRA)) \ + X(INPUTB, T::inputb(P, EXTRA)) \ X(ADDC, C0 = C1 + C2) \ X(ADDCI, C0 = C1 + N) \ X(MULCI, C0 = C1 * N) \ @@ -74,6 +75,7 @@ X(REVEAL, S1.reveal(C0)) \ X(PRINTREG, P.print_reg(R0, N)) \ X(PRINTREGPLAIN, P.print_reg_plain(C0)) \ + X(PRINTREGSIGNED, P.print_reg_signed(N, C0)) \ X(PRINTCHR, P.print_chr(N)) \ X(PRINTSTR, P.print_str(N)) \ X(LDINT, I0 = int(N)) \ diff --git a/Makefile b/Makefile index 0d670b03..c8370de6 100644 --- a/Makefile +++ b/Makefile @@ -26,8 +26,9 @@ OT_EXE = ot.x ot-offline.x endif COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) -COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(OT) -BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(GC) +COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) +YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) +BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(GC) $(YAO) LIB = libSPDZ.a @@ -59,6 +60,8 @@ externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x bmr: bmr-program-party.x bmr-program-tparty.x +yao: yao-simulate.x yao-player.x + she-offline: Check-Offline.x spdz2-offline.x overdrive: simple-offline.x pairwise-offline.x cnc-offline.x @@ -102,10 +105,10 @@ gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(BMR) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) bmr-program-party.x: $(BMR) bmr-program-party.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + $(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS) bmr-program-tparty.x: $(BMR) bmr-program-tparty.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + $(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS) bmr-clean: -rm BMR/*.o BMR/*/*.o GC/*.o @@ -133,5 +136,14 @@ spdz2-offline.x: $(COMMON) $(FHEOFFLINE) spdz2-offline.cpp $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) endif +yao-simulate.x: $(YAO) $(BMR) yao-simulate.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + +yao-player.x: $(YAO) $(BMR) yao-player.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + +yao-clean: + -rm Yao/*.o + clean: -rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o diff --git a/Math/bigint.h b/Math/bigint.h index 22b209a8..e9713c6f 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -33,6 +33,10 @@ public: bigint(const T& x) : mpz_class(x) {} bigint(const gfp& x); + bigint& operator=(int n); + bigint& operator=(long n); + bigint& operator=(word n); + void allocate_slots(const bigint& x) { *this = x; } int get_min_alloc() { return get_mpz_t()->_mp_alloc; } @@ -72,6 +76,24 @@ public: size_t report_size(ReportType type) const; }; +inline bigint& bigint::operator=(int n) +{ + mpz_class::operator=(n); + return *this; +} + +inline bigint& bigint::operator=(long n) +{ + mpz_class::operator=(n); + return *this; +} + +inline bigint& bigint::operator=(word n) +{ + mpz_class::operator=(n); + return *this; +} + /********************************** * Utility Functions * diff --git a/Math/gfp.cpp b/Math/gfp.cpp index fa1ed168..b56a52fd 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -19,7 +19,7 @@ void gfp::AND(const gfp& x,const gfp& y) to_bigint(bi1,x); to_bigint(bi2,y); mpz_and(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); - to_gfp(*this, bi1); + convert_destroy(bi1); } void gfp::OR(const gfp& x,const gfp& y) @@ -28,7 +28,7 @@ void gfp::OR(const gfp& x,const gfp& y) to_bigint(bi1,x); to_bigint(bi2,y); mpz_ior(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); - to_gfp(*this, bi1); + convert_destroy(bi1); } void gfp::XOR(const gfp& x,const gfp& y) @@ -37,7 +37,7 @@ void gfp::XOR(const gfp& x,const gfp& y) to_bigint(bi1,x); to_bigint(bi2,y); mpz_xor(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t()); - to_gfp(*this, bi1); + convert_destroy(bi1); } void gfp::AND(const gfp& x,const bigint& y) @@ -45,7 +45,7 @@ void gfp::AND(const gfp& x,const bigint& y) bigint bi; to_bigint(bi,x); mpz_and(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); - to_gfp(*this, bi); + convert_destroy(bi); } void gfp::OR(const gfp& x,const bigint& y) @@ -53,7 +53,7 @@ void gfp::OR(const gfp& x,const bigint& y) bigint bi; to_bigint(bi,x); mpz_ior(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); - to_gfp(*this, bi); + convert_destroy(bi); } void gfp::XOR(const gfp& x,const bigint& y) @@ -61,7 +61,7 @@ void gfp::XOR(const gfp& x,const bigint& y) bigint bi; to_bigint(bi,x); mpz_xor(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t()); - to_gfp(*this, bi); + convert_destroy(bi); } @@ -76,7 +76,7 @@ void gfp::SHL(const gfp& x,int n) bigint bi; to_bigint(bi,x,false); mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + convert_destroy(bi); } else assign(x); @@ -97,7 +97,7 @@ void gfp::SHR(const gfp& x,int n) bigint bi; to_bigint(bi,x); mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + convert_destroy(bi); } else assign(x); diff --git a/Math/gfp.h b/Math/gfp.h index e981f358..62c6b25d 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -50,9 +50,9 @@ class gfp void assign(const gfp& g) { a=g.a; } void assign_zero() { assignZero(a,ZpD); } void assign_one() { assignOne(a,ZpD); } - void assign(word aa) { bigint b=aa; to_gfp(*this,b); } - void assign(long aa) { bigint b=aa; to_gfp(*this,b); } - void assign(int aa) { bigint b=aa; to_gfp(*this,b); } + void assign(word aa) { bigint b=aa; convert_destroy(b); } + void assign(long aa) { bigint b=aa; convert_destroy(b); } + void assign(int aa) { bigint b=aa; convert_destroy(b); } void assign(const char* buffer) { a.assign(buffer, ZpD.get_t()); } modp get() const { return a; } @@ -92,7 +92,7 @@ class gfp bool is_zero() const { return isZero(a,ZpD); } - bool is_one() const { return isOne(a,ZpD); } + bool is_one() const { return isOne(a,ZpD); } bool is_bit() const { return is_zero() or is_one(); } bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } bool operator==(const gfp& y) const { return equal(y); } @@ -198,6 +198,7 @@ class gfp void unpack(octetStream& o) { a.unpack(o,ZpD); } + void convert_destroy(bigint& x) { a.convert_destroy(x, ZpD); } // Convert representation to and from a bigint number friend void to_bigint(bigint& ans,const gfp& x,bool reduce=true) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index bb0259c8..bc3b83c8 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -230,6 +230,26 @@ void Player::setup_sockets(const vector& names,const vector& ports, } +void Player::send_long(int i, long a) const +{ + send(sockets[i], (octet*)&a, sizeof(long)); +} + +long Player::receive_long(int i) const +{ + long res; + receive(sockets[i], (octet*)&res, sizeof(long)); + return res; +} + +long Player::peek_long(int i) const +{ + long res; + recv(sockets[i], &res, sizeof(res), MSG_PEEK); + return res; +} + + void Player::send_to(int player,const octetStream& o,bool donthash) const { TimeScope ts(comm_stats["Sending directly"].add(o)); @@ -263,6 +283,13 @@ void Player::receive_player(int i,octetStream& o,bool donthash) const { blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); } } +void Player::receive_player(int i, FlexBuffer& buffer) const +{ + octetStream os; + receive_player(i, os, true); + buffer = os; +} + void Player::exchange(int other, octetStream& o) const { @@ -320,7 +347,6 @@ void Player::Check_Broadcast() const blk_SHA1_Init(&ctx); } - void Player::wait_for_available(vector& players, vector& result) const { fd_set rfds; diff --git a/Networking/Player.h b/Networking/Player.h index 573c7d24..3e24286d 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -17,6 +17,7 @@ using namespace std; #include "Tools/octetStream.h" +#include "Tools/FlexBuffer.h" #include "Networking/sockets.h" #include "Networking/ServerSocket.h" #include "Tools/sha1.h" @@ -158,11 +159,16 @@ public: void send_int(int i,int a) const { send(sockets[i],a); } void receive_int(int i,int& a) const { receive(sockets[i],a); } + void send_long(int i, long a) const; + long receive_long(int i) const; + long peek_long(int i) const; + // Send an octetStream to all other players // -- And corresponding receive virtual void send_all(const octetStream& o,bool donthash=false) const; void send_to(int player,const octetStream& o,bool donthash=false) const; virtual void receive_player(int i,octetStream& o,bool donthash=false) const; + void receive_player(int i,FlexBuffer& buffer) const; // exchange data with minimal memory usage void exchange(int other, octetStream& o) const; diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index c02820ee..a846d337 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -136,7 +136,6 @@ inline void square32::transpose(square128& output, int x, int y) #endif #ifdef __AVX2__ -#warning Using AVX2 for transpose typedef square32 subsquare; #define N_SUBSQUARES 4 #else diff --git a/OT/Tools.cpp b/OT/Tools.cpp index 4448e811..46b7733a 100644 --- a/OT/Tools.cpp +++ b/OT/Tools.cpp @@ -12,7 +12,7 @@ void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len) vector Open_seed(2); G.get_octetStream(seed_strm[0], len); - Commit(Comm_seed[0], Open_seed[0], seed_strm[0], player.my_num()); + Commit(Comm_seed[0], Open_seed[0], seed_strm[0], player.my_real_num()); player.send_receive_player(Comm_seed); player.send_receive_player(Open_seed); diff --git a/Player-Online.cpp b/Player-Online.cpp index 0e843ea2..d13f67d9 100644 --- a/Player-Online.cpp +++ b/Player-Online.cpp @@ -4,6 +4,7 @@ #include "Math/Setup.h" #include "Tools/ezOptionParser.h" #include "Tools/Config.h" +#include "Networking/Server.h" #include #include @@ -138,6 +139,17 @@ int main(int argc, const char** argv) "-c", // Flag token. "--player-to-player-commsec" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of players. Server.x is not used if given. " + "Player 0 must run on host given with -h then. " + "Default: Use Server.x or IP file", // Help description. + "-N", // Flag token. + "--nparties" // Flag token. + ); opt.parse(argc, argv); @@ -219,12 +231,23 @@ int main(int argc, const char** argv) } Names playerNames; + Server* server = 0; if (ipFileName.size() > 0) { if (my_port != Names::DEFAULT_PORT) throw runtime_error("cannot set port number when using IP file"); playerNames.init(playerno, pnbase, ipFileName); } else { - playerNames.init(playerno, pnbase, my_port, hostname.c_str()); + if (opt.get("-N")->isSet) + { + if (my_port != Names::DEFAULT_PORT) + throw runtime_error("cannot set port number when not using Server.x"); + int nplayers; + opt.get("-N")->getInt(nplayers); + server = Server::start_networking(playerNames, mynum, nplayers, + hostname, pnbase); + } + else + playerNames.init(playerno, pnbase, my_port, hostname.c_str()); } playerNames.set_keys(keys); @@ -236,6 +259,9 @@ int main(int argc, const char** argv) opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, opt.get("--threads")->isSet, max_broadcast).run(); + if (server) + delete server; + cerr << "Command line:"; for (int i = 0; i < argc; i++) cerr << " " << argv[i]; diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index fd24e50d..9f981581 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -244,6 +244,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case INPUTMASK: case GINPUTMASK: case ACCEPTCLIENTCONNECTION: + case INV2M: r[0]=get_int(s); n = get_int(s); break; @@ -546,7 +547,7 @@ void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) { switch (opcode) { case LDI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); Proc.write_Cp(r[0],Proc.temp.ansp); break; case GLDI: @@ -554,7 +555,7 @@ void Instruction::execute(Processor& Proc) const Proc.write_C2(r[0],Proc.temp.ans2); break; case LDSI: - { Proc.temp.ansp.assign(n); + { Proc.temp.assign_ansp(n); if (Proc.P.my_num()==0) Proc.get_Sp_ref(r[0]).set_share(Proc.temp.ansp); else @@ -852,7 +853,7 @@ void Instruction::execute(Processor& Proc) const to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); - to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); break; case LEGENDREC: @@ -876,7 +877,7 @@ void Instruction::execute(Processor& Proc) const case DIVCI: if (n == 0) throw Processor_Error("Division by immediate zero"); - to_gfp(Proc.temp.ansp,n%gfp::pr()); + Proc.temp.assign_ansp(n); Proc.temp.ansp.invert(); Proc.temp.ansp.mul(Proc.read_Cp(r[1])); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -889,9 +890,18 @@ void Instruction::execute(Processor& Proc) const Proc.temp.ans2.mul(Proc.read_C2(r[1])); Proc.write_C2(r[0],Proc.temp.ans2); break; + case INV2M: + if (Proc.inverses2m.find(n) == Proc.inverses2m.end()) + { + to_gfp(Proc.inverses2m[n], bigint(1) << n); + Proc.inverses2m[n].invert(); + } + Proc.write_Cp(r[0], Proc.inverses2m[n]); + break; case MODCI: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); - to_gfp(Proc.temp.ansp, mpz_fdiv_ui(Proc.temp.aa.get_mpz_t(), n)); + Proc.temp.aa = mpz_fdiv_ui(Proc.temp.aa.get_mpz_t(), n); + Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); break; case GMULBITC: @@ -911,7 +921,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case ADDCI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Proc.temp.ansp.add(Proc.temp.ansp,Proc.read_Cp(r[1])); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -929,7 +939,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case ADDSI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Sansp.add(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); Proc.write_Sp(r[0],Sansp); @@ -947,7 +957,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case SUBCI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Proc.temp.ansp.sub(Proc.read_Cp(r[1]),Proc.temp.ansp); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -965,7 +975,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case SUBSI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Sansp.sub(Proc.read_Sp(r[1]),Proc.temp.ansp,Proc.P.my_num()==0,Proc.MCp.get_alphai()); Proc.write_Sp(r[0],Sansp); @@ -983,7 +993,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case SUBCFI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Proc.temp.ansp.sub(Proc.temp.ansp,Proc.read_Cp(r[1])); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -1001,7 +1011,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case SUBSFI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Sansp.sub(Proc.temp.ansp,Proc.read_Sp(r[1]),Proc.P.my_num()==0,Proc.MCp.get_alphai()); Proc.write_Sp(r[0],Sansp); @@ -1019,7 +1029,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case MULCI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Proc.temp.ansp.mul(Proc.temp.ansp,Proc.read_Cp(r[1])); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -1037,7 +1047,7 @@ void Instruction::execute(Processor& Proc) const #endif break; case MULSI: - Proc.temp.ansp.assign(n); + Proc.temp.assign_ansp(n); #ifdef DEBUG Sansp.mul(Proc.read_Sp(r[1]),Proc.temp.ansp); Proc.write_Sp(r[0],Sansp); @@ -1103,7 +1113,7 @@ void Instruction::execute(Processor& Proc) const #ifdef DEBUG printf("Enter your input : \n"); #endif - word x; + long x; cin >> x; t.assign(x); t.sub(t,rr); @@ -1269,7 +1279,7 @@ void Instruction::execute(Processor& Proc) const Proc.temp.aa2 = 1; Proc.temp.aa2 <<= n; Proc.temp.aa += Proc.temp.aa2; - to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); break; case GNOTC: @@ -1437,7 +1447,8 @@ void Instruction::execute(Processor& Proc) const Proc.get_Ci_ref(r[0]) = Proc.read_Ci(r[1]) / Proc.read_Ci(r[2]); break; case CONVINT: - Proc.get_Cp_ref(r[0]).assign(Proc.read_Ci(r[1])); + Proc.temp.assign_ansp(Proc.read_Ci(r[1])); + Proc.get_Cp_ref(r[0]) = Proc.temp.ansp; break; case GCONVINT: Proc.get_C2_ref(r[0]).assign((word)Proc.read_Ci(r[1])); @@ -1483,6 +1494,12 @@ void Instruction::execute(Processor& Proc) const cout << Proc.read_C2(r[0]) << flush; } break; + case PRINTINT: + if (Proc.P.my_num() == 0) + { + cout << Proc.read_Ci(r[0]) << flush; + } + break; case PRINTFLOATPLAIN: if (Proc.P.my_num() == 0) { @@ -1508,12 +1525,6 @@ void Instruction::execute(Processor& Proc) const cout << res << flush; } break; - case PRINTINT: - if (Proc.P.my_num() == 0) - { - cout << Proc.read_Ci(r[0]) << flush; - } - break; case PRINTSTR: if (Proc.P.my_num() == 0) { diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 2d095a48..7a23a89f 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -91,6 +91,7 @@ enum MODCI = 0x37, LEGENDREC = 0x38, DIGESTC = 0x39, + INV2M = 0x3a, // Open STARTOPEN = 0xA0, STOPOPEN = 0xA1, @@ -165,8 +166,8 @@ enum PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, - WRITEFILESHARE = 0xBD, - READFILESHARE = 0xBE, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, // GF(2^n) versions @@ -283,6 +284,12 @@ struct TempVars { // GINPUT and GLDSI gf2n rr2,t2,tmp2; gf2n xi2; + // assign without allocation + void assign_ansp(int n) + { + aa = n; + ansp.convert_destroy(aa); + } }; diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp index 9582b9aa..0f857754 100644 --- a/Processor/Machine.cpp +++ b/Processor/Machine.cpp @@ -78,10 +78,10 @@ Machine::Machine(int my_number, Names& playerNames, exit(1); } - sprintf(filename, "Programs/Schedules/%s.sch",progname.c_str()); - cerr << "Opening file " << filename << endl; - inpf.open(filename); - if (inpf.fail()) { throw file_error("Missing '" + string(filename) + "'. Did you compile '" + progname + "'?"); } + string fname = "Programs/Schedules/" + progname + ".sch"; + cerr << "Opening file " << fname << endl; + inpf.open(fname); + if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); } int nprogs; inpf >> nthreads; @@ -95,10 +95,10 @@ Machine::Machine(int my_number, Names& playerNames, // Load in the programs progs.resize(nprogs,N.num_players()); - char threadname[1024]; + string threadname; for (int i=0; i> threadname; - sprintf(filename,"Programs/Bytecode/%s.bc",threadname); + string filename = "Programs/Bytecode/" + threadname + ".bc"; cerr << "Loading program " << i << " from " << filename << endl; ifstream pinp(filename); if (pinp.fail()) { throw file_error(filename); } diff --git a/Processor/Processor.h b/Processor/Processor.h index 2dfae97b..b67a0dea 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -115,6 +115,9 @@ class Processor : public ProcessorBase ExternalClients external_clients; Binary_File_IO binary_file_io; + // avoid re-computation of expensive division + map inverses2m; + static const int reg_bytes = 4; void reset(const Program& program,int arg); // Reset the state of the processor diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index 73160ca2..43d25529 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -1,8 +1,6 @@ # (C) 2018 University of Bristol, Bar-Ilan University. See License.txt -program.options.merge_opens = False - -from Compiler.GC.types import sbits, sbit, cbits +program.optimize_for_gc() def test(a, b, value_type=None): try: @@ -35,7 +33,7 @@ test(sbit(1).if_else(1, 2), 1) test(sbit(0).if_else(2, 1), 1) test(sbit(1).if_else(2, 1), 2) -test(sbits.compose((sbits(2), sbits(1, n=2)), 2), 6) +test(sbits.compose((sbits(2, n=2), sbits(1, n=2)), 2), 6) x = MemValue(sbits(1234)) program.curr_tape.start_new_basicblock() @@ -67,3 +65,8 @@ b = sbits.get_random_bit() test(b * (1 - b), 0) bits = [sbits.get_random_bit() for i in range(40)] print_ln('random: %s', sbits.bit_compose(bits).reveal()) + +k = 41 +a = int(2.9142 * 2**k) +alpha = sbitint.get_type(2 * k)(a) +test(sbits.bit_compose((alpha >> 64).bit_decompose()[:64]), 0) diff --git a/README.md b/README.md index bc4e8588..8be0941f 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,30 @@ (C) 2018 University of Bristol, Bar-Ilan University. See License.txt -This repository contains the code to benchmark ORAM in SPDZ-BMR as used for the [Eurocrypt 2018 paper](https://eprint.iacr.org/2017/981) by Marcel Keller and Avishay Yanay. +This repository contains code to run computation with Yao's garbled circuits optimized for AES-NI by [Bellare et al.](https://eprint.iacr.org/2013/426). #### Preface: -This implementation only allows to benchmark the data-dependent phase. The data-independent and function-independent phases are emulated insecurely. The software should be considered an academic prototype, and we will only give advice on re-running the examples below. +The main purpose of this software is to provide a quick way to benchmark the computation of some programs written in a subset of the SPDZ high-level language (using purely `sint` and `sfix`) with Yao's garbled circuits. Private inputs are not supported. #### Requirements: - GCC (tested with 7.2) or LLVM (tested with 3.8) - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.11 - - CPU supporting AES-NI and PCLMUL + - CPU supporting AES-NI, PCLMUL and AVX2 - Python 2.x - If using macOS, Sierra or later -#### To compile: +#### Compile the VM: -1) Edit `CONFIG` or `CONFIG.mine`: +Run `make yao` (use the flag -j for faster compilation multiple threads). - - Add the following line at the top: `MY_CFLAGS = -DINSECURE` - - For processors without AVX (e.g., Intel Atom) or for optimization, set `ARCH = -march=`. +#### Compile the circuit: -2) Run `make bmr` (use the flag -j for faster compilation multiple threads). Remember to run `make clean` first after changing `CONFIG` or `CONFIG.mine`. - -#### Configure the parameters: - -1) Edit `Program/Source/gc_oram.mpc` to change size and to choose Circuit ORAM or linear scan without ORAM. -2) Run `./compile.py -D gc_oram`. +Run `./compile.py -D ` to compile the `Programs/Source/.mpc`. See `gc_tutorial.mpc` and `gc_fixed_point_tutorial.mpc` for examples. #### Run the protocol: -- Run everything locally: `Scripts/bmr-program-run.sh gc_oram`. -- Run on different hosts: `Scripts/bmr-program-run-remote.sh gc_oram [...]` - -To run with more than two parties, change `CFLAGS = -DN_PARTIES=` in `CONFIG`, and compile again after `make clean`. +- Run everything locally: `./yao-simulate.x ` +- Run on different hosts: + - Garbler: ```./yao-player.x -p 0 ``` + - Evaluator: ```./yao-player.x -p 1 -h ``` diff --git a/Scripts/bmr-program-run.sh b/Scripts/bmr-program-run.sh index 8a88669f..28467026 100755 --- a/Scripts/bmr-program-run.sh +++ b/Scripts/bmr-program-run.sh @@ -1,4 +1,12 @@ -#!/bin/sh +#!/bin/bash + +while getopts t: opt; do + case $opt in + t) threshold=$OPTARG ;; + esac +done + +shift $[OPTIND-1] # (C) 2018 University of Bristol, Bar-Ilan University. See License.txt @@ -28,6 +36,6 @@ done $prefix ./bmr-program-tparty.x $prog $netmap 2>&1 &> bmr-log/t & for i in $(seq $[n_players-1]); do - $prefix ./bmr-program-party.x $i $prog $netmap 2>&1 &> bmr-log/$i & + $prefix ./bmr-program-party.x $i $prog $netmap $threshold 2>&1 &> bmr-log/$i & done -$prefix ./bmr-program-party.x $n_players $prog $netmap 2>&1 | tee bmr-log/$n_players +$prefix ./bmr-program-party.x $n_players $prog $netmap $threshold 2>&1 | tee bmr-log/$n_players diff --git a/SimpleOT b/SimpleOT index 52b43a25..cbe71d3e 160000 --- a/SimpleOT +++ b/SimpleOT @@ -1 +1 @@ -Subproject commit 52b43a250922fb45f0bfff73ba2f9b9a11c1784c +Subproject commit cbe71d3eac29836ac58711474f985f2db7e5af41 diff --git a/Tools/Config.cpp b/Tools/Config.cpp index d2c629da..68dd74b2 100644 --- a/Tools/Config.cpp +++ b/Tools/Config.cpp @@ -122,7 +122,7 @@ namespace Config { pubkeys[i].resize(crypto_sign_PUBLICKEYBYTES); infile.read((char*)&pubkeys[i][0],pubkeys[i].size()); } - } catch (ConfigError e) { + } catch (ConfigError& e) { pubkeys.resize(0); } diff --git a/Tools/FlexBuffer.h b/Tools/FlexBuffer.h index 5e26c6e8..2682c95e 100644 --- a/Tools/FlexBuffer.h +++ b/Tools/FlexBuffer.h @@ -10,7 +10,7 @@ #include "Tools/avx_memcpy.h" #include "Tools/time-func.h" - +#include "Tools/octetStream.h" #include #include #include @@ -19,6 +19,8 @@ using namespace std; class FlexBuffer { + friend class octetStream; + protected: char* buf, *ptr; size_t len, max_len; @@ -29,6 +31,7 @@ public: FlexBuffer(const FlexBuffer&); ~FlexBuffer() { del(); } void operator=(FlexBuffer& msg); + void operator=(octetStream& os); char* data() { return buf; } const char* data() const { return buf; } size_t size() const { return len; } @@ -40,6 +43,7 @@ class ReceivedMsg : public virtual FlexBuffer friend class ReceivedMsgStore; public: + void operator=(FlexBuffer& msg) { FlexBuffer::operator=(msg); } void reset_head() { ptr = buf; } void resize(size_t new_len); void unserialize(void* output, size_t size); @@ -56,6 +60,9 @@ public: class SendBuffer : public virtual FlexBuffer { public: + void operator=(FlexBuffer& msg) { FlexBuffer::operator=(msg); } + char* end() { return buf + len; } + void skip(size_t n) { len += n; } void resize(size_t new_len); void resize_copy(size_t new_max_len); void clear() { len = 0; } @@ -65,9 +72,11 @@ public: void serialize(const T& source) { serialize(&source, sizeof(T)); } void serialize(const void* source, size_t size); void allocate(size_t size); + char* allocate_and_skip(size_t size); template void serialize_no_allocate(const T& source); void serialize_no_allocate(const void* source, size_t size); + void send(int socket_num); }; class LocalBuffer : public ReceivedMsg, public SendBuffer @@ -113,6 +122,16 @@ inline void FlexBuffer::operator=(FlexBuffer& msg) } } +inline void FlexBuffer::operator=(octetStream& os) +{ + del(); + buf = (char*)os.get_data(); + ptr = (char*)os.get_data() + os.get_ptr(); + len = os.get_length(); + max_len = os.get_max_length(); + os.reset(); +} + inline void ReceivedMsg::resize(size_t new_len) { if (new_len > max_len) @@ -196,6 +215,14 @@ inline void SendBuffer::serialize_no_allocate(const T& source) serialize_no_allocate(&source, sizeof(T)); } +inline char* SendBuffer::allocate_and_skip(size_t size) +{ + allocate(size); + char* res = end(); + skip(size); + return res; +} + inline void SendBuffer::serialize_no_allocate(const void* source, size_t size) { avx_memcpy(buf + len, source, size); diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index bf8ea26c..0ca31393 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -27,17 +27,6 @@ void MMO::setIV(octet key[AES_BLK_SIZE]) } -template -void MMO::encrypt_and_xor(void* output, const void* input, const octet* key) -{ - __m128i in[N], out[N]; - avx_memcpy(in, input, sizeof(in)); - ecb_aes_128_encrypt(out, in, key); - for (int i = 0; i < N; i++) - out[i] = _mm_xor_si128(out[i], in[i]); - avx_memcpy(output, out, sizeof(out)); -} - template void MMO::encrypt_and_xor(void* output, const void* input, const octet* key, const int* indices) diff --git a/Tools/MMO.h b/Tools/MMO.h index 3e7af72b..9e76d488 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -9,6 +9,7 @@ #define TOOLS_MMO_H_ #include "Tools/aes.h" +#include "BMR/Key.h" // Matyas-Meyer-Oseas hashing class MMO @@ -32,6 +33,33 @@ public: void hashBlockWise(octet* output, octet* input); template void outputOneBlock(octet* output); + Key hash(const Key& input); + template + void hash(Key* output, const Key* input); }; +template +inline void MMO::encrypt_and_xor(void* output, const void* input, const octet* key) +{ + __m128i in[N], out[N]; + avx_memcpy(in, input, sizeof(in)); + ecb_aes_128_encrypt(out, in, key); + for (int i = 0; i < N; i++) + out[i] = _mm_xor_si128(out[i], in[i]); + avx_memcpy(output, out, sizeof(out)); +} + +inline Key MMO::hash(const Key& input) +{ + Key res; + encrypt_and_xor<1>(&res.r, &input.r, IV); + return res; +} + +template +inline void MMO::hash(Key* output, const Key* input) +{ + encrypt_and_xor(output, input, IV); +} + #endif /* TOOLS_MMO_H_ */ diff --git a/Tools/Worker.h b/Tools/Worker.h index 07c0b418..66c5976b 100644 --- a/Tools/Worker.h +++ b/Tools/Worker.h @@ -16,7 +16,9 @@ class Worker pthread_t thread; WaitQueue input; WaitQueue output; - Timer timer; + Timer timer, wall_timer; + Timer request_timer, done_timer; + int n_jobs; static void* run_thread(void* worker) { @@ -29,8 +31,16 @@ class Worker T* job = 0; while (input.pop(job)) { - TimeScope ts(timer); - output.push(job->run()); +#ifdef WORKER_TIMINGS + request_timer.stop(); + TimeScope ts(timer), ts2(wall_timer); +#endif + int res = job->run(); +#ifdef WORKER_TIMINGS + done_timer.start(); +#endif + output.push(res); + n_jobs++; } } @@ -38,17 +48,27 @@ public: Worker() : timer(CLOCK_THREAD_CPUTIME_ID) { pthread_create(&thread, 0, Worker::run_thread, this); + n_jobs = 0; } ~Worker() { input.stop(); pthread_join(thread, 0); +#ifdef WORKER_TIMINGS cout << "Worker time: " << timer.elapsed() << endl; + cout << "Worker wall time: " << wall_timer.elapsed() << endl; + cout << "Request time: " << request_timer.elapsed() << endl; + cout << "Done time: " << done_timer.elapsed() << endl; + cout << "Run jobs: " << n_jobs << endl; +#endif } void request(T& job) { +#ifdef WORKER_TIMINGS + request_timer.start(); +#endif input.push(&job); } @@ -56,6 +76,9 @@ public: { int res = 0; output.pop(res); +#ifdef WORKER_TIMINGS + done_timer.stop(); +#endif return res; } }; diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index 9688b238..64acf765 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -12,8 +12,15 @@ #include "Networking/data.h" #include "Math/bigint.h" #include "Tools/time-func.h" +#include "Tools/FlexBuffer.h" +void octetStream::reset() +{ + data = 0; + len = mxlen = ptr = 0; +} + void octetStream::clear() { if (data) @@ -37,16 +44,6 @@ void octetStream::assign(const octetStream& os) } -void octetStream::swap(octetStream& os) -{ - const size_t size = sizeof(octetStream); - char tmp[size]; - memcpy(tmp, this, size); - memcpy(this, &os, size); - memcpy(&os, tmp, size); -} - - octetStream::octetStream(size_t maxlen) { mxlen=maxlen; len=0; ptr=0; @@ -63,6 +60,15 @@ octetStream::octetStream(const octetStream& os) ptr=os.ptr; } +octetStream::octetStream(FlexBuffer& buffer) +{ + mxlen = buffer.capacity(); + len = buffer.size(); + data = (octet*)buffer.data(); + ptr = buffer.ptr - buffer.data(); + buffer.reset(); +} + void octetStream::hash(octetStream& output) const { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 2db7a7ce..f94cf275 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -33,12 +33,17 @@ using namespace std; class bigint; +class FlexBuffer; class octetStream { + friend class FlexBuffer; + size_t len,mxlen,ptr; // len is the "write head", ptr is the "read head" octet *data; + void reset(); + public: void resize(size_t l); @@ -46,10 +51,10 @@ class octetStream void clear(); void assign(const octetStream& os); - void swap(octetStream& os); octetStream() : len(0), mxlen(0), ptr(0), data(0) {} octetStream(size_t maxlen); + octetStream(FlexBuffer& buffer); octetStream(const octetStream& os); octetStream& operator=(const octetStream& os) { if (this!=&os) { assign(os); } diff --git a/Tools/time-func.h b/Tools/time-func.h index 916871fa..253c4e82 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -6,6 +6,7 @@ #include /* Wait for Process Termination */ #include #include +#include #include "Exceptions/Exceptions.h" @@ -43,6 +44,18 @@ public: ~TimeScope() { timer.stop(); } }; +class DoubleTimer +{ + Timer wall, thread; + +public: + DoubleTimer() : thread(CLOCK_THREAD_CPUTIME_ID) {} + void start() { wall.start(); thread.start(); } + void stop() { wall.stop(); thread.stop(); } + string elapsed() + { return to_string(thread.elapsed()) + "/" + to_string(wall.elapsed()); } +}; + inline Timer& Timer::start() { if (running) diff --git a/Yao/YaoAndJob.h b/Yao/YaoAndJob.h new file mode 100644 index 00000000..f90d3cb6 --- /dev/null +++ b/Yao/YaoAndJob.h @@ -0,0 +1,58 @@ +/* + * YaoAndJob.h + * + */ + +#ifndef YAO_YAOANDJOB_H_ +#define YAO_YAOANDJOB_H_ + +#include "YaoGarbleWire.h" +#include "Tools/Worker.h" + +class YaoGate; + +class YaoAndJob +{ + GC::Memory< GC::Secret >* S; + const vector* args; + size_t start, end, n_gates; + YaoGate* gate; + long counter; + PRNG prng; + map timers; + +public: + Worker worker; + + YaoAndJob() : S(0), args(0), start(0), end(0), n_gates(0), gate(0), + counter(0) { prng.ReSeed(); } + + ~YaoAndJob() + { + for (auto& x : timers) + cout << x.first << " time:" << x.second.elapsed() << endl; + } + + void dispatch(GC::Memory >& S, const vector& args, + size_t start, size_t end, size_t n_gates, + YaoGate* gate, long counter) + { + this->S = &S; + this->args = &args; + this->start = start; + this->end = end; + this->n_gates = n_gates; + this->gate = gate; + this->counter = counter; + worker.request(*this); + } + + int run() + { + YaoGarbleWire::andrs(*S, *args, start, end, n_gates, gate, counter, + prng, timers); + return 0; + } +}; + +#endif /* YAO_YAOANDJOB_H_ */ diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp new file mode 100644 index 00000000..0572f151 --- /dev/null +++ b/Yao/YaoEvalWire.cpp @@ -0,0 +1,115 @@ +/* + * YaoEvalWire.cpp + * + */ + +#include "YaoEvalWire.h" +#include "YaoGate.h" +#include "YaoEvaluator.h" +#include "BMR/prf.h" +#include "BMR/common.h" + +ostream& YaoEvalWire::out = cout; + +void YaoEvalWire::random() +{ + set(0); +} + +void YaoEvalWire::public_input(bool value) +{ + (void)value; + set(0); +} + +void YaoEvalWire::andrs(GC::Processor >& processor, + const vector& args) +{ + int total_ands = processor.check_args(args, 4); + if (total_ands < 10) + return processor.andrs(args); + processor.complexity += total_ands; + Key* labels; + Key* hashes; + vector label_vec, hash_vec; + size_t n_hashes = total_ands; + Key label_arr[1000], hash_arr[1000]; + if (total_ands < 1000) + { + labels = label_arr; + hashes = hash_arr; + } + else + { + label_vec.resize(n_hashes); + hash_vec.resize(n_hashes); + labels = label_vec.data(); + hashes = hash_vec.data(); + } + size_t i_label = 0; + size_t n_args = args.size(); + auto& evaluator = YaoEvaluator::s(); + for (size_t i = 0; i < n_args; i += 4) + { + const Key& right_key = processor.S[args[i + 3]].get_reg(0).key; + for (auto& left_wire: processor.S[args[i + 2]].get_regs()) + { + long counter = ++evaluator.counter; + labels[i_label++] = YaoGate::E_input(left_wire.key, right_key, + counter); + } + } + MMO& mmo = evaluator.mmo; + size_t i; + for (i = 0; i + 8 <= n_hashes; i += 8) + mmo.hash<8>(&hashes[i], &labels[i]); + for (; i < n_hashes; i++) + hashes[i] = mmo.hash(labels[i]); + size_t j = 0; + for (size_t i = 0; i < n_args; i += 4) + { + YaoEvalWire& right_wire = processor.S[args[i + 3]].get_reg(0); + auto& out = processor.S[args[i + 1]]; + out.resize_regs(args[i]); + int n = args[i]; + for (int k = 0; k < n; k++) + { + auto& left_wire = processor.S[args[i + 2]].get_reg(k); + YaoGate gate; + evaluator.load_gate(gate); + gate.eval(out.get_reg(k), hashes[j++], + gate.get_entry(left_wire.external, right_wire.external)); + } + } +} + +void YaoEvalWire::op(const YaoEvalWire& left, const YaoEvalWire& right, + Function func) +{ + (void)func; + YaoGate gate; + YaoEvaluator::s().load_gate(gate); + YaoEvaluator::s().counter++; + gate.eval(*this, left, right); +} + +void YaoEvalWire::XOR(const YaoEvalWire& left, const YaoEvalWire& right) +{ + external = left.external ^ right.external; + key = left.key ^ right.key; +} + +bool YaoEvalWire::get_output() +{ + bool res = external ^ YaoEvaluator::s().output_masks.pop_front(); +#ifdef DEBUG + cout << "output " << res << endl; +#endif + return res; +} + +void YaoEvalWire::set(const Key& key) +{ + this->key = key; + external = key.get_signal(); +} diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h new file mode 100644 index 00000000..b49f71be --- /dev/null +++ b/Yao/YaoEvalWire.h @@ -0,0 +1,39 @@ +/* + * YaoEvalWire.h + * + */ + +#ifndef YAO_YAOEVALWIRE_H_ +#define YAO_YAOEVALWIRE_H_ + +#include "BMR/Key.h" +#include "BMR/Gate.h" +#include "BMR/Register.h" +#include "GC/Processor.h" + +class YaoEvalWire : public Phase +{ +public: + static string name() { return "YaoEvalWire"; } + + typedef ostream& out_type; + static ostream& out; + + bool external; + Key key; + + static YaoEvalWire new_reg() { return {}; } + + static void andrs(GC::Processor>& processor, + const vector& args); + + void set(const Key& key); + + void random(); + void public_input(bool value); + void op(const YaoEvalWire& left, const YaoEvalWire& right, Function func); + void XOR(const YaoEvalWire& left, const YaoEvalWire& right); + bool get_output(); +}; + +#endif /* YAO_YAOEVALWIRE_H_ */ diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp new file mode 100644 index 00000000..7bc01120 --- /dev/null +++ b/Yao/YaoEvaluator.cpp @@ -0,0 +1,61 @@ +/* + * YaoEvaluator.cpp + * + */ + +#include "YaoEvaluator.h" + +YaoEvaluator* YaoEvaluator::singleton = 0; + +YaoEvaluator::YaoEvaluator(string progname) : machine(MD), processor(machine) +{ + counter = 0; + + program.parse(progname + "-0"); + processor.reset(program); + + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; +} + +void YaoEvaluator::run() +{ + while(GC::DONE_BREAK != program.execute(processor, -1)) + ; +} + +void YaoEvaluator::run(Player& P) +{ + do + receive(P); + while(GC::DONE_BREAK != program.execute(processor, -1)); +} + +void YaoEvaluator::run_from_store() +{ + machine.reset_timer(); + do + { + gates_store.pop(gates); + output_masks_store.pop(output_masks); + } + while(GC::DONE_BREAK != program.execute(processor, -1)); +} + +void YaoEvaluator::receive(Player& P) +{ + P.receive_player(0, gates); + P.receive_player(0, output_masks); +} + +void YaoEvaluator::receive_to_store(Player& P) +{ + while (P.peek_long(0) != -1) + { + receive(P); + gates_store.push(gates); + output_masks_store.push(output_masks); + } + P.receive_long(0); +} diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h new file mode 100644 index 00000000..559c50d9 --- /dev/null +++ b/Yao/YaoEvaluator.h @@ -0,0 +1,63 @@ +/* + * YaoEvaluator.h + * + */ + +#ifndef YAO_YAOEVALUATOR_H_ +#define YAO_YAOEVALUATOR_H_ + +#include "YaoGate.h" +#include "GC/Secret.h" +#include "GC/Program.h" +#include "GC/Machine.h" +#include "GC/Processor.h" +#include "GC/Memory.h" +#include "Tools/MMO.h" + +class YaoEvaluator +{ +protected: + static YaoEvaluator* singleton; + + ReceivedMsg gates; + ReceivedMsgStore gates_store; + + GC::Program< GC::Secret > program; + GC::Machine< GC::Secret > machine; + GC::Processor< GC::Secret > processor; + GC::Memory::DynamicType> MD; + +public: + ReceivedMsg output_masks; + ReceivedMsgStore output_masks_store; + + MMO mmo; + + long counter; + + static YaoEvaluator& s(); + + YaoEvaluator(string progname); + void run(); + void run(Player& P); + void run_from_store(); + void receive(Player& P); + void receive_to_store(Player& P); + + void load_gate(YaoGate& gate); +}; + +inline void YaoEvaluator::load_gate(YaoGate& gate) +{ + gates.unserialize(gate); +} + +inline YaoEvaluator& YaoEvaluator::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("singleton unavailable"); +} + +#endif /* YAO_YAOEVALUATOR_H_ */ diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp new file mode 100644 index 00000000..de239c2d --- /dev/null +++ b/Yao/YaoGarbleWire.cpp @@ -0,0 +1,204 @@ +/* + * YaoWire.cpp + * + */ + +#include "YaoGarbleWire.h" +#include "YaoGate.h" +#include "YaoGarbler.h" + +void YaoGarbleWire::randomize(PRNG& prng) +{ + key = prng.get_doubleword(); + mask = prng.get_bit(); +#ifdef DEBUG + key = YaoGarbler::s().counter << 1; +#endif + key.set_signal(0); +} + +void YaoGarbleWire::random() +{ + mask = YaoGarbler::s().prng.get_bit(); + key = 0; +} + +void YaoGarbleWire::public_input(bool value) +{ + mask = value; + key = 0; +} + +void YaoGarbleWire::andrs(GC::Processor >& processor, + const vector& args) +{ +#ifdef YAO_TIMINGS + auto& garbler = YaoGarbler::s(); + TimeScope ts(garbler.and_timer), ts2(garbler.and_proc_timer), + ts3(garbler.and_main_thread_timer); +#endif + andrs_multithread(processor, args); +} + +void YaoGarbleWire::andrs_multithread(GC::Processor >& processor, + const vector& args) +{ + YaoGarbler& party = YaoGarbler::s(); + int total = processor.check_args(args, 4); + if (total < party.get_threshold()) + { + // run in single thread + andrs_singlethread(processor, args); + return; + } + + party.and_prepare_timer.start(); + processor.complexity += total; + SendBuffer& gates = party.gates; + gates.allocate(total * sizeof(YaoGate)); + int max_gates_per_thread = max(party.get_threshold() / 2, + (total + party.get_n_threads() - 1) / party.get_n_threads()); + int i_thread = 0, i_gate = 0, start = 0; + for (size_t j = 0; j < args.size(); j += 4) + { + i_gate += args[j]; + size_t end = j + 4; + if (i_gate >= max_gates_per_thread or end >= args.size()) + { + YaoGate* gate = (YaoGate*)gates.end(); + gates.skip(i_gate * sizeof(YaoGate)); + party.timers["Dispatch"].start(); + party.and_jobs[i_thread++].dispatch(processor.S, args, start, end, + i_gate, gate, party.counter); + party.timers["Dispatch"].stop(); + party.counter += i_gate; + i_gate = 0; + start = end; + } + } + party.and_prepare_timer.stop(); + party.and_wait_timer.start(); + for (int i = 0; i < i_thread; i++) + party.and_jobs[i].worker.done(); + party.and_wait_timer.stop(); +} + +void YaoGarbleWire::andrs_singlethread(GC::Processor >& processor, + const vector& args) +{ + int total_ands = processor.check_args(args, 4); + if (total_ands < 10) + return processor.andrs(args); + processor.complexity += total_ands; + size_t n_args = args.size(); + auto& garbler = YaoGarbler::s(); + SendBuffer& gates = garbler.gates; + YaoGate* gate = (YaoGate*)gates.allocate_and_skip(total_ands * sizeof(YaoGate)); + andrs(processor.S, args, 0, n_args, total_ands, gate, garbler.counter, + garbler.prng, garbler.timers); +} + +void YaoGarbleWire::andrs(GC::Memory >& S, + const vector& args, size_t start, size_t end, size_t total_ands, + YaoGate* gate, long& counter, PRNG& prng, map& timers) +{ + (void)timers; + Key* labels; + Key* hashes; + vector label_vec, hash_vec; + size_t n_hashes = 4 * total_ands; + Key label_arr[400], hash_arr[400]; + if (total_ands < 100) + { + labels = label_arr; + hashes = hash_arr; + } + else + { + label_vec.resize(n_hashes); + hash_vec.resize(n_hashes); + labels = label_vec.data(); + hashes = hash_vec.data(); + } + //timers["Hash input"].start(); + auto& garbler = YaoGarbler::s(); + const Key& delta = garbler.get_delta(); + size_t i_label = 0; + for (size_t i = start; i < end; i += 4) + { + const Key& right_key = S[args[i + 3]].get_reg(0).key; + for (auto& left_wire : S[args[i + 2]].get_regs()) + { + counter++; + for (int i = 0; i < 2; i++) + for (int j = 0; j < 2; j++) + labels[i_label++] = YaoGate::E_input( + left_wire.key ^ (i ? delta : 0), + right_key ^ (j ? delta : 0), counter); + } + } + //timers["Hash input"].stop(); + //timers["Hashing"].start(); + MMO& mmo = garbler.mmo; + size_t i; + for (i = 0; i + 8 <= n_hashes; i += 8) + mmo.hash<8>(&hashes[i], &labels[i]); + for (; i < n_hashes; i++) + hashes[i] = mmo.hash(labels[i]); + //timers["Hashing"].stop(); + //timers["Garbling"].start(); + size_t i_hash = 0; + for (size_t i = start; i < end; i += 4) + { + //timers["Outer ref"].start(); + YaoGarbleWire& right_wire = S[args[i + 3]].get_reg(0); + auto& out = S[args[i + 1]]; + //timers["Outer ref"].stop(); + //timers["Resizing"].start(); + out.resize_regs(args[i]); + //timers["Resizing"].stop(); + int n = args[i]; + for (int k = 0; k < n; k++) + { + //timers["Inner ref"].start(); + auto& left_wire = S[args[i + 2]].get_reg(k); + //timers["Inner ref"].stop(); + //timers["Randomizing"].start(); + out.get_reg(k).randomize(prng); + //timers["Randomizing"].stop(); + //timers["Gate computation"].start(); + (gate++)->garble(out.get_reg(k), &hashes[i_hash], left_wire.mask, + right_wire.mask, 0x0001); + //timers["Gate computation"].stop(); + i_hash += 4; + } + } + //timers["Garbling"].stop(); +} + +inline void YaoGarbler::store_gate(const YaoGate& gate) +{ + gates.serialize(gate); +} + +void YaoGarbleWire::op(const YaoGarbleWire& left, const YaoGarbleWire& right, + Function func) +{ + auto& garbler = YaoGarbler::s(); + randomize(garbler.prng); + YaoGarbler::s().counter++; + YaoGate gate(*this, left, right, func); + YaoGarbler::s().store_gate(gate); +} + +void YaoGarbleWire::XOR(const YaoGarbleWire& left, const YaoGarbleWire& right) +{ + mask = left.mask ^ right.mask; + key = left.key ^ right.key; +} + +char YaoGarbleWire::get_output() +{ + YaoGarbler::s().output_masks.push_back(mask); + return -1; +} diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h new file mode 100644 index 00000000..07382173 --- /dev/null +++ b/Yao/YaoGarbleWire.h @@ -0,0 +1,47 @@ +/* + * YaoWire.h + * + */ + +#ifndef YAO_YAOGARBLEWIRE_H_ +#define YAO_YAOGARBLEWIRE_H_ + +#include "BMR/Key.h" +#include "BMR/Register.h" +#include "GC/Processor.h" + +class YaoGate; + +class YaoGarbleWire : public Phase +{ +public: + static string name() { return "YaoGarbleWire"; } + + Key key; + bool mask; + + static YaoGarbleWire new_reg() { return {}; } + + static void andrs(GC::Processor>& processor, + const vector& args); + static void andrs_multithread( + GC::Processor>& processor, + const vector& args); + static void andrs_singlethread( + GC::Processor>& processor, + const vector& args); + static void andrs(GC::Memory>& S, + const vector& args, size_t start, size_t end, + size_t total_ands, YaoGate* gate, long& counter, PRNG& prng, + map& timers); + + void randomize(PRNG& prng); + + void random(); + void public_input(bool value); + void op(const YaoGarbleWire& left, const YaoGarbleWire& right, Function func); + void XOR(const YaoGarbleWire& left, const YaoGarbleWire& right); + char get_output(); +}; + +#endif /* YAO_YAOGARBLEWIRE_H_ */ diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp new file mode 100644 index 00000000..d1a27004 --- /dev/null +++ b/Yao/YaoGarbler.cpp @@ -0,0 +1,70 @@ +/* + * YaoGarbler.cpp + * + */ + +#include "YaoGarbler.h" +#include "YaoGate.h" + +YaoGarbler* YaoGarbler::singleton = 0; + +YaoGarbler::YaoGarbler(string progname, int threshold) : + machine(MD), processor(machine), threshold(threshold), + and_proc_timer(CLOCK_PROCESS_CPUTIME_ID), + and_main_thread_timer(CLOCK_THREAD_CPUTIME_ID) +{ + prng.ReSeed(); + delta = prng.get_doubleword(); + delta.set_signal(1); + counter = 0; +#ifdef DEBUG_DELTA + delta = 1; +#endif + + program.parse(progname + "-0"); + processor.reset(program); + + and_jobs = new YaoAndJob[get_n_threads()]; + + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; +} + +YaoGarbler::~YaoGarbler() +{ + delete[] and_jobs; +#ifdef YAO_TIMINGS + cout << "AND time: " << and_timer.elapsed() << endl; + cout << "AND process timer: " << and_proc_timer.elapsed() << endl; + cout << "AND main thread timer: " << and_main_thread_timer.elapsed() << endl; + cout << "AND prepare timer: " << and_prepare_timer.elapsed() << endl; + cout << "AND wait timer: " << and_wait_timer.elapsed() << endl; + for (auto& x : timers) + cout << x.first << " time:" << x.second.elapsed() << endl; +#endif +} + +void YaoGarbler::run() +{ + while(GC::DONE_BREAK != program.execute(processor, -1)) + ; +} + +void YaoGarbler::run(Player& P) +{ + GC::BreakType b = GC::TIME_BREAK; + while(GC::DONE_BREAK != b) + { + b = program.execute(processor, -1); + send(P); + gates.clear(); + output_masks.clear(); + } +} + +void YaoGarbler::send(Player& P) +{ + P.send_to(1, gates, true); + P.send_to(1, output_masks, true); +} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h new file mode 100644 index 00000000..cf31cc23 --- /dev/null +++ b/Yao/YaoGarbler.h @@ -0,0 +1,76 @@ +/* + * YaoGarbler.h + * + */ + +#ifndef YAO_YAOGARBLER_H_ +#define YAO_YAOGARBLER_H_ + +#include "YaoGarbleWire.h" +#include "YaoAndJob.h" +#include "Tools/random.h" +#include "Tools/MMO.h" +#include "GC/Secret.h" +#include "GC/Program.h" +#include "Networking/Player.h" +#include "sys/sysinfo.h" + +class YaoGate; + +class YaoGarbler +{ + friend class YaoGarbleWire; + +protected: + static YaoGarbler* singleton; + + Key delta; + SendBuffer gates; + + GC::Program< GC::Secret > program; + GC::Machine< GC::Secret > machine; + GC::Processor< GC::Secret > processor; + GC::Memory::DynamicType> MD; + + int threshold; + + Timer and_timer; + Timer and_proc_timer; + Timer and_main_thread_timer; + DoubleTimer and_prepare_timer; + DoubleTimer and_wait_timer; + +public: + PRNG prng; + SendBuffer output_masks; + long counter; + MMO mmo; + + YaoAndJob* and_jobs; + + map timers; + + static YaoGarbler& s(); + + YaoGarbler(string progname, int threshold = 1024); + ~YaoGarbler(); + void run(); + void run(Player& P); + void send(Player& P); + + const Key& get_delta() { return delta; } + void store_gate(const YaoGate& gate); + + int get_n_threads() { return get_nprocs(); } + int get_threshold() { return threshold; } +}; + +inline YaoGarbler& YaoGarbler::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("singleton unavailable"); +} + +#endif /* YAO_YAOGARBLER_H_ */ diff --git a/Yao/YaoGate.cpp b/Yao/YaoGate.cpp new file mode 100644 index 00000000..282c6bda --- /dev/null +++ b/Yao/YaoGate.cpp @@ -0,0 +1,42 @@ +/* + * YaoGate.cpp + * + */ + +#include "YaoGate.h" +#include "YaoGarbler.h" +#include "YaoEvaluator.h" +#include "BMR/prf.h" +#include "Tools/MMO.h" + +YaoGate::YaoGate(const YaoGarbleWire& out, const YaoGarbleWire& left, + const YaoGarbleWire& right, Function func) +{ + const Key& delta = YaoGarbler::s().get_delta(); + MMO& mmo = YaoGarbler::s().mmo; + Key hashes[4]; + for (int i = 0; i < 2; i++) + for (int j = 0; j < 2; j++) + hashes[2 * i + j] = mmo.hash( + E_input(left.key ^ (i ? delta : 0), + right.key ^ (j ? delta : 0), + YaoGarbler::s().counter)); + garble(out, hashes, left.mask, right.mask, func); +#ifdef DEBUG + cout << "left " << left.mask << " " << left.key << " " << (left.key ^ delta) << endl; + cout << "right " << right.mask << " " << right.key << " " << (right.key ^ delta) << endl; + cout << "out " << out.mask << " " << out.key << " " << (out.key ^ delta) << endl; +#endif +} + +void YaoGate::eval(YaoEvalWire& out, const YaoEvalWire& left, const YaoEvalWire& right) +{ + MMO& mmo = YaoEvaluator::s().mmo; + Key key = E_input(left.key, right.key, YaoEvaluator::s().counter); + eval(out, mmo.hash(key), get_entry(left.external, right.external)); +#ifdef DEBUG + cout << "external " << left.external << " " << right.external << endl; + cout << "entry " << get_entry(left.external, right.external) << endl; + cout << "out " << out.key << endl; +#endif +} diff --git a/Yao/YaoGate.h b/Yao/YaoGate.h new file mode 100644 index 00000000..75198d88 --- /dev/null +++ b/Yao/YaoGate.h @@ -0,0 +1,78 @@ +/* + * YaoGate.h + * + */ + +#ifndef YAO_YAOGATE_H_ +#define YAO_YAOGATE_H_ + +#include "BMR/Key.h" +#include "YaoGarbleWire.h" +#include "YaoEvalWire.h" +#include "YaoGarbler.h" + +class YaoGate +{ + Key entries[2][2]; +public: + static Key E_input(const Key& left, const Key& right, long T); + + YaoGate() {} + YaoGate(const YaoGarbleWire& out, const YaoGarbleWire& left, + const YaoGarbleWire& right, Function func); + void garble(const YaoGarbleWire& out, const Key* hashes, bool left_mask, + bool right_mask, Function func); + void eval(YaoEvalWire& out, const YaoEvalWire& left, const YaoEvalWire& right); + void eval(YaoEvalWire& out, const Key& hash, + const Key& entry); + const Key& get_entry(bool left, bool right) { return entries[left][right]; } +}; + +inline Key YaoGate::E_input(const Key& left, const Key& right, long T) +{ + Key res = left.doubling(1) ^ right.doubling(2) ^ T; +#ifdef DEBUG + cout << "E " << res << ": " << left.doubling(1) << " " << right.doubling(2) + << " " << T << endl; +#endif + return res; +} + +inline void YaoGate::garble(const YaoGarbleWire& out, const Key* hashes, + bool left_mask, bool right_mask, Function func) +{ + const Key& delta = YaoGarbler::s().get_delta(); + for (int left_value = 0; left_value < 2; left_value++) + for (int right_value = 0; right_value < 2; right_value++) + { + Key key = out.key; + if (func.call(left_value, right_value) ^ out.mask) + key += delta; +#ifdef DEBUG + cout << "start key " << key << endl; +#endif + key += hashes[2 * (left_value ^ left_mask) + (right_value ^ right_mask)]; +#ifdef DEBUG + cout << "after left " << key << endl; +#endif + entries[left_value ^ left_mask][right_value ^ right_mask] = key; + } +#ifdef DEBUG + cout << "counter " << YaoGarbler::s().counter << endl; + for (int i = 0; i < 2; i++) + for (int j = 0; j < 2; j++) + cout << "entry " << i << " " << j << " " << entries[i][j] << endl; +#endif +} + +inline void YaoGate::eval(YaoEvalWire& out, const Key& hash, const Key& entry) +{ + Key key = entry; + key -= hash; +#ifdef DEBUG + cout << "after left " << key << endl; +#endif + out.set(key); +} + +#endif /* YAO_YAOGATE_H_ */ diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp new file mode 100644 index 00000000..0eea3406 --- /dev/null +++ b/Yao/YaoPlayer.cpp @@ -0,0 +1,110 @@ +/* + * YaoPlayer.cpp + * + */ + +#include "YaoPlayer.h" +#include "YaoGarbler.h" +#include "YaoEvaluator.h" +#include "Tools/ezOptionParser.h" + +YaoPlayer::YaoPlayer(int argc, const char** argv) +{ + ez::ezOptionParser opt; + opt.add( + "", // Default. + 1, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "This player's number, 0 for garbling, 1 for evaluating.", // Help description. + "-p", // Flag token. + "--player" // Flag token. + ); + opt.add( + "localhost", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Host where party 0 is running (default: localhost)", // Help description. + "-h", // Flag token. + "--hostname" // Flag token. + ); + opt.add( + "5000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Base port number (default: 5000).", // Help description. + "-pn", // Flag token. + "--portnum" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Evaluate while garbling (default: false).", // Help description. + "-C", // Flag token. + "--continuous" // Flag token. + ); + opt.add( + "1024", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Minimum number of gates for multithreading (default: 1024).", // Help description. + "-t", // Flag token. + "--threshold" // Flag token. + ); + opt.parse(argc, argv); + opt.syntax = "./yao-player.x [OPTIONS] "; + if (opt.lastArgs.size() == 1) + { + progname = *opt.lastArgs[0]; + } + else + { + string usage; + opt.getUsage(usage); + cerr << usage; + exit(1); + } + + int my_num; + int pnb; + string hostname; + int threshold; + opt.get("-p")->getInt(my_num); + opt.get("-pn")->getInt(pnb); + opt.get("-h")->getString(hostname); + bool continuous = opt.get("-C")->isSet; + opt.get("-t")->getInt(threshold); + + server = Server::start_networking(N, my_num, 2, hostname, pnb); + Player P(N); + + if (my_num == 0) + { + YaoGarbler garbler(progname, threshold); + garbler.run(P); + if (not continuous) + P.send_long(1, -1); + } + else + { + YaoEvaluator evaluator(progname); + if (continuous) + evaluator.run(P); + else + { + evaluator.receive_to_store(P); + evaluator.run_from_store(); + } + } +} + +YaoPlayer::~YaoPlayer() +{ + if (server) + delete server; +} diff --git a/Yao/YaoPlayer.h b/Yao/YaoPlayer.h new file mode 100644 index 00000000..6404d107 --- /dev/null +++ b/Yao/YaoPlayer.h @@ -0,0 +1,23 @@ +/* + * YaoPlayer.h + * + */ + +#ifndef YAO_YAOPLAYER_H_ +#define YAO_YAOPLAYER_H_ + +#include "Networking/Player.h" +#include "Networking/Server.h" + +class YaoPlayer +{ + string progname; + Names N; + Server* server; + +public: + YaoPlayer(int argc, const char** argv); + ~YaoPlayer(); +}; + +#endif /* YAO_YAOPLAYER_H_ */ diff --git a/Yao/YaoSimulator.cpp b/Yao/YaoSimulator.cpp new file mode 100644 index 00000000..4f9cabd6 --- /dev/null +++ b/Yao/YaoSimulator.cpp @@ -0,0 +1,18 @@ +/* + * YaoSimulator.cpp + * + */ + +#include "YaoSimulator.h" + +YaoSimulator::YaoSimulator(string progname) : YaoEvaluator(progname), YaoGarbler(progname) +{ +} + +void YaoSimulator::run() +{ + YaoGarbler::run(); + YaoEvaluator::output_masks = YaoGarbler::output_masks; + YaoEvaluator::gates = YaoGarbler::gates; + YaoEvaluator::run(); +} diff --git a/Yao/YaoSimulator.h b/Yao/YaoSimulator.h new file mode 100644 index 00000000..7c864d2d --- /dev/null +++ b/Yao/YaoSimulator.h @@ -0,0 +1,19 @@ +/* + * YaoSimulator.h + * + */ + +#ifndef YAO_YAOSIMULATOR_H_ +#define YAO_YAOSIMULATOR_H_ + +#include "YaoEvaluator.h" +#include "YaoGarbler.h" + +class YaoSimulator : public YaoEvaluator, public YaoGarbler +{ +public: + YaoSimulator(string progname); + void run(); +}; + +#endif /* YAO_YAOSIMULATOR_H_ */ diff --git a/yao-player.cpp b/yao-player.cpp new file mode 100644 index 00000000..021cc27b --- /dev/null +++ b/yao-player.cpp @@ -0,0 +1,11 @@ +/* + * yao-player.cpp + * + */ + +#include "Yao/YaoPlayer.h" + +int main(int argc, const char** argv) +{ + YaoPlayer(argc, argv); +} diff --git a/yao-simulate.cpp b/yao-simulate.cpp new file mode 100644 index 00000000..2fb37d88 --- /dev/null +++ b/yao-simulate.cpp @@ -0,0 +1,13 @@ +/* + * yao-simulate.cpp + * + */ + +#include "Yao/YaoSimulator.h" + +int main(int argc, char** argv) +{ + if (argc < 1) + throw exception(); + YaoSimulator(argv[1]).run(); +}