From 0603e43375f9a893b8270373b3e8875583ca94e9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 4 Nov 2021 17:37:02 +1100 Subject: [PATCH 001/265] Fix deprecated Sphinx interface. --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index a57f08fa..57f730ad 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -188,4 +188,4 @@ import subprocess subprocess.call('doxygen', shell=True) def setup(app): - app.add_stylesheet('custom.css') + app.add_css_file('custom.css') From ab637517881270b9fdf5e82bfb49ff69b7170f7f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 20 Nov 2021 18:01:26 +1100 Subject: [PATCH 002/265] Instruction output functionality. --- Processor/Instruction.cpp | 32 ++++++++++++ Processor/Instruction.hpp | 20 ++------ Processor/instructions.h | 104 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 16 deletions(-) diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index db03d6d0..68acda3b 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -7,6 +7,7 @@ #include "instructions.h" #include "Processor.h" #include "Math/gf2n.h" +#include "GC/instructions.h" #include @@ -89,6 +90,37 @@ void Instruction::bitdecint(ArithmeticProcessor& Proc) const } } +ostream& operator<<(ostream& s, const Instruction& instr) +{ + switch (instr.get_opcode()) + { +#define X(NAME, PRE, CODE) \ + case NAME: s << #NAME; break; + ALL_INSTRUCTIONS +#undef X +#define X(NAME, CODE) \ + case NAME: s << #NAME; break; + COMBI_INSTRUCTIONS + } + + s << " size=" << instr.get_size(); + s << " n=" << instr.get_n(); + s << " r=("; + for (int i = 0; i < 3; i++) + s << instr.get_r(i) << ", "; + s << instr.get_r(3); + s << ")"; + if (not instr.get_start().empty()) + { + s << " args=("; + for (unsigned i = 0; i < instr.get_start().size() - 1; i++) + s << instr.get_start()[i] << ", "; + s << instr.get_start().back(); + s << ")"; + } + return s; +} + template void Instruction::execute_clear_gf2n(vector& registers, vector& memory, ArithmeticProcessor& Proc) const; template void Instruction::execute_clear_gf2n(vector& registers, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 72184d9e..27bec2b3 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -805,22 +805,6 @@ bool BaseInstruction::is_direct_memory_access() const } -inline -ostream& operator<<(ostream& s,const Instruction& instr) -{ - s << instr.opcode << " : "; - for (int i=0; i<3; i++) - { s << instr.r[i] << " "; } - s << " : " << instr.n; - if (instr.start.size()!=0) - { s << " : " << instr.start.size() << " : "; - for (unsigned int i=0; i inline void Instruction::execute(Processor& Proc) const { @@ -1287,6 +1271,10 @@ void Program::execute(Processor& Proc) const Proc.stats[p[Proc.PC].get_opcode()]++; #endif +#ifdef OUTPUT_INSTRUCTIONS + cerr << instruction << endl; +#endif + Proc.PC++; switch(instruction.get_opcode()) diff --git a/Processor/instructions.h b/Processor/instructions.h index 49901822..5928fdab 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -280,4 +280,108 @@ X(GRAWOUTPUT, auto source = &C2[r[0]], \ (*source++).output(Proc.public_output, false)) \ +#define REMAINING_INSTRUCTIONS \ + X(CONVMODP, throw not_implemented(),) \ + X(LDMC, throw not_implemented(),) \ + X(LDMCI, throw not_implemented(),) \ + X(STMC, throw not_implemented(),) \ + X(STMCI, throw not_implemented(),) \ + X(MOVC, throw not_implemented(),) \ + X(DIVC, throw not_implemented(),) \ + X(GDIVC, throw not_implemented(),) \ + X(FLOORDIVC, throw not_implemented(),) \ + X(MODC, throw not_implemented(),) \ + X(LEGENDREC, throw not_implemented(),) \ + X(DIGESTC, throw not_implemented(),) \ + X(DIVCI, throw not_implemented(),) \ + X(GDIVCI, throw not_implemented(),) \ + X(INV2M, throw not_implemented(),) \ + X(MODCI, throw not_implemented(),) \ + X(SQUARE, throw not_implemented(),) \ + X(GSQUARE, throw not_implemented(),) \ + X(INV, throw not_implemented(),) \ + X(GINV, throw not_implemented(),) \ + X(RANDOMS, throw not_implemented(),) \ + X(INPUTMASKREG, throw not_implemented(),) \ + X(INPUTMASK, throw not_implemented(),) \ + X(GINPUTMASK, throw not_implemented(),) \ + X(INPUT, throw not_implemented(),) \ + X(GINPUT, throw not_implemented(),) \ + X(INPUTFIX, throw not_implemented(),) \ + X(INPUTFLOAT, throw not_implemented(),) \ + X(INPUTMIXED, throw not_implemented(),) \ + X(INPUTMIXEDREG, throw not_implemented(),) \ + X(RAWINPUT, throw not_implemented(),) \ + X(GRAWINPUT, throw not_implemented(),) \ + X(INPUTPERSONAL, throw not_implemented(),) \ + X(NOTC, throw not_implemented(),) \ + X(SHRSI, throw not_implemented(),) \ + X(OPEN, throw not_implemented(),) \ + X(GOPEN, throw not_implemented(),) \ + X(MULS, throw not_implemented(),) \ + X(GMULS, throw not_implemented(),) \ + X(MULRS, throw not_implemented(),) \ + X(GMULRS, throw not_implemented(),) \ + X(DOTPRODS, throw not_implemented(),) \ + X(GDOTPRODS, throw not_implemented(),) \ + X(MATMULS, throw not_implemented(),) \ + X(MATMULSM, throw not_implemented(),) \ + X(CONV2DS, throw not_implemented(),) \ + X(TRUNC_PR, throw not_implemented(),) \ + X(CHECK, throw not_implemented(),) \ + X(JMP, throw not_implemented(),) \ + X(JMPI, throw not_implemented(),) \ + X(JMPNZ, throw not_implemented(),) \ + X(JMPEQZ, throw not_implemented(),) \ + X(PRINTREG, throw not_implemented(),) \ + X(PRINTREGPLAIN, throw not_implemented(),) \ + X(CONDPRINTPLAIN, throw not_implemented(),) \ + X(PRINTFLOATPLAIN, throw not_implemented(),) \ + X(CONDPRINTSTR, throw not_implemented(),) \ + X(REQBL, throw not_implemented(),) \ + X(GREQBL, throw not_implemented(),) \ + X(USE, throw not_implemented(),) \ + X(USE_INP, throw not_implemented(),) \ + X(USE_EDABIT, throw not_implemented(),) \ + X(USE_MATMUL, throw not_implemented(),) \ + X(USE_PREP, throw not_implemented(),) \ + X(GUSE_PREP, throw not_implemented(),) \ + X(TIME, throw not_implemented(),) \ + X(START, throw not_implemented(),) \ + X(STOP, throw not_implemented(),) \ + X(RUN_TAPE, throw not_implemented(),) \ + X(JOIN_TAPE, throw not_implemented(),) \ + X(CRASH, throw not_implemented(),) \ + X(STARTGRIND, throw not_implemented(),) \ + X(STOPGRIND, throw not_implemented(),) \ + X(NPLAYERS, throw not_implemented(),) \ + X(THRESHOLD, throw not_implemented(),) \ + X(PLAYERID, throw not_implemented(),) \ + X(LISTEN, throw not_implemented(),) \ + X(ACCEPTCLIENTCONNECTION, throw not_implemented(),) \ + X(CLOSECLIENTCONNECTION, throw not_implemented(),) \ + X(READSOCKETINT, throw not_implemented(),) \ + X(READSOCKETC, throw not_implemented(),) \ + X(READSOCKETS, throw not_implemented(),) \ + X(WRITESOCKETINT, throw not_implemented(),) \ + X(WRITESOCKETC, throw not_implemented(),) \ + X(WRITESOCKETS, throw not_implemented(),) \ + X(WRITESOCKETSHARE, throw not_implemented(),) \ + X(WRITEFILESHARE, throw not_implemented(),) \ + X(READFILESHARE, throw not_implemented(),) \ + X(PUBINPUT, throw not_implemented(),) \ + X(RAWOUTPUT, throw not_implemented(),) \ + X(INTOUTPUT, throw not_implemented(),) \ + X(FLOATOUTPUT, throw not_implemented(),) \ + X(STARTPRIVATEOUTPUT, throw not_implemented(),) \ + X(GSTARTPRIVATEOUTPUT, throw not_implemented(),) \ + X(STOPPRIVATEOUTPUT, throw not_implemented(),) \ + X(GSTOPPRIVATEOUTPUT, throw not_implemented(),) \ + X(PREP, throw not_implemented(),) \ + X(GPREP, throw not_implemented(),) \ + X(CISC, throw not_implemented(),) \ + +#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ + CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS + #endif /* PROCESSOR_INSTRUCTIONS_H_ */ From eac6456ec85213c97d0d6004478b3c70d0edf489 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 22 Nov 2021 17:50:58 +1100 Subject: [PATCH 003/265] Allow preprocessing information to be supplied via named pipes. --- Processor/Data_Files.hpp | 5 ++- Processor/OnlineOptions.cpp | 16 +++++++ Processor/OnlineOptions.h | 1 + Processor/PrepBase.cpp | 10 ++++- Programs/Source/test_thread_mul.mpc | 11 +++++ Scripts/test_streaming.sh | 17 ++++++++ Tools/Buffer.cpp | 20 ++++++++- Tools/Buffer.h | 1 + Utils/stream-fake-mascot-triples.cpp | 65 ++++++++++++++++++++++++++++ 9 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 Programs/Source/test_thread_mul.mpc create mode 100755 Scripts/test_streaming.sh create mode 100644 Utils/stream-fake-mascot-triples.cpp diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 6951ed2c..3635dc0a 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -26,7 +26,7 @@ Preprocessing* Preprocessing::get_new( return get_live_prep(proc, usage); else return new Sub_Data_Files(machine.get_N(), - machine.template prep_dir_prefix(), usage); + machine.template prep_dir_prefix(), usage, BaseMachine::thread_num); } template @@ -185,6 +185,9 @@ Sub_Data_Files::~Sub_Data_Files() template void Sub_Data_Files::seekg(DataPositions& pos) { + if (OnlineOptions::singleton.file_prep_per_thread) + return; + if (T::LivePrep::use_part) { get_part().seekg(pos); diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 03fa2379..41308603 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -28,6 +28,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) bucket_size = 4; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; + file_prep_per_thread = false; #ifdef VERBOSE verbose = true; #else @@ -170,6 +171,16 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--live-preprocessing" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Preprocessing from files by thread (use with pipes)", // Help description. + "-f", // Flag token. + "--file-prep-per-thread" // Flag token. + ); + opt.add( to_string(default_batch_size).c_str(), // Default. 0, // Required? @@ -224,6 +235,11 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, live_prep = not opt.get("-F")->isSet; else live_prep = opt.get("-L")->isSet; + if (opt.isSet("-f")) + { + live_prep = false; + file_prep_per_thread = true; + } opt.get("-b")->getInt(batch_size); opt.get("--memory")->getString(memtype); bits_from_squares = opt.isSet("-Q"); diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 32c80fc2..de8f1e72 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -29,6 +29,7 @@ public: std::string cmd_private_input_file; std::string cmd_private_output_file; bool verbose; + bool file_prep_per_thread; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index b99d091e..5c44b908 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -6,11 +6,17 @@ #include "PrepBase.h" #include "Data_Files.h" +#include "OnlineOptions.h" string PrepBase::get_suffix(int thread_num) { - (void) thread_num; - return ""; + if (OnlineOptions::singleton.file_prep_per_thread) + { + assert(thread_num >= 0); + return "-T" + to_string(thread_num); + } + else + return ""; } string PrepBase::get_filename(const string& prep_data_dir, diff --git a/Programs/Source/test_thread_mul.mpc b/Programs/Source/test_thread_mul.mpc new file mode 100644 index 00000000..42097fd1 --- /dev/null +++ b/Programs/Source/test_thread_mul.mpc @@ -0,0 +1,11 @@ +n = 1000000 +x = sint.Array(n) +x.assign_vector(regint.inc(n)) + +@multithread(2, n) +def _(base, size): + x.assign_vector(x.get_vector(base, size) ** 2, base) + +print_ln('%s', x[2].reveal()) +crash(x[2].reveal() != 4) +crash(x[n - 1].reveal() != (n - 1) ** 2) diff --git a/Scripts/test_streaming.sh b/Scripts/test_streaming.sh new file mode 100755 index 00000000..0ff2fb33 --- /dev/null +++ b/Scripts/test_streaming.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +make stream-fake-mascot-triples.x +./compile.py test_thread_mul || exit 1 + +rm Player-Data/2-p-128/Triples-p-P?-T? +mkdir Player-Data/2-p-128 + +for i in 0 1; do + for j in 0 1 2; do + mknod Player-Data/2-p-128/Triples-p-P$i-T$j p || exit 1 + done +done + +./stream-fake-mascot-triples.x & + +Scripts/mascot.sh test_thread_mul -f || exit 1 diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 75cb8b6e..c669081f 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -7,6 +7,8 @@ #include "Processor/BaseMachine.h" #include +#include +#include bool BufferBase::rewind = false; @@ -21,8 +23,19 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, this->filename = filename; } +bool BufferBase::is_pipe() +{ + struct stat buf; + if (stat(filename.c_str(), &buf)) + return S_ISFIFO(buf.st_mode); + else + return false; +} + void BufferBase::seekg(int pos) { + assert(not is_pipe()); + #ifdef DEBUG_BUFFER if (pos != 0) printf("seek %d %s thread %d\n", pos, filename.c_str(), @@ -52,6 +65,8 @@ void BufferBase::seekg(int pos) void BufferBase::try_rewind() { + assert(not is_pipe()); + #ifndef INSECURE string type; if (field_type.size() and data_type.size()) @@ -70,6 +85,9 @@ void BufferBase::try_rewind() void BufferBase::prune() { + if (is_pipe()) + return; + if (file and (not file->good() or file->peek() == EOF)) purge(); else if (file and file->tellg() != header_length) @@ -99,7 +117,7 @@ void BufferBase::prune() void BufferBase::purge() { - if (file) + if (file and not is_pipe()) { #ifdef VERBOSE cerr << "Removing " << filename << endl; diff --git a/Tools/Buffer.h b/Tools/Buffer.h index a95dee0d..941ec425 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -44,6 +44,7 @@ public: const char* type = "", const string& field = {}); void seekg(int pos); bool is_up() { return file != 0; } + bool is_pipe(); void try_rewind(); void prune(); void purge(); diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp new file mode 100644 index 00000000..5aa85a05 --- /dev/null +++ b/Utils/stream-fake-mascot-triples.cpp @@ -0,0 +1,65 @@ +/* + * stream-fake-mascot-triples.cpp + * + */ + +#include "Protocols/Share.h" +#include "Math/gfpvar.h" +#include "Tools/benchmarking.h" + +#include "Math/Setup.hpp" +#include "Protocols/fake-stuff.hpp" + +class Info +{ +public: + int thread_num; + int nplayers; + gfpvar key; + pthread_t thread; +}; + +void* run(void* arg) +{ + auto& info = *(Info*) arg; + Files> files(info.nplayers, info.key, PREP_DIR, DATA_TRIPLE, info.thread_num); + SeededPRNG G; + int count = 0; + while (true) + { + gfpvar triple[3]; + for (int i = 0; i < 2; i++) + triple[i].randomize(G); + triple[2] = triple[0] * triple[1]; + for (int i = 0; i < 3; i++) + files.output_shares(triple[i]); + count++; + } + cerr << "failed after " << count << endl; + return 0; +} + +int main() +{ + insecure("preprocessing"); + typedef Share T; + int nplayers = 2; + int lgp = 128; + string prep_data_prefix = PREP_DIR; + gfpvar::generate_setup(prep_data_prefix, nplayers, lgp); + T::mac_key_type keyp; + generate_mac_keys(keyp, nplayers, prep_data_prefix); + + int nthreads = 3; + OnlineOptions::singleton.file_prep_per_thread = true; + vector infos(3); + for (int i = 0; i < nthreads; i++) + { + auto& info = infos[i]; + info.thread_num = i; + info.nplayers = nplayers; + info.key = keyp; + pthread_create(&info.thread, 0, run, &info); + } + pthread_join(infos[0].thread, 0); +} From 10f43e281e7a13f153dece6b27b015da38548b9f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 23 Nov 2021 21:19:52 +1100 Subject: [PATCH 004/265] Fix cleartext comparisons with larger primes. --- Compiler/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 151c805e..ec062b5c 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1011,8 +1011,9 @@ class cint(_clear, _int): return regint(self) < regint(other) else: diff = self - other + diff += (1 << (bit_length - 1)) shifted = diff >> (bit_length - 1) - res = regint(shifted & 1) + res = 1 - regint(shifted & 1) return res def __lt__(self, other): From 40431cd52a2dfb179ed958ef058e27bfbedc64f0 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 25 Nov 2021 11:15:54 +1100 Subject: [PATCH 005/265] Fix bug in early abort. --- Compiler/ml.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index a8be2d53..5ff1a375 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -2011,7 +2011,8 @@ class Optimizer: i.iadd(1) res = True if self.tol > 0: - res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() + res *= (1 - (loss_sum >= 0) * \ + (loss_sum < self.tol * n_per_epoch)).reveal() return res def reveal_correctness(self, data, truth, batch_size): From 00a2bcba45ea77fe490b30f5b23cb1b6cb58f875 Mon Sep 17 00:00:00 2001 From: rtaiello <41542771+rtaiello@users.noreply.github.com> Date: Mon, 29 Nov 2021 14:11:38 +0100 Subject: [PATCH 006/265] Fix problem with Scripts/run-common.sh execution --- ExternalIO/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 02b4e1e8..36649f5c 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -14,7 +14,7 @@ make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh & +Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 From ae99bed1927dbefd36258971001339587fc3410f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 2 Dec 2021 18:12:11 +1100 Subject: [PATCH 007/265] Access clear bit memory by run-time indices. --- Compiler/GC/instructions.py | 22 ++++++++++++++++++++++ Compiler/GC/types.py | 4 ++-- GC/Instruction.h | 2 ++ GC/instructions.h | 2 ++ Processor/Instruction.hpp | 2 ++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index f1b5ad23..fc64ae2d 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -64,6 +64,8 @@ opcodes = dict( MULCBI = 0x21c, SHRCBI = 0x21d, SHLCBI = 0x21e, + LDMCBI = 0x258, + STMCBI = 0x259, CONVCINTVEC = 0x21f, PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, @@ -360,6 +362,26 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): code = opcodes['STMSBI'] arg_format = ['sb','ci'] +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): + """ Copy clear bit memory cell with run-time address to clear bit + register. + + :param: destination (cbit) + :param: memory address (regint) + """ + code = opcodes['LDMCBI'] + arg_format = ['cbw','ci'] + +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): + """ Copy clear bit register to clear bit memory cell with run-time + address. + + :param: source (cbit) + :param: memory address (regint) + """ + code = opcodes['STMCBI'] + arg_format = ['cb','ci'] + class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] arg_format = tools.cycle(['sbw','cb','int']) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 6a5e39f1..5a65e73a 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -228,8 +228,8 @@ class cbits(bits): max_length = 64 reg_type = 'cb' is_clear = True - load_inst = (None, inst.ldmcb) - store_inst = (None, inst.stmcb) + load_inst = (inst.ldmcbi, inst.ldmcb) + store_inst = (inst.stmcbi, inst.stmcb) bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) conv_cint_vec = inst.convcintvec diff --git a/GC/Instruction.h b/GC/Instruction.h index f8872f12..e990f954 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -81,6 +81,8 @@ enum SHRCBI = 0x21d, SHLCBI = 0x21e, CONVCINTVEC = 0x21f, + LDMCBI = 0x258, + STMCBI = 0x259, // don't write PRINTREGSIGNED = 0x220, PRINTREGB = 0x221, diff --git a/GC/instructions.h b/GC/instructions.h index f94da799..fb44e4e0 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -61,6 +61,8 @@ X(STMCB, PROC.mem_op(SIZE, MMC, PROC.C, IMM, R0)) \ X(LDMSBI, PROC.mem_op(SIZE, PROC.S, MMS, R0, Ci[REG1])) \ X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \ + X(LDMCBI, PROC.mem_op(SIZE, PROC.C, MMC, R0, Ci[REG1])) \ + X(STMCBI, PROC.mem_op(SIZE, MMC, PROC.C, Ci[REG1], R0)) \ X(MOVSB, S0 = PS1) \ X(TRANS, T::trans(PROC, IMM, EXTRA)) \ X(BITB, PROC.random_bit(S0)) \ diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 27bec2b3..e516fdf3 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -101,6 +101,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STMSI: case LDMSBI: case STMSBI: + case LDMCBI: + case STMCBI: case MOVC: case MOVS: case MOVSB: From b771417e04aa60582a357de14331eaa59a23c101 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 4 Dec 2021 20:38:56 +1100 Subject: [PATCH 008/265] Clear to secret bit conversion with Yao's garbled circuits. --- BMR/Register.h | 5 +++++ GC/FakeSecret.h | 4 ++++ GC/Processor.h | 3 +++ GC/Processor.hpp | 12 ++++++++++++ GC/Secret.h | 5 +++++ GC/ShareSecret.h | 5 +++++ GC/instructions.h | 3 ++- Processor/Processor.hpp | 11 ----------- Yao/YaoEvalWire.cpp | 13 +++++++++++++ Yao/YaoEvalWire.h | 6 ++++++ Yao/YaoGarbleWire.cpp | 15 +++++++++++++++ Yao/YaoGarbleWire.h | 6 ++++++ 12 files changed, 76 insertions(+), 12 deletions(-) diff --git a/BMR/Register.h b/BMR/Register.h index 4d0c1b07..d0a75e93 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -23,6 +23,7 @@ using namespace std; #include "Tools/PointerVector.h" #include "Tools/Bundle.h" #include "Tools/SwitchableOutput.h" +#include "Processor/Instruction.h" //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -289,6 +290,10 @@ public: static void inputbvec(T& processor, ProcessorBase& input_processor, const vector& args); + template + static void convcbit2s(GC::Processor&, const BaseInstruction&) + { throw runtime_error("convcbit2s not implemented"); } + // most BMR phases don't need actual input template static T get_input(GC::Processor& processor, const InputArgs& args) diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 73013efa..55c537de 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -105,6 +105,10 @@ public: template static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; } + template + static void convcbit2s(GC::Processor&, const BaseInstruction&) + { throw runtime_error("convcbit2s not implemented"); } + static FakeSecret input(GC::Processor& processor, const InputArgs& args); static FakeSecret input(int from, word input, int n_bits); diff --git a/GC/Processor.h b/GC/Processor.h index 2dddf8df..3cc0c509 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -100,6 +100,9 @@ public: void reveal(const vector& args); + template + void convcbit2s(const BaseInstruction& instruction); + void print_reg(int reg, int n, int size); void print_reg_plain(Clear& value); void print_reg_signed(unsigned n_bits, Integer value); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 016352d3..663d55fc 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -331,6 +331,18 @@ void Processor::reveal(const vector& args) } } +template +template +void Processor::convcbit2s(const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + auto& share_thread = ShareThread::s(); + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + S[instruction.get_r(0) + i] = T::constant(C[instruction.get_r(1) + i], + share_thread.P->my_num(), share_thread.MC->get_alphai(), + min(unsigned(unit), instruction.get_n() - i * unit)); +} + template void Processor::print_reg(int reg, int n, int size) { diff --git a/GC/Secret.h b/GC/Secret.h index 6b37aa21..14f6638a 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -18,6 +18,7 @@ #include "Math/gf2nlong.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Tools/FixedVector.h" @@ -122,6 +123,10 @@ public: Processor& proc) { T::convcbit(dest, source, proc); } + template + static void convcbit2s(Processor& processor, const BaseInstruction& instruction) + { T::convcbit2s(processor, instruction); } + Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 9ea0d2f6..10cf65c0 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -21,6 +21,7 @@ using namespace std; #include "Protocols/ReplicatedMC.h" #include "Processor/DummyProtocol.h" #include "Processor/ProcessorBase.h" +#include "Processor/Instruction.h" namespace GC { @@ -74,6 +75,10 @@ public: template static void convcbit(Integer& dest, const Clear& source, T&) { dest = source; } + template + static void convcbit2s(Processor& processor, const BaseInstruction& instruction) + { processor.convcbit2s(instruction); } + static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } void check_length(int n, const Integer& x); diff --git a/GC/instructions.h b/GC/instructions.h index fb44e4e0..fc278d44 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -82,7 +82,7 @@ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ X(CONVCINTVEC, Proc.convcintvec(instruction)) \ X(CONVCBITVEC, Proc.convcbitvec(instruction)) \ - X(CONVCBIT2S, Proc.convcbit2s(instruction)) \ + X(CONVCBIT2S, PROC.convcbit2s(instruction)) \ X(DABIT, Proc.dabit(INST)) \ X(EDABIT, Proc.edabit(INST)) \ X(SEDABIT, Proc.edabit(INST, true)) \ @@ -99,6 +99,7 @@ X(CONVSINT, S0.load_clear(IMM, PI1)) \ X(CONVCINT, C0 = PI1) \ X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \ + X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \ X(PRINTCHR, PROC.print_chr(IMM)) \ X(PRINTSTR, PROC.print_str(IMM)) \ X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \ diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index ebbc1c8c..6206e27c 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -219,17 +219,6 @@ void Processor::convcintvec(const Instruction& instruction) } } -template -void Processor::convcbit2s(const Instruction& instruction) -{ - int unit = GC::Clear::N_BITS; - for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) - Procb.S[instruction.get_r(0) + i] = sint::bit_type::constant( - Procb.C[instruction.get_r(1) + i], P.my_num(), - share_thread.MC->get_alphai(), - min(unsigned(unit), instruction.get_n() - i * unit)); -} - template void Processor::split(const Instruction& instruction) { diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index b3240e08..38cdc922 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -243,6 +243,19 @@ void YaoEvalWire::reveal_inst(Processor& processor, const vector& args) } } +void YaoEvalWire::convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto& dest = processor.S[instruction.get_r(0) + i]; + dest.resize_regs(min(unsigned(unit), instruction.get_n() - i * unit)); + for (auto& reg : dest.get_regs()) + reg.set(0); + } +} + template void YaoEvalWire::and_( GC::Processor >& processor, const vector& args); diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index dc5d45a9..796d3561 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -10,6 +10,7 @@ #include "BMR/Gate.h" #include "BMR/Register.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "config.h" #include "YaoWire.h" @@ -19,6 +20,8 @@ class ProcessorBase; class YaoEvalWire : public YaoWire { + typedef GC::Secret whole_type; + public: typedef YaoEvaluator Party; typedef YaoEvalInput Input; @@ -61,6 +64,9 @@ public: GC::Processor>&); static void reveal_inst(Processor& processor, const vector& args); + static void convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction); + void set(const Key& key); void set(Key key, bool external); diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 7e52602e..37931df4 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -230,3 +230,18 @@ void YaoGarbleWire::reveal_inst(Processor& processor, const vector& args) else processor.reveal(args); } + +void YaoGarbleWire::convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction) +{ + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto& dest = processor.S[instruction.get_r(0) + i]; + int n = min(unsigned(unit), instruction.get_n() - i * unit); + dest.resize_regs(n); + for (int j = 0; j < n; j++) + dest.get_reg(i).public_input( + processor.C[instruction.get_r(1) + i].get_bit(j)); + } +} diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 47ebe8e5..cc1ba8ce 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -10,6 +10,7 @@ #include "BMR/Register.h" #include "config.h" #include "YaoWire.h" +#include "Processor/Instruction.h" #include @@ -19,6 +20,8 @@ class ProcessorBase; class YaoGarbleWire : public YaoWire { + typedef GC::Secret whole_type; + public: typedef YaoGarbler Party; typedef YaoGarbleInput Input; @@ -62,6 +65,9 @@ public: GC::Processor>&); static void reveal_inst(Processor& processor, const vector& args); + static void convcbit2s(GC::Processor& processor, + const BaseInstruction& instruction); + void randomize(PRNG& prng); void set(Key key, bool mask); From e76014e2e9f9a0b8d6c3dc629c80ae3410ad1e70 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 7 Dec 2021 16:42:18 +1100 Subject: [PATCH 009/265] More parallelized SSL handshake. --- Networking/CryptoPlayer.cpp | 29 +++++++++++++++++++++++------ Networking/CryptoPlayer.h | 2 ++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index d0b289b3..c2e1403b 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -68,7 +68,20 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : player.sockets.clear(); } - for (int i = 0; i < (int)sockets.size(); i++) + for (int offset = 1; offset <= num_players() / 2; offset++) + { + int others[] = { get_player(offset), get_player(-offset) }; + if (my_num() % (2 * offset) < offset) + swap(others[0], others[1]); + + if (num_players() % 2 == 0 and offset == num_players() / 2) + connect(others[0], plaintext_sockets); + else + for (int i = 0; i < 2; i++) + connect(others[i], plaintext_sockets); + } + + for (int i = 0; i < num_players(); i++) { if (i == my_num()) { @@ -79,16 +92,20 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : continue; } - sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i], - "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); - other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i], - "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); - senders[i] = new Sender(i < my_num() ? sockets[i] : other_sockets[i]); receivers[i] = new Receiver(i < my_num() ? other_sockets[i] : sockets[i]); } } +void CryptoPlayer::connect(int i, vector* plaintext_sockets) +{ + sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i], + "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); + other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i], + "P" + to_string(i), "P" + to_string(my_num()), i < my_num()); + +} + CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : CryptoPlayer(Nms, to_string(id_base)) { diff --git a/Networking/CryptoPlayer.h b/Networking/CryptoPlayer.h index 287f5c66..3ab3ed56 100644 --- a/Networking/CryptoPlayer.h +++ b/Networking/CryptoPlayer.h @@ -28,6 +28,8 @@ class CryptoPlayer : public MultiPlayer vector*> senders; vector*> receivers; + void connect(int other, vector* plaintext_sockets); + public: /** * Start a new set of encrypted connections. From cdb0c0f898f0c79b70d0b101872baeb80bd70ba2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 15 Dec 2021 12:58:54 +1100 Subject: [PATCH 010/265] In-place operations for containers. --- Compiler/types.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index ec062b5c..48bb27a1 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5420,6 +5420,22 @@ class Array(_vectorizable): __radd__ = __add__ __rmul__ = __mul__ + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __neg__(self): return -self.get_vector() @@ -5770,6 +5786,22 @@ class SubMultiArray(_vectorizable): assert self.sizes == other.sizes self.assign_vector(self.get_vector() + other.get_vector()) + def __iadd__(self, other): + self[:] += other.get_vector() + return self + + def __isub__(self, other): + self[:] -= other.get_vector() + return self + + def __imul__(self, other): + self[:] *= other.get_vector() + return self + + def __itruediv__(self, other): + self[:] /= other.get_vector() + return self + def __mul__(self, other): # legacy function return self.mul(other) From e07d9bf2a3231fe95557106371ce25f3da32f5d6 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 11 Jan 2022 16:04:59 +1100 Subject: [PATCH 011/265] Maintenance. --- .gitmodules | 2 +- BMR/Party.cpp | 2 +- BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 8 +- BMR/Register.h | 3 + BMR/TrustedParty.cpp | 6 + BMR/TrustedParty.h | 3 +- CHANGELOG.md | 13 ++ Compiler/GC/types.py | 10 +- Compiler/comparison.py | 11 +- Compiler/compilerLib.py | 2 +- Compiler/exceptions.py | 5 +- Compiler/floatingpoint.py | 14 +- Compiler/instructions.py | 17 +- Compiler/instructions_base.py | 83 ++++++++- Compiler/library.py | 23 ++- Compiler/ml.py | 12 +- Compiler/non_linear.py | 43 +++-- Compiler/program.py | 28 ++- Compiler/types.py | 50 +++-- ECDSA/hm-ecdsa-party.hpp | 4 +- ECDSA/mascot-ecdsa-party.cpp | 2 + ECDSA/ot-ecdsa-party.hpp | 4 +- ECDSA/preprocessing.hpp | 10 +- ECDSA/sign.hpp | 20 +- ExternalIO/Client.h | 25 +++ ExternalIO/README.md | 59 ++---- FHE/FHE_Params.h | 3 - FHE/NTL-Subs.cpp | 7 +- FHE/NTL-Subs.h | 5 +- FHE/NoiseBounds.cpp | 1 - FHE/Ring_Element.cpp | 22 ++- FHE/Ring_Element.h | 1 + FHEOffline/PairwiseGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 2 +- GC/BitAdder.hpp | 2 +- GC/CcdPrep.h | 5 - GC/CcdPrep.hpp | 8 + GC/CcdShare.h | 1 + GC/FakeSecret.h | 3 + GC/Instruction.cpp | 2 +- GC/NoShare.h | 6 +- GC/PostSacriBin.cpp | 16 +- GC/PostSacriBin.h | 5 +- GC/RepPrep.hpp | 5 + GC/Secret.h | 3 + GC/SemiPrep.cpp | 13 +- GC/SemiPrep.h | 4 +- GC/ShareSecret.h | 2 + GC/ShareSecret.hpp | 10 +- GC/ShareThread.h | 3 - GC/ShareThread.hpp | 3 +- GC/Thread.h | 2 - GC/Thread.hpp | 7 - GC/ThreadMaster.hpp | 4 +- GC/TinierSharePrep.h | 2 - GC/TinierSharePrep.hpp | 18 +- GC/TinyPrep.hpp | 2 +- GC/VectorInput.h | 6 + GC/VectorProtocol.h | 7 +- GC/VectorProtocol.hpp | 24 +-- GC/instructions.h | 2 +- License.txt | 2 +- Machines/Atlas.hpp | 16 ++ Machines/Rep.hpp | 1 + Machines/Rep4.hpp | 17 ++ Machines/SPDZ.hpp | 12 +- Machines/SPDZ2k.hpp | 11 +- Machines/Semi.hpp | 1 + Machines/Semi2k.hpp | 15 ++ Machines/ShamirMachine.hpp | 1 + Machines/Tinier.cpp | 23 +++ Machines/atlas-party.cpp | 7 +- Machines/emulate.cpp | 8 +- Machines/hemi-party.cpp | 1 + Machines/no-party.cpp | 1 + Machines/soho-party.cpp | 1 + Makefile | 28 +-- Math/BitVec.h | 20 +- Math/Setup.hpp | 5 +- Math/ValueInterface.h | 1 + Math/Z2k.h | 3 + Math/Zp_Data.cpp | 36 ++++ Math/Zp_Data.h | 33 +--- Math/gfp.h | 2 +- Networking/CryptoPlayer.cpp | 5 + Networking/Player.cpp | 49 +++-- Networking/Player.h | 17 +- Networking/Receiver.cpp | 8 + Networking/Sender.cpp | 12 +- Networking/Server.cpp | 46 +++-- Networking/Server.h | 7 +- Networking/ssl_sockets.h | 13 ++ OT/BaseOT.cpp | 18 ++ OT/NPartyTripleGenerator.h | 13 +- Processor/BaseMachine.cpp | 34 +++- Processor/BaseMachine.h | 15 +- Processor/Binary_File_IO.hpp | 2 +- Processor/Data_Files.h | 37 ++-- Processor/Data_Files.hpp | 18 +- Processor/DummyProtocol.h | 4 +- Processor/FieldMachine.h | 5 +- Processor/FieldMachine.hpp | 3 +- Processor/HonestMajorityMachine.cpp | 2 +- Processor/Input.h | 5 +- Processor/Input.hpp | 10 +- Processor/Instruction.hpp | 2 + Processor/Machine.h | 1 - Processor/Machine.hpp | 16 +- Processor/Memory.h | 3 + Processor/NoFilePrep.h | 22 +++ Processor/OfflineMachine.hpp | 6 +- Processor/Online-Thread.hpp | 15 +- Processor/OnlineOptions.cpp | 14 ++ Processor/OnlineOptions.h | 6 + Processor/OnlineOptions.hpp | 30 +++ Processor/PrepBase.cpp | 20 +- Processor/PrepBase.h | 6 +- Processor/Processor.h | 4 - Processor/Processor.hpp | 50 +---- Processor/RingMachine.h | 2 +- Processor/RingMachine.hpp | 3 +- Processor/ThreadQueue.cpp | 21 +++ Processor/ThreadQueue.h | 5 + Processor/TruncPrTuple.h | 37 +++- Programs/Source/keras_mnist_lenet_predict.mpc | 44 +++++ Protocols/Atlas.h | 11 +- Protocols/Atlas.hpp | 9 +- Protocols/Beaver.h | 11 +- Protocols/Beaver.hpp | 34 ++-- Protocols/BrainShare.h | 2 + Protocols/FakeProtocol.h | 35 ++-- Protocols/FakeShare.h | 3 + Protocols/Hemi.h | 6 +- Protocols/Hemi.hpp | 22 ++- Protocols/HighGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.hpp | 2 +- Protocols/MAC_Check.hpp | 2 + Protocols/MalRepRingPrep.hpp | 41 ----- Protocols/MaliciousRep3Share.h | 1 + Protocols/MaliciousRepPO.h | 8 +- Protocols/MaliciousRepPO.hpp | 18 +- Protocols/MaliciousRepPrep.hpp | 5 +- Protocols/MamaPrep.hpp | 1 + Protocols/MascotPrep.h | 2 - Protocols/MascotPrep.hpp | 12 +- Protocols/NoProtocol.h | 4 +- Protocols/PostSacriRepRingShare.h | 2 + Protocols/PostSacrifice.h | 4 +- Protocols/PostSacrifice.hpp | 7 +- Protocols/ProtocolSet.h | 107 +++++++++++ Protocols/ProtocolSetup.h | 95 ++++++++++ Protocols/Rep3Share.h | 27 +++ Protocols/Rep3Share2k.h | 12 -- Protocols/Rep4.h | 12 +- Protocols/Rep4.hpp | 34 ++-- Protocols/Rep4Prep.hpp | 2 +- Protocols/Replicated.h | 20 +- Protocols/Replicated.hpp | 174 +++++++----------- Protocols/ReplicatedInput.h | 3 +- Protocols/ReplicatedInput.hpp | 2 +- Protocols/ReplicatedPO.h | 24 +++ Protocols/ReplicatedPO.hpp | 21 +++ Protocols/ReplicatedPrep.h | 25 ++- Protocols/ReplicatedPrep.hpp | 148 +++++++++------ Protocols/{Semi2k.h => Semi.h} | 23 ++- Protocols/Semi2kShare.h | 6 +- Protocols/SemiShare.h | 8 +- Protocols/Shamir.h | 16 +- Protocols/Shamir.hpp | 14 +- Protocols/ShuffleSacrifice.hpp | 16 +- Protocols/Spdz2kPrep.h | 3 - Protocols/Spdz2kPrep.hpp | 35 ++-- Protocols/SpdzWise.h | 11 +- Protocols/SpdzWise.hpp | 28 +-- Protocols/SpdzWiseInput.hpp | 3 +- Protocols/SpdzWisePrep.hpp | 13 +- Protocols/SpdzWiseRing.hpp | 2 +- Protocols/SquarePrep.h | 6 +- README.md | 6 +- Scripts/decompile.py | 16 ++ Scripts/memory-usage.py | 29 +++ Scripts/run-common.sh | 31 +--- Scripts/test_streaming.sh | 4 + Scripts/tldr.sh | 3 +- Tools/BitVector.cpp | 9 + Tools/BitVector.h | 9 +- Tools/Buffer.cpp | 13 +- Tools/Bundle.h | 2 +- Tools/TimerWithComm.cpp | 23 +++ Tools/TimerWithComm.h | 23 +++ Tools/benchmarking.cpp | 15 ++ Tools/benchmarking.h | 3 + Tools/octetStream.h | 4 +- Tools/random.cpp | 4 +- Tools/random.h | 2 + Utils/Fake-Offline.cpp | 2 +- Utils/binary-example.cpp | 140 ++++++++++++++ Utils/mixed-example.cpp | 137 ++++++++++++++ Utils/paper-example.cpp | 49 ++--- Utils/stream-fake-mascot-triples.cpp | 21 ++- Yao/YaoEvaluator.h | 3 - Yao/YaoGarbler.cpp | 5 - Yao/YaoGarbler.h | 2 - Yao/YaoWire.h | 4 + Yao/YaoWire.hpp | 20 ++ doc/Doxyfile | 2 +- doc/conf.py | 5 +- doc/index.rst | 14 +- doc/io.rst | 10 + doc/low-level.rst | 142 +++++--------- doc/networking.rst | 6 +- doc/non-linear.rst | 4 +- doc/preprocessing.rst | 64 ++++++- doc/troubleshooting.rst | 21 ++- 216 files changed, 2410 insertions(+), 1117 deletions(-) create mode 100644 Machines/Atlas.hpp create mode 100644 Machines/Rep4.hpp create mode 100644 Machines/Semi2k.hpp create mode 100644 Machines/Tinier.cpp create mode 100644 Processor/NoFilePrep.h create mode 100644 Processor/OnlineOptions.hpp create mode 100644 Programs/Source/keras_mnist_lenet_predict.mpc create mode 100644 Protocols/ProtocolSet.h create mode 100644 Protocols/ProtocolSetup.h create mode 100644 Protocols/ReplicatedPO.h create mode 100644 Protocols/ReplicatedPO.hpp rename Protocols/{Semi2k.h => Semi.h} (75%) create mode 100755 Scripts/decompile.py create mode 100755 Scripts/memory-usage.py create mode 100644 Tools/TimerWithComm.cpp create mode 100644 Tools/TimerWithComm.h create mode 100644 Tools/benchmarking.cpp create mode 100644 Utils/binary-example.cpp create mode 100644 Utils/mixed-example.cpp diff --git a/.gitmodules b/.gitmodules index 455a5514..32dca28b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,7 @@ url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] path = mpir - url = git://github.com/wbhart/mpir.git + url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 5ca1360a..84ba909b 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,7 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl; + cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 55adcbfb..760a20b8 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -175,7 +175,7 @@ void GarbleInputter::exchange() assert(party.P != 0); assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; - protocol.init_mul(party.shared_proc); + protocol.init_mul(); for (auto& tuple : tuples) protocol.prepare_mul(tuple.first->mask, T::constant(1, party.P->my_num(), party.mac_key) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 0c97f9bd..8e16c307 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -155,7 +155,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); - data_sent = P->comm_stats.total_data() + prep->data_sent(); + data_sent = P->total_comm().sent; this->machine.write_memory(this->N.my_num()); } @@ -173,7 +173,8 @@ void RealProgramParty::garble() garble_jobs.clear(); garble_inputter->reset_all(*P); auto& protocol = *garble_protocol; - protocol.init_mul(shared_proc); + protocol.init(*prep, shared_proc->MC); + protocol.init_mul(); next = this->first_phase(program, garble_processor, this->garble_machine); @@ -181,7 +182,8 @@ void RealProgramParty::garble() protocol.exchange(); typename T::Protocol second_protocol(*P); - second_protocol.init_mul(shared_proc); + second_protocol.init(*prep, shared_proc->MC); + second_protocol.init_mul(); for (auto& job : garble_jobs) job.middle_round(*this, second_protocol); diff --git a/BMR/Register.h b/BMR/Register.h index d0a75e93..f348f7b7 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -293,6 +293,9 @@ public: template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } // most BMR phases don't need actual input template diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 6bd1ba26..439bcfc7 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty() _received_gc_received = 0; n_received = 0; randomfd = open("/dev/urandom", O_RDONLY); + done_filling = false; +} + +BaseTrustedParty::~BaseTrustedParty() +{ + close(randomfd); } TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 24e8120d..260e7a51 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -20,7 +20,7 @@ public: vector msg_input_masks; BaseTrustedParty(); - virtual ~BaseTrustedParty() {} + virtual ~BaseTrustedParty(); /* From NodeUpdatable class */ virtual void NodeReady(); @@ -104,7 +104,6 @@ private: void add_all_keys(const Register& reg, bool external); }; - inline void BaseTrustedParty::add_keys(const Register& reg) { for(int p = 0; p < get_n_parties(); p++) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c9be9e5..2b75d24f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.9 (Jan 11, 2021) + +- Disassembler +- Run-time parameter for probabilistic truncation error +- Probabilistic truncation for some protocols computing modulo a prime +- Simplified C++ interface +- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757) +- More general scalar-vector multiplication +- Complete memory support for clear bits +- Extended clear bit functionality with Yao's garbled circuits +- Allow preprocessing information to be supplied via named pipes +- In-place operations for containers + ## 0.2.8 (Nov 4, 2021) - Tested on Apple laptop with ARM chip diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5a65e73a..53da15ba 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -112,10 +112,16 @@ class bits(Tape.Register, _structure, _bit): return cls.load_dynamic_mem(address) else: for i in range(res.size): - cls.load_inst[util.is_constant(address)](res[i], address + i) + cls.mem_op(cls.load_inst, res[i], address + i) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, int)](self, address) + self.mem_op(self.store_inst, self, address) + @staticmethod + def mem_op(inst, reg, address): + direct = isinstance(address, int) + if not direct: + address = regint.conv(address) + inst[direct](reg, address) @classmethod def new(cls, value=None, n=None): if util.is_constant(value): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index f4cf89ad..2f7ca81f 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa): k: bit length of a """ + movs(s, program.non_linear.ltz(a, k, kappa)) + +def LtzRing(a, k): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): summands = a.split_to_two_summands(k) carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] - movs(s, sint.conv(msb)) + return sint.conv(msb) return elif program.options.ring: from . import floatingpoint @@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa): a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) - movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))) - return - t = sint() - Trunc(t, a, k, k - 1, kappa, True) - subsfi(s, t, 0) + return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)) def LessThanZero(a, k, kappa): from . import types diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 64e76434..b2898e21 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -82,7 +82,7 @@ def run(args, options): prog.finalize() if prog.req_num: - print('Program requires:') + print('Program requires at most:') for x in prog.req_num.pretty(): print(x) diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index fd026563..c68ecd31 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -12,4 +12,7 @@ class ArgumentError(CompilerError): """ Exception raised for errors in instruction argument parsing. """ def __init__(self, arg, msg): self.arg = arg - self.msg = msg \ No newline at end of file + self.msg = msg + +class VectorMismatch(CompilerError): + pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index a15a62dd..c596240b 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) - lts(d, c_dprime, r_prime, l, kappa) + d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m @@ -629,12 +629,14 @@ def BITLT(a, b, bit_length): # - From the paper # Multiparty Computation for Interval, Equality, and Comparison without # Bit-Decomposition Protocol -def BitDecFull(a, maybe_mixed=False): +def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() + n_bits = n_bits or bit_length + assert n_bits <= bit_length logp = int(round(math.log(p, 2))) if abs(p - 2 ** logp) / p < 2 ** -get_program().security: # inspired by Rabbit (https://eprint.iacr.org/2021/119) @@ -677,12 +679,12 @@ def BitDecFull(a, maybe_mixed=False): czero = (c==0) q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [bbits[0].clear_type.conv(cint(x)) - for x in ((1< 0. + """ Crash runtime if the value in the register is not zero. :param: Crash condition (regint)""" code = base.opcodes['CRASH'] @@ -1275,7 +1275,7 @@ class prep(base.Instruction): field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, self.args[0]), 1) + req_node.increment((self.field_type, self.args[0]), self.get_size()) def has_var_args(self): return True @@ -2407,19 +2407,6 @@ class sqrs(base.CISC): subml(self.args[0], s[5], c[1]) -@base.gf2n -@base.vectorize -class lts(base.CISC): - """ Secret comparison $s_i = (s_j < s_k)$. """ - __slots__ = [] - arg_format = ['sw', 's', 's', 'int', 'int'] - - def expand(self): - from .types import sint - a = sint() - subs(a, self.args[1], self.args[2]) - comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) - # placeholder for documentation class cisc: """ Meta instruction for emulation. This instruction is only generated diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 38fd97d2..fb2a67b8 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -4,6 +4,8 @@ import time import inspect import functools import copy +import sys +import struct from Compiler.exceptions import * from Compiler.config import * from Compiler import util @@ -299,11 +301,12 @@ def vectorize(instruction, global_dict=None): vectorized_name = 'v' + instruction.__name__ Vectorized_Instruction.__name__ = vectorized_name global_dict[vectorized_name] = Vectorized_Instruction + + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction - instruction.__doc__ = '' - # exclude GF(2^n) instructions from documentation - if instruction.code and instruction.code >> 8 == 1: - maybe_vectorized_instruction.__doc__ = '' + maybe_vectorized_instruction.arg_format = instruction.arg_format return maybe_vectorized_instruction @@ -389,8 +392,11 @@ def gf2n(instruction): else: global_dict[GF2N_Instruction.__name__] = GF2N_Instruction + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction_cls - instruction_cls.__doc__ = '' + maybe_gf2n_instruction.arg_format = instruction.arg_format return maybe_gf2n_instruction #return instruction @@ -661,6 +667,12 @@ class RegisterArgFormat(ArgFormat): assert arg.i >= 0 return int_to_bytes(arg.i) + def __init__(self, f): + self.i = struct.unpack('>I', f.read(4))[0] + + def __str__(self): + return self.reg_type + str(self.i) + class ClearModpAF(RegisterArgFormat): reg_type = RegType.ClearModp @@ -686,6 +698,12 @@ class IntArgFormat(ArgFormat): def encode(cls, arg): return int_to_bytes(arg) + def __init__(self, f): + self.i = struct.unpack('>i', f.read(4))[0] + + def __str__(self): + return str(self.i) + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): @@ -722,6 +740,13 @@ class String(ArgFormat): def encode(cls, arg): return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg)) + def __init__(self, f): + tmp = f.read(16) + self.str = str(tmp[0:tmp.find(b'\0')], 'ascii') + + def __str__(self): + return self.str + ArgFormats = { 'c': ClearModpAF, 's': SecretModpAF, @@ -890,6 +915,54 @@ class Instruction(object): def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' +class ParsedInstruction: + reverse_opcodes = {} + + def __init__(self, f): + cls = type(self) + from Compiler import instructions + from Compiler.GC import instructions as gc_inst + if not cls.reverse_opcodes: + for module in instructions, gc_inst: + for x, y in inspect.getmodule(module).__dict__.items(): + if inspect.isclass(y) and y.__name__[0] != 'v': + try: + cls.reverse_opcodes[y.code] = y + except AttributeError: + pass + read = lambda: struct.unpack('>I', f.read(4))[0] + full_code = read() + code = full_code % (1 << Instruction.code_length) + self.size = full_code >> Instruction.code_length + self.type = cls.reverse_opcodes[code] + t = self.type + name = t.__name__ + try: + n_args = len(t.arg_format) + self.var_args = False + except: + n_args = read() + self.var_args = True + try: + arg_format = iter(t.arg_format) + except: + if name == 'cisc': + arg_format = itertools.chain(['str'], itertools.repeat('int')) + else: + arg_format = itertools.repeat('int') + self.args = [ArgFormats[next(arg_format)](f) + for i in range(n_args)] + + def __str__(self): + name = self.type.__name__ + res = name + ' ' + if self.size > 1: + res = 'v' + res + str(self.size) + ', ' + if self.var_args: + res += str(len(self.args)) + ', ' + res += ', '.join(str(arg) for arg in self.args) + return res + class VarArgsInstruction(Instruction): def has_var_args(self): return True diff --git a/Compiler/library.py b/Compiler/library.py index 529608dc..7bab1951 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -219,6 +219,9 @@ def crash(condition=None): :param condition: crash if true (default: true) """ + if isinstance(condition, localint): + # allow crash on local values + condition = condition._v if condition == None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -1347,6 +1350,8 @@ def while_loop(loop_body, condition, arg, g=None): arg = regint(arg) def loop_fn(): result = loop_body(arg) + if isinstance(result, MemValue): + result = result.read() result.link(arg) cont = condition(result) return cont @@ -1531,6 +1536,8 @@ def if_(condition): def if_e(condition): """ Conditional execution with else block. + Use :py:class:`~Compiler.types.MemValue` to assign values that + live beyond. :param condition: regint/cint/int @@ -1538,12 +1545,13 @@ def if_e(condition): .. code:: + y = MemValue(0) @if_e(x > 0) def _(): - ... + y.write(1) @else_ def _(): - ... + y.write(0) """ try: condition = bool(condition) @@ -1647,11 +1655,18 @@ def get_player_id(): return res def listen_for_clients(port): - """ Listen for clients on specific port. """ + """ Listen for clients on specific port base. + + :param port: port base (int/regint/cint) + """ instructions.listen(regint.conv(port)) def accept_client_connection(port): - """ Listen for clients on specific port. """ + """ Accept client connection on specific port base. + + :param port: port base (int/regint/cint) + :returns: client id + """ res = regint() instructions.acceptclientconnection(res, regint.conv(port)) return res diff --git a/Compiler/ml.py b/Compiler/ml.py index 5ff1a375..7e53a78f 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1810,6 +1810,7 @@ class Optimizer: self.print_loss_reduction = False self.i_epoch = MemValue(0) self.stopped_on_loss = MemValue(0) + self.stopped_on_low_loss = MemValue(0) @property def layers(self): @@ -1932,6 +1933,7 @@ class Optimizer: """ Run training. :param batch_size: batch size (defaults to example size of first layer) + :param stop_on_loss: stop when loss falls below this (default: 0) """ if self.n_epochs == 0: return @@ -2013,6 +2015,7 @@ class Optimizer: if self.tol > 0: res *= (1 - (loss_sum >= 0) * \ (loss_sum < self.tol * n_per_epoch)).reveal() + self.stopped_on_low_loss.write(1 - res) return res def reveal_correctness(self, data, truth, batch_size): @@ -2138,6 +2141,7 @@ class Optimizer: if depreciation: self.gamma.imul(depreciation) print_ln('reducing learning rate to %s', self.gamma) + return 1 - self.stopped_on_low_loss if 'model_output' in program.args: self.output_weights() @@ -2386,6 +2390,7 @@ class keras: return list(self.opt.thetas) def build(self, input_shape, batch_size=128): + data_input_shape = input_shape if self.opt != None and \ input_shape == self.opt.layers[0].X.sizes and \ batch_size <= self.batch_size and \ @@ -2458,9 +2463,10 @@ class keras: else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: - layers.append(Output(input_shape[0])) + layers.append(Output(data_input_shape[0])) else: - layers.append(MultiOutput(input_shape[0], layers[-1].d_out)) + layers.append( + MultiOutput(data_input_shape[0], layers[-1].d_out)) if self.optimizer[1]: raise Exception('use keyword arguments for optimizer') opt = self.optimizer[0] @@ -2504,7 +2510,7 @@ class keras: if x.total_size() != self.opt.layers[0].X.total_size(): raise Exception('sample data size mismatch') if y.total_size() != self.opt.layers[-1].Y.total_size(): - print (y, layers[-1].Y) + print (y, self.opt.layers[-1].Y) raise Exception('label size mismatch') if validation_data == None: validation_data = None, None diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 43e10c2e..01cb4db5 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -1,7 +1,7 @@ from .comparison import * from .floatingpoint import * from .types import * -from . import comparison +from . import comparison, program class NonLinear: kappa = None @@ -30,6 +30,15 @@ class NonLinear: def trunc_pr(self, a, k, m, signed=True): if isinstance(a, types.cint): return shift_two(a, m) + prog = program.Program.prog + if prog.use_trunc_pr: + if signed and prog.use_trunc_pr != -1: + a += (1 << (k - 1)) + res = sint() + trunc_pr(res, a, k, m) + if signed and prog.use_trunc_pr != -1: + res -= (1 << (k - m - 1)) + return res return self._trunc_pr(a, k, m, signed) def trunc_round_nearest(self, a, k, m, signed): @@ -44,6 +53,9 @@ class NonLinear: return a return self._trunc(a, k, m, signed) + def ltz(self, a, k, kappa=None): + return -self.trunc(a, k, k - 1, kappa, True) + class Masking(NonLinear): def eqz(self, a, k): c, r = self._mask(a, k) @@ -100,42 +112,44 @@ class KnownPrime(NonLinear): def _mod2m(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) - return sint.bit_compose(self.bit_dec(a, k, k, True)[:m]) + return sint.bit_compose(self.bit_dec(a, k, m, True)) def _trunc_pr(self, a, k, m, signed): # nearest truncation return self.trunc_round_nearest(a, k, m, signed) def _trunc(self, a, k, m, signed=None): - if signed: - a += cint(1) << (k - 1) - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) - if signed: - res -= cint(1) << (k - 1 - m) - return res + return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed) def trunc_round_nearest(self, a, k, m, signed): a += cint(1) << (m - 1) if signed: a += cint(1) << (k - 1) k += 1 - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) + res = self._trunc(a, k, m, False) if signed: res -= cint(1) << (k - m - 2) return res def bit_dec(self, a, k, m, maybe_mixed=False): assert k < self.prime.bit_length() - bits = BitDecFull(a, maybe_mixed=maybe_mixed) - if len(bits) < m: - raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) - return bits[:m] + bits = BitDecFull(a, m, maybe_mixed=maybe_mixed) + assert len(bits) == m + return bits def eqz(self, a, k): # always signed a += two_power(k) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) + def ltz(self, a, k, kappa=None): + if k + 1 < self.prime.bit_length(): + # https://dl.acm.org/doi/10.1145/3474123.3486757 + # "negative" values wrap around when doubling, thus becoming odd + return self.mod2m(2 * a, k + 1, 1, False) + else: + return super(KnownPrime, self).ltz(a, k, kappa) + class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. """ @@ -172,3 +186,6 @@ class Ring(Masking): return TruncRing(None, tmp + 1, k - m + 1, 1, signed) else: return super(Ring, self).trunc_round_nearest(a, k, m, signed) + + def ltz(self, a, k, kappa=None): + return LtzRing(a, k) diff --git a/Compiler/program.py b/Compiler/program.py index 19ce5248..5dad8e51 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -578,6 +578,15 @@ class Program(object): self.warn_about_mem.append(False) self.curr_block.warn_about_mem = False + @staticmethod + def read_tapes(schedule): + if not os.path.exists(schedule): + schedule = 'Programs/Schedules/%s.sch' % schedule + + lines = open(schedule).readlines() + for tapename in lines[2].split(' '): + yield tapename.strip() + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -1109,7 +1118,20 @@ class Tape: else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) - class Register(object): + @staticmethod + def read_instructions(tapename): + tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + while tape.peek(): + yield inst_base.ParsedInstruction(tape) + + class _no_truth(object): + __slots__ = [] + + def __bool__(self): + raise CompilerError('Cannot derive truth value from register, ' + "consider using 'compile.py -l'") + + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. @@ -1233,10 +1255,6 @@ class Tape: self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.ClearInt - def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") - def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/types.py b/Compiler/types.py index 48bb27a1..33df2e37 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,7 +127,7 @@ def vectorize(operation): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise CompilerError('Different vector sizes of operands: %d/%d' + raise VectorMismatch('Different vector sizes of operands: %d/%d' % (self.size, args[0].size)) set_global_vector_size(self.size) try: @@ -221,7 +221,7 @@ def inputmixed(*args): else: instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),))) -class _number(object): +class _number(Tape._no_truth): """ Number functionality. """ def square(self): @@ -246,7 +246,11 @@ class _number(object): elif is_one(other): return self else: - return self.mul(other) + try: + return self.mul(other) + except VectorMismatch: + # try reverse multiplication + return NotImplemented __radd__ = __add__ __rmul__ = __mul__ @@ -320,7 +324,7 @@ class _number(object): def popcnt_bits(bits): return sum(bits) -class _int(object): +class _int(Tape._no_truth): """ Integer functionality. """ @staticmethod @@ -408,7 +412,7 @@ class _int(object): def long_one(): return 1 -class _bit(object): +class _bit(Tape._no_truth): """ Binary functionality. """ def bit_xor(self, other): @@ -474,7 +478,7 @@ class _gf2n(_bit): def bit_not(self): return self ^ 1 -class _structure(object): +class _structure(Tape._no_truth): """ Interface for type-dependent container types. """ MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value))) @@ -591,7 +595,7 @@ class _secret_structure(_structure): res.input_from(player) return res -class _vec(object): +class _vec(Tape._no_truth): def link(self, other): assert len(self.v) == len(other.v) for x, y in zip(self.v, other.v): @@ -726,7 +730,7 @@ class _register(Tape.Register, _number, _structure): assert self.size == 1 res = type(self)(size=size) for i in range(size): - movs(res[i], self) + self.mov(res[i], self) return res class _clear(_register): @@ -1010,9 +1014,10 @@ class cint(_clear, _int): if bit_length <= 64: return regint(self) < regint(other) else: + sint.require_bit_length(bit_length + 1) diff = self - other - diff += (1 << (bit_length - 1)) - shifted = diff >> (bit_length - 1) + diff += 1 << bit_length + shifted = diff >> bit_length res = 1 - regint(shifted & 1) return res @@ -1646,7 +1651,7 @@ class regint(_register, _int): player = -1 intoutput(player, self) -class localint(object): +class localint(Tape._no_truth): """ Local integer that must prevented from leaking into the secure computation. Uses regint internally. @@ -1669,7 +1674,7 @@ class localint(object): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) -class personal(object): +class personal(Tape._no_truth): def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -2003,9 +2008,11 @@ class _secret(_register, _secret_structure): size or one size 1 for a value-vector multiplication. :param other: any compatible type """ - if isinstance(other, _secret) and (1 in (self.size, other.size)) \ + if isinstance(other, _register) and (1 in (self.size, other.size)) \ and (self.size, other.size) != (1, 1): x, y = (other, self) if self.size < other.size else (self, other) + if not isinstance(other, _secret): + return y.expand_to_vector(x.size) * x res = type(self)(size=x.size) mulrs(res, x, y) return res @@ -2221,11 +2228,13 @@ class sint(_secret, _int): @vectorized_classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of values input by a client. + This uses the triple-based input protocol introduced by + `Damgård et al. `_ :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) - + :returns: list of sint """ # send shares of a triple to client triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) @@ -2910,7 +2919,7 @@ for t in (sint, sgf2n): sint.bit_type = sintbit sgf2n.bit_type = sgf2n -class _bitint(object): +class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False @@ -3521,6 +3530,7 @@ class cfix(_number, _structure): @classmethod def _new(cls, other, k=None, f=None): + assert not isinstance(other, (list, tuple)) res = cls(k=k, f=f) res.v = cint.conv(other) return res @@ -3567,6 +3577,8 @@ class cfix(_number, _structure): return len(self.v) def __getitem__(self, index): + if isinstance(index, slice): + return [self._new(x, k=self.k, f=self.f) for x in self.v[index]] return self._new(self.v[index], k=self.k, f=self.f) @vectorize @@ -3608,7 +3620,6 @@ class cfix(_number, _structure): else: return NotImplemented - @vectorize def mul(self, other): """ Clear fixed-point multiplication. @@ -4045,7 +4056,8 @@ class _fix(_single): 'for fixed-point computation') cls.round_nearest = True if adapt_ring and program.options.ring \ - and 'fix_ring' not in program.args: + and 'fix_ring' not in program.args \ + and 2 * cls.k > int(program.options.ring): need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) if need != int(program.options.ring): print('Changing computation modulus to 2^%d' % need) @@ -4489,7 +4501,7 @@ class squant(_single): def __neg__(self): return self._new(-self.v + 2 * util.expand(self.Z, self.v.size)) -class _unreduced_squant(object): +class _unreduced_squant(Tape._no_truth): def __init__(self, v, params, res_params=None, n_summands=1): self.v = v self.params = params @@ -5011,7 +5023,7 @@ class sfloat(_number, _secret_structure): :return: cfloat """ return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal()) -class cfloat(object): +class cfloat(Tape._no_truth): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index a68f8e83..fc19e989 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -52,10 +52,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index 87573593..920397ce 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -5,6 +5,8 @@ #define NO_MIXED_CIRCUITS +#define NO_SECURITY_CHECK + #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 58f35d4b..569aa791 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -113,10 +113,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::TriplePrep prep(0, usage); diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 334d5d1b..0a5e0ab9 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -41,8 +41,8 @@ void preprocessing(vector>& tuples, int buffer_size, timer.start(); Player& P = proc.P; auto& prep = proc.DataF; - size_t start = P.sent + prep.data_sent(); - auto stats = P.comm_stats + prep.comm_stats(); + size_t start = P.total_comm().sent; + auto stats = P.total_comm(); auto& extra_player = P; auto& protocol = proc.protocol; @@ -77,7 +77,7 @@ void preprocessing(vector>& tuples, int buffer_size, MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); protocol.start_exchange(); @@ -106,9 +106,9 @@ void preprocessing(vector>& tuples, int buffer_size, timer.stop(); cout << "Generated " << buffer_size << " tuples in " << timer.elapsed() << " seconds, throughput " << buffer_size / timer.elapsed() << ", " - << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size + << 1e-3 * (P.total_comm().sent - start) / buffer_size << " kbytes per tuple" << endl; - (P.comm_stats + prep.comm_stats() - stats).print(true); + (P.total_comm() - stats).print(true); } template class T> diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index 10991276..5686349e 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length, (void) pk; Timer timer; timer.start(); - size_t start = P.sent; - auto stats = P.comm_stats; + auto stats = P.total_comm(); EcSignature signature; vector opened_R; if (opts.R_after_msg) @@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length, auto& protocol = proc->protocol; if (proc) { - protocol.init_mul(proc); + protocol.init_mul(); protocol.prepare_mul(sk, tuple.a); protocol.start_exchange(); } @@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length, auto rx = tuple.R.x(); signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); + auto diff = (P.total_comm() - stats); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " - << (P.sent - start) << " bytes" << endl; - auto diff = (P.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(true); return signature; } @@ -139,11 +138,11 @@ void sign_benchmark(vector>& tuples, T sk, P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { @@ -154,13 +153,12 @@ void sign_benchmark(vector>& tuples, T sk, Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); - auto stats = check_player.comm_stats; - auto start = check_player.sent; + auto stats = check_player.total_comm(); MCp.Check(P); MCc.Check(P); + auto diff = (check_player.total_comm() - stats); cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending " - << (check_player.sent - start) << " bytes" << endl; - auto diff = (check_player.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(); } } diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 12ba1c93..5f8e76fd 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,6 +8,9 @@ #include "Networking/ssl_sockets.h" +/** + * Client-side interface + */ class Client { vector plain_sockets; @@ -15,15 +18,37 @@ class Client ssl_service io_service; public: + /** + * Sockets for cleartext communication + */ vector sockets; + + /** + * Specification of computation domain + */ octetStream specification; + /** + * Start a new set of connections to computing parties. + * @param hostnames location of computing parties + * @param port_base port base + * @param my_client_id client identifier + */ Client(const vector& hostnames, int port_base, int my_client_id); ~Client(); + /** + * Securely input private values. + * @param values vector of integer-like values + */ template void send_private_inputs(const vector& values); + /** + * Securely receive output values. + * @param n number of values + * @returns vector of integer-like values + */ template vector receive_outputs(int n); }; diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 36649f5c..d4f99288 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -19,6 +19,8 @@ Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 ``` +`` can be any arithmetic protocol (e.g., `mascot`) but not a +binary protocol (e.g., `yao`). This should output that the winning id is 1. Note that the ids have to be incremental, and the client with the highest id has to input 1 as the last argument while the others have to input 0 there. Furthermore, @@ -32,54 +34,21 @@ different hosts, you will have to distribute the `*.pem` files. ### Connection Setup -**listen**(*int port_num*) - -Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. - -*port_num* - the port number to listen on. - -**acceptclientconnection**(*regint client_socket_id*, *int port_num*) - -Picks the first available client socket connection. Blocks if none available. - -*client_socket_id* - an identifier used to refer to the client socket. - -*port_num* - the port number identifies the socket server to accept connections on. +1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients) +2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection) +3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection) ### Data Exchange -Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). +Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types). -*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) +1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket) +2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client) +3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients) -Read a share of an input from a client, blocking on the client send. +## Client-Side Interface -*client_socket_id* - an identifier used to refer to the client socket. - -*number_of_inputs* - the number of inputs expected - -*[inputs]* - returned list of shares of private input. - -**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) - -Write shares of values including macs to an external client. - -*client_socket_id* - an identifier used to refer to the client socket. - -*[values]* - list of shares of values to send to client. - -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. - -See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. - -*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) - -Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) - -*number_of_inputs* - the number of inputs expected - -*client_socket_id* - an identifier used to refer to the client socket. - -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. - -*[inputs]* - returned list of shares of private input. +The example uses the `Client` class implemented in +`ExternalIO/Client.hpp` to handle the communication, see +https://mp-spdz.readthedocs.io/en/latest/io.html#reference for +documentation. diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index ac56668a..8ac40083 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -33,9 +33,6 @@ class FHE_Params int n_mults() const { return FFTData.size() - 1; } - // Rely on default copy assignment/constructor (not that they should - // ever be needed) - void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index cb5daa38..c6e294a6 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -178,12 +178,6 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, return extra_slack; } - - - -/****************************************************************************** - * Here onwards needs NTL - ******************************************************************************/ @@ -345,6 +339,7 @@ ZZX Cyclotomic(int N) return F; } #else +// simplified version powers of two int phi_N(int N) { if (((N - 1) & N) != 0) diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index ab150d27..c0a2ecfe 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -1,8 +1,6 @@ #ifndef _NTL_Subs #define _NTL_Subs -/* All these routines use NTL on the inside */ - #include "FHE/Ring.h" #include "FHE/FFT_Data.h" #include "FHE/P2Data.h" @@ -47,7 +45,7 @@ public: }; -// Main setup routine (need NTL if online_only is false) +// Main setup routine void generate_setup(int nparties, int lgp, int lg2, int sec, bool skip_2 = false, int slack = 0, bool round_up = false); @@ -60,7 +58,6 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, bool round_up); -// Everything else needs NTL void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index ae52fc62..7ab8e517 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -114,7 +114,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, cout << "n: " << n << endl; cout << "sec: " << sec << endl; cout << "sigma: " << this->sigma << endl; - cout << "h: " << h << endl; cout << "B_clean size: " << numBits(B_clean) << endl; cout << "B_scale size: " << numBits(B_scale) << endl; cout << "B_KS size: " << numBits(B_KS) << endl; diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 9c2545ed..812560a3 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -401,19 +401,29 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { - if (element.empty() and a.element.empty()) - return true; - else if (element.empty() or a.element.empty()) - throw not_implemented(); - if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } + + if (is_zero() or a.is_zero()) + return is_zero() and a.is_zero(); + for (int i=0; i<(*FFTD).phi_m(); i++) { if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } } return true; } +bool Ring_Element::is_zero() const +{ + if (element.empty()) + return true; + for (auto& x : element) + if (not ::isZero(x, FFTD->get_prD())) + return false; + return true; +} + + ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) @@ -560,6 +570,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const { if (&FFTD != this->FFTD) throw params_mismatch(); + if (is_zero()) + throw runtime_error("element is zero"); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 5cc93ca9..5982bbe3 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -95,6 +95,7 @@ class Ring_Element void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; + bool is_zero() const; // This is a NOP in cases where we cannot do a FFT void change_rep(RepType r); diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index ed5fb303..dcbd29b5 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -175,7 +175,7 @@ size_t PairwiseGenerator::report_size(ReportType type) template size_t PairwiseGenerator::report_sent() { - return P.sent; + return P.total_comm().sent; } template diff --git a/FHEOffline/SimpleGenerator.h b/FHEOffline/SimpleGenerator.h index 9cacad69..d5ee933a 100644 --- a/FHEOffline/SimpleGenerator.h +++ b/FHEOffline/SimpleGenerator.h @@ -71,7 +71,7 @@ public: void run(bool exhaust); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); - size_t report_sent() { return P.sent; } + size_t report_sent() { return P.total_comm().sent; } }; #endif /* FHEOFFLINE_SIMPLEGENERATOR_H_ */ diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index 9f852597..437af179 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -96,7 +96,7 @@ void BitAdder::add(vector >& res, b[j] = summands[i][1][input_begin + j]; } - protocol.init_mul(&proc); + protocol.init_mul(); for (size_t j = 0; j < n_items; j++) { res[begin + j][i] = a[j] + b[j] + carries[j]; diff --git a/GC/CcdPrep.h b/GC/CcdPrep.h index 8d232444..ab02ea80 100644 --- a/GC/CcdPrep.h +++ b/GC/CcdPrep.h @@ -91,11 +91,6 @@ public: (typename T::clear(tmp.get_bit(0)) << i); } } - - NamedCommStats comm_stats() - { - return part_prep.comm_stats(); - } }; } /* namespace GC */ diff --git a/GC/CcdPrep.hpp b/GC/CcdPrep.hpp index f9535350..3124efc4 100644 --- a/GC/CcdPrep.hpp +++ b/GC/CcdPrep.hpp @@ -25,6 +25,14 @@ void CcdPrep::set_protocol(typename T::Protocol& protocol) { auto& thread = ShareThread::s(); assert(thread.MC); + + if (part_proc) + { + assert(&part_proc->MC == &thread.MC->get_part_MC()); + assert(&part_proc->P == &protocol.get_part().P); + return; + } + part_proc = new SubProcessor( thread.MC->get_part_MC(), part_prep, protocol.get_part().P); } diff --git a/GC/CcdShare.h b/GC/CcdShare.h index aececad0..e890ce63 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -27,6 +27,7 @@ public: typedef ShamirInput Input; typedef ShamirMC MAC_Check; + typedef Shamir Protocol; typedef This small_type; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 55c537de..00e6c52c 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -108,6 +108,9 @@ public: template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } static FakeSecret input(GC::Processor& processor, const InputArgs& args); static FakeSecret input(int from, word input, int n_bits); diff --git a/GC/Instruction.cpp b/GC/Instruction.cpp index 3fe0cc58..6be1eb1a 100644 --- a/GC/Instruction.cpp +++ b/GC/Instruction.cpp @@ -84,7 +84,7 @@ void Instruction::parse(istream& s, int pos) ostringstream os; os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl; os << "This virtual machine executes binary circuits only." << endl; - os << "Try compiling with '-B' or use only sbit* types." << endl; + os << "Use 'compile.py -B'." << endl; throw Invalid_Instruction(os.str()); break; } diff --git a/GC/NoShare.h b/GC/NoShare.h index f60eccd7..c435ec3f 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -7,6 +7,7 @@ #define GC_NOSHARE_H_ #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Protocols/ShareInterface.h" class InputArgs; @@ -148,11 +149,14 @@ public: static void trans(Processor&, Integer, const vector&) { fail(); } + static void andm(GC::Processor&, const BaseInstruction&) { fail(); } + static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; } NoShare() {} - NoShare(int) { fail(); } + template + NoShare(T) { fail(); } void load_clear(Integer, Integer) { fail(); } void random_bit() { fail(); } diff --git a/GC/PostSacriBin.cpp b/GC/PostSacriBin.cpp index 81341cf0..74248060 100644 --- a/GC/PostSacriBin.cpp +++ b/GC/PostSacriBin.cpp @@ -9,6 +9,7 @@ #include "Protocols/Replicated.hpp" #include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "ShareSecret.hpp" namespace GC @@ -28,24 +29,19 @@ PostSacriBin::~PostSacriBin() } } -void PostSacriBin::init_mul(SubProcessor* proc) -{ - assert(proc != 0); - init_mul(proc->DataF, proc->MC); -} - -void PostSacriBin::init_mul(Preprocessing&, T::MC&) +void PostSacriBin::init_mul() { if ((int) inputs.size() >= OnlineOptions::singleton.batch_size) check(); honest.init_mul(); } -PostSacriBin::T::clear PostSacriBin::prepare_mul(const T& x, const T& y, int n) +void PostSacriBin::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; honest.prepare_mul(x, y, n); inputs.push_back({{x.mask(n), y.mask(n)}}); - return {}; } void PostSacriBin::exchange() @@ -55,6 +51,8 @@ void PostSacriBin::exchange() PostSacriBin::T PostSacriBin::finalize_mul(int n) { + if (n == -1) + n = T::default_length; auto res = honest.finalize_mul(n); outputs.push_back({res, n}); return res; diff --git a/GC/PostSacriBin.h b/GC/PostSacriBin.h index 50baa9c5..8f1643a7 100644 --- a/GC/PostSacriBin.h +++ b/GC/PostSacriBin.h @@ -38,9 +38,8 @@ public: PostSacriBin(Player& P); ~PostSacriBin(); - void init_mul(Preprocessing&, T::MC&); - void init_mul(SubProcessor* proc); - T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index 1c91fd39..f83fbdaf 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_REPPREP_HPP_ +#define GC_REPPREP_HPP_ + #include "RepPrep.h" #include "ShareThread.h" #include "Processor/OnlineOptions.h" @@ -98,3 +101,5 @@ void RepPrep::buffer_inputs(int player) } } /* namespace GC */ + +#endif diff --git a/GC/Secret.h b/GC/Secret.h index 14f6638a..c4b6e8eb 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -126,6 +126,9 @@ public: template static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { T::convcbit2s(processor, instruction); } + template + static void andm(Processor& processor, const BaseInstruction& instruction) + { T::andm(processor, instruction); } Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 9fc3f491..9eed3b31 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -24,12 +24,15 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) : void SemiPrep::set_protocol(Beaver& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, {}, &protocol.P); triple_generator->multi_threaded = false; @@ -61,12 +64,4 @@ void SemiPrep::buffer_bits() } } -NamedCommStats SemiPrep::comm_stats() -{ - if (triple_generator) - return triple_generator->comm_stats(); - else - return {}; -} - } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 97214c28..737cfb98 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -44,6 +44,8 @@ public: array get_triple_no_count(int n_bits) { + if (n_bits == -1) + n_bits = SemiSecret::default_length; return ShiftableTripleBuffer::get_triple_no_count(n_bits); } @@ -51,8 +53,6 @@ public: { throw not_implemented(); } - - NamedCommStats comm_stats(); }; } /* namespace GC */ diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 10cf65c0..48f75b8f 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -78,6 +78,8 @@ public: template static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { processor.convcbit2s(instruction); } + static void andm(Processor& processor, const BaseInstruction& instruction) + { processor.andm(instruction); } static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 1a508828..23c86cb2 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -47,7 +47,7 @@ void ShareSecret::invert(int n, const U& x) { U ones; ones.load_clear(64, -1); - static_cast(*this) = U(x ^ ones) & get_mask(n); + reinterpret_cast(*this) = U(x + ones) & get_mask(n); } template @@ -92,8 +92,12 @@ template void ShareSecret::store_clear_in_dynamic(Memory& mem, const vector& accesses) { + auto& thread = ShareThread::s(); + assert(thread.P); + assert(thread.MC); for (auto access : accesses) - mem[access.address] = access.value; + mem[access.address] = U::constant(access.value, thread.P->my_num(), + thread.MC->get_alphai()); } template @@ -330,7 +334,7 @@ void ShareSecret::random_bit() template U& GC::ShareSecret::operator=(const U& other) { - U& real_this = static_cast(*this); + U& real_this = reinterpret_cast(*this); real_this = other; return real_this; } diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 5f995e80..42c5e3bd 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -58,9 +58,6 @@ public: void pre_run(); void post_run() { ShareThread::post_run(); } - - NamedCommStats comm_stats() - { return Thread::comm_stats() + this->DataF.comm_stats(); } }; template diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 14d49611..07085040 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -63,6 +63,7 @@ void ShareThread::pre_run(Player& P, typename T::mac_key_type mac_key) protocol = new typename T::Protocol(*this->P); MC = this->new_mc(mac_key); DataF.set_protocol(*this->protocol); + this->protocol->init(DataF, *MC); } template @@ -85,7 +86,7 @@ void ShareThread::and_(Processor& processor, { auto& protocol = this->protocol; processor.check_args(args, 4); - protocol->init_mul(DataF, *this->MC); + protocol->init_mul(); T x_ext, y_ext; for (size_t i = 0; i < args.size(); i += 4) { diff --git a/GC/Thread.h b/GC/Thread.h index 659c070a..6631ad72 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -55,8 +55,6 @@ public: void join_tape(); void finish(); - - virtual NamedCommStats comm_stats(); }; template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 5487c41b..d0b515cb 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -96,13 +96,6 @@ void Thread::finish() pthread_join(thread, 0); } -template -NamedCommStats Thread::comm_stats() -{ - assert(P); - return P->comm_stats; -} - } /* namespace GC */ diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 060e9f11..c6c9dcaa 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -95,11 +95,11 @@ void ThreadMaster::run() post_run(); - NamedCommStats stats = P->comm_stats; + NamedCommStats stats = P->total_comm(); ExecutionStats exe_stats; for (auto thread : threads) { - stats += thread->P->comm_stats; + stats += thread->P->total_comm(); exe_stats += thread->processor.stats; delete thread; } diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h index 34beaf6f..4e316e38 100644 --- a/GC/TinierSharePrep.h +++ b/GC/TinierSharePrep.h @@ -44,8 +44,6 @@ public: ~TinierSharePrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; } diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index 57e759b9..e136ec44 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -8,7 +8,7 @@ #include "TinierSharePrep.h" -#include "PersonalPrep.hpp" +#include "PersonalPrep.h" namespace GC { @@ -39,14 +39,17 @@ template void TinierSharePrep::set_protocol(typename T::Protocol& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } params.generateMACs = true; params.amplify = false; params.check = false; auto& thread = ShareThread::s(); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &protocol.P); triple_generator->multi_threaded = false; @@ -84,17 +87,6 @@ void GC::TinierSharePrep::buffer_bits() BufferPrep::get_random_from_inputs(thread.P->num_players())); } -template -NamedCommStats TinierSharePrep::comm_stats() -{ - NamedCommStats res; - if (triple_generator) - res += triple_generator->comm_stats(); - if (real_triple_generator) - res += real_triple_generator->comm_stats(); - return res; -} - } #endif diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 2b8a11b7..897b3b48 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -16,7 +16,7 @@ void TinierSharePrep::init_real(Player& P) assert(real_triple_generator == 0); auto& thread = ShareThread::s(); real_triple_generator = new typename T::whole_type::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), P.N, -1, + BaseMachine::fresh_ot_setup(P), P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &P); real_triple_generator->multi_threaded = false; diff --git a/GC/VectorInput.h b/GC/VectorInput.h index c17cd93d..44c9591b 100644 --- a/GC/VectorInput.h +++ b/GC/VectorInput.h @@ -36,6 +36,8 @@ public: void add_mine(const typename T::open_type& input, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_mine(input.get_bit(i)); input_lengths.push_back(n_bits); @@ -43,6 +45,8 @@ public: void add_other(int player, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_other(player); } @@ -69,6 +73,8 @@ public: void finalize_other(int player, T& target, octetStream&, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; target.resize_regs(n_bits); for (int i = 0; i < n_bits; i++) part_input.finalize_other(player, target.get_reg(i), diff --git a/GC/VectorProtocol.h b/GC/VectorProtocol.h index 3f7e203c..94ef1989 100644 --- a/GC/VectorProtocol.h +++ b/GC/VectorProtocol.h @@ -21,9 +21,10 @@ public: VectorProtocol(Player& P); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void finalize_mult(T& res, int n = -1); T finalize_mul(int n = -1); diff --git a/GC/VectorProtocol.hpp b/GC/VectorProtocol.hpp index cae46181..e72e0d14 100644 --- a/GC/VectorProtocol.hpp +++ b/GC/VectorProtocol.hpp @@ -18,26 +18,26 @@ VectorProtocol::VectorProtocol(Player& P) : } template -void VectorProtocol::init_mul(SubProcessor* proc) -{ - assert(proc); - init_mul(proc->DataF, proc->MC); -} - -template -void VectorProtocol::init_mul(Preprocessing& prep, +void VectorProtocol::init(Preprocessing& prep, typename T::MAC_Check& MC) { - part_protocol.init_mul(prep.get_part(), MC.get_part_MC()); + part_protocol.init(prep.get_part(), MC.get_part_MC()); } template -typename T::clear VectorProtocol::prepare_mul(const T& x, +void VectorProtocol::init_mul() +{ + part_protocol.init_mul(); +} + +template +void VectorProtocol::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; for (int i = 0; i < n; i++) part_protocol.prepare_mul(x.get_reg(i), y.get_reg(i), 1); - return {}; } template @@ -57,6 +57,8 @@ T VectorProtocol::finalize_mul(int n) template void VectorProtocol::finalize_mult(T& res, int n) { + if (n == -1) + n = T::default_length; res.resize_regs(n); for (int i = 0; i < n; i++) res.get_reg(i) = part_protocol.finalize_mul(1); diff --git a/GC/instructions.h b/GC/instructions.h index fc278d44..66ae46d2 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -46,6 +46,7 @@ X(NOTCB, processor.notcb(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ + X(ANDM, T::andm(PROC, instruction)) \ X(ADDCB, C0 = PC1 + PC2) \ X(ADDCBI, C0 = PC1 + int(IMM)) \ X(MULCBI, C0 = PC1 * int(IMM)) \ @@ -76,7 +77,6 @@ #define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \ X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \ - X(ANDM, processor.andm(instruction)) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ diff --git a/License.txt b/License.txt index 3a9eb2ae..ccaafe01 100644 --- a/License.txt +++ b/License.txt @@ -1,5 +1,5 @@ CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) -Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. +Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material. Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. diff --git a/Machines/Atlas.hpp b/Machines/Atlas.hpp new file mode 100644 index 00000000..045b69b9 --- /dev/null +++ b/Machines/Atlas.hpp @@ -0,0 +1,16 @@ +/* + * Atlas.hpp + * + */ + +#ifndef MACHINES_ATLAS_HPP_ +#define MACHINES_ATLAS_HPP_ + +#include "Protocols/AtlasShare.h" +#include "Protocols/AtlasPrep.h" +#include "GC/AtlasSecret.h" + +#include "ShamirMachine.hpp" +#include "Protocols/Atlas.hpp" + +#endif /* MACHINES_ATLAS_HPP_ */ diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index d37c385c..a480860f 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -4,6 +4,7 @@ */ #include "Protocols/MalRepRingPrep.h" +#include "Protocols/ReplicatedPrep2k.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" diff --git a/Machines/Rep4.hpp b/Machines/Rep4.hpp new file mode 100644 index 00000000..83ad1cff --- /dev/null +++ b/Machines/Rep4.hpp @@ -0,0 +1,17 @@ +/* + * Rep4.hpp + * + */ + +#ifndef MACHINES_REP4_HPP_ +#define MACHINES_REP4_HPP_ + +#include "GC/Rep4Secret.h" +#include "Protocols/Rep4Share2k.h" +#include "Protocols/Rep4Prep.h" +#include "Protocols/Rep4.hpp" +#include "Protocols/Rep4MC.hpp" +#include "Protocols/Rep4Input.hpp" +#include "Protocols/Rep4Prep.hpp" + +#endif /* MACHINES_REP4_HPP_ */ diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index 02ad9b98..a221b087 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -21,13 +21,15 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" +#include "GC/VectorProtocol.h" -#include "GC/ShareParty.hpp" +#include "GC/ShareParty.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/ShareSecret.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/ShareSecret.h" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" #include "Math/gfp.hpp" diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 672a29b4..6cb02779 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -23,9 +23,10 @@ #include "Protocols/MascotPrep.hpp" #include "Protocols/Spdz2kPrep.hpp" -#include "GC/ShareParty.hpp" -#include "GC/ShareSecret.hpp" +#include "GC/ShareParty.h" +#include "GC/ShareSecret.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 36c9d8c5..1a093146 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -18,3 +18,4 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" diff --git a/Machines/Semi2k.hpp b/Machines/Semi2k.hpp new file mode 100644 index 00000000..56f86d9b --- /dev/null +++ b/Machines/Semi2k.hpp @@ -0,0 +1,15 @@ +/* + * Semi2.hpp + * + */ + +#ifndef MACHINES_SEMI2K_HPP_ +#define MACHINES_SEMI2K_HPP_ + +#include "Protocols/Semi2kShare.h" +#include "Protocols/SemiPrep2k.h" + +#include "Semi.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" + +#endif /* MACHINES_SEMI2K_HPP_ */ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 080332ae..7697c512 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -27,6 +27,7 @@ #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" #include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/VectorProtocol.hpp" #include "GC/Secret.hpp" diff --git a/Machines/Tinier.cpp b/Machines/Tinier.cpp new file mode 100644 index 00000000..99ad1c5c --- /dev/null +++ b/Machines/Tinier.cpp @@ -0,0 +1,23 @@ +/* + * Tinier.cpp + * + */ + +#include "GC/TinyMC.h" +#include "GC/TinierSecret.h" +#include "GC/VectorInput.h" + +#include "GC/ShareParty.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/PersonalPrep.hpp" + +//template class GC::ShareParty>; +template class GC::CcdPrep>; +template class Preprocessing>; +template class GC::TinierSharePrep>; +template class GC::ShareSecret>; +template class TripleShuffleSacrifice>; diff --git a/Machines/atlas-party.cpp b/Machines/atlas-party.cpp index 6e754c7f..2df033e6 100644 --- a/Machines/atlas-party.cpp +++ b/Machines/atlas-party.cpp @@ -3,12 +3,7 @@ * */ -#include "Protocols/AtlasShare.h" -#include "Protocols/AtlasPrep.h" -#include "GC/AtlasSecret.h" - -#include "ShamirMachine.hpp" -#include "Protocols/Atlas.hpp" +#include "Atlas.hpp" int main(int argc, const char** argv) { diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 8525b067..f26f5f32 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -10,11 +10,13 @@ #include "Processor/RingOptions.h" #include "Processor/Machine.hpp" +#include "Processor/OnlineOptions.hpp" #include "Math/Z2k.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/ShuffleSacrifice.hpp" #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" +#include "Protocols/MalRepRingPrep.hpp" int main(int argc, const char** argv) { @@ -22,7 +24,7 @@ int main(int argc, const char** argv) Names N; ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); - online_opts = {opt, argc, argv}; + online_opts = {opt, argc, argv, FakeShare>()}; opt.parse(argc, argv); opt.syntax = string(argv[0]) + " "; @@ -44,9 +46,7 @@ int main(int argc, const char** argv) #ifdef ROUND_NEAREST_IN_EMULATION cerr << "Using nearest rounding instead of probabilistic truncation" << endl; #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - cerr << "Using risky truncation" << endl; -#endif + online_opts.set_trunc_error(opt); #endif int R = ring_opts.ring_size_from_opts_or_schedule(progname); diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 471862da..934c15dc 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -24,6 +24,7 @@ #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Hemi.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index 2120322f..ce542de1 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -8,6 +8,7 @@ #include "Processor/OnlineMachine.hpp" #include "Processor/Machine.hpp" #include "Protocols/Replicated.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "Math/gfp.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index 6f7c70a3..7ecc450d 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -22,6 +22,7 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Makefile b/Makefile index 9d634c0e..e40528b8 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,8 @@ MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VM = $(MINI_OT) $(SHAREDLIB) COMMON = $(SHAREDLIB) +TINIER = Machines/Tinier.o $(OT) +SPDZ = Machines/SPDZ.o $(TINIER) LIB = libSPDZ.a @@ -117,7 +119,7 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC) +$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(TINIER) $(GC) $(AR) -csr $@ $^ CFLAGS += -fPIC @@ -203,16 +205,16 @@ ps-rep-bin-party.x: GC/PostSacriBin.o semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) tinier-party.x: $(OT) -spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) +spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) -cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o -highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o +cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o +highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) @@ -220,10 +222,10 @@ static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o -mascot-party.x: Machines/SPDZ.o $(OT) -static/mascot-party.x: Machines/SPDZ.o -Player-Online.x: Machines/SPDZ.o $(OT) -mama-party.x: $(OT) +mascot-party.x: $(SPDZ) +static/mascot-party.x: $(SPDZ) +Player-Online.x: $(SPDZ) +mama-party.x: $(TINIER) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o @@ -236,8 +238,10 @@ emulate.x: GC/FakeSecret.o semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -mascot-offline.x: $(VM) $(OT) -cowgear-offline.x: $(OT) $(FHEOFFLINE) +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o +mascot-offline.x: $(VM) $(TINIER) +cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) static/mal-rep-bmr-party.x: $(BMR) static/shamir-bmr-party.x: $(BMR) diff --git a/Math/BitVec.h b/Math/BitVec.h index f9e874d1..f0d60a1b 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -26,6 +26,7 @@ public: static const false_type invertible; static const true_type characteristic_two; + static const true_type binary; static char type_char() { return 'B'; } static string type_short() { return "B"; } @@ -64,8 +65,21 @@ public: void pack(octetStream& os) const { os.store_int(this->a); } void unpack(octetStream& os) { this->a = os.get_int(); } - void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } - void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); } + void pack(octetStream& os, int n) const + { + if (n == -1) + pack(os); + else + os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); + } + + void unpack(octetStream& os, int n) + { + if (n == -1) + unpack(os); + else + this->a = os.get_int(DIV_CEIL(n, 8)); + } static BitVec_ unpack_new(octetStream& os, int n = n_bits) { @@ -81,5 +95,7 @@ template const false_type BitVec_::invertible; template const true_type BitVec_::characteristic_two; +template +const true_type BitVec_::binary; #endif /* MATH_BITVEC_H_ */ diff --git a/Math/Setup.hpp b/Math/Setup.hpp index 6545d67e..91cafaea 100644 --- a/Math/Setup.hpp +++ b/Math/Setup.hpp @@ -36,8 +36,9 @@ void read_setup(const string& dir_prefix, int lgp = -1) { if (lgp > 0) { - cerr << "No modulus found in " << filename << ", generating " << lgp - << "-bit prime" << endl; + if (OnlineOptions::singleton.verbose) + cerr << "No modulus found in " << filename << ", generating " + << lgp << "-bit prime" << endl; T::init_default(lgp); } else diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index d15af24c..07807cb2 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -20,6 +20,7 @@ public: static const false_type characteristic_two; static const false_type prime_field; static const false_type invertible; + static const false_type binary; template static void init(bool mont = true) { (void) mont; } diff --git a/Math/Z2k.h b/Math/Z2k.h index 3e653044..ad32cbf1 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -47,6 +47,7 @@ public: static int size_in_limbs() { return N_WORDS; } static int size_in_bits() { return size() * 8; } static int length() { return size_in_bits(); } + static int n_bits() { return N_BITS; } static int t() { return 0; } static char type_char() { return 'R'; } @@ -100,6 +101,8 @@ public: int bit_length() const; + Z2 mask(int) const { return *this; } + Z2 operator+(const Z2& other) const; Z2 operator-(const Z2& other) const; diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 63c279a2..17fcdf24 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -86,6 +86,42 @@ void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t { inline_mpn_copyi(z,ans+t,t); } } +void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x, + const mp_limb_t* y) const +{ + switch (t) + { +#ifdef __BMI2__ +#define CASE(N) \ + case N: \ + Mont_Mult_(z, x, y); \ + break; + CASE(1) + CASE(2) +#if MAX_MOD_SZ >= 4 + CASE(3) + CASE(4) +#endif +#if MAX_MOD_SZ >= 5 + CASE(5) +#endif +#if MAX_MOD_SZ >= 6 + CASE(6) +#endif +#if MAX_MOD_SZ >= 10 + CASE(7) + CASE(8) + CASE(9) + CASE(10) +#endif +#undef CASE +#endif + default: + Mont_Mult_variable(z, x, y); + break; + } +} + ostream& operator<<(ostream& s,const Zp_Data& ZpD) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 96deb795..f30e7103 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -40,6 +40,7 @@ class Zp_Data template void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; + void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const; void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const { Mont_Mult(z, x, y, t); } @@ -242,37 +243,11 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* { if (not cpu_has_bmi2()) return Mont_Mult_variable(z, x, y); - switch (t) - { #ifdef __BMI2__ -#define CASE(N) \ - case N: \ - Mont_Mult_(z, x, y); \ - break; - CASE(1) - CASE(2) -#if MAX_MOD_SZ >= 4 - CASE(3) - CASE(4) + return Mont_Mult_switch(z, x, y); +#else + return Mont_Mult_variable(z, x, y); #endif -#if MAX_MOD_SZ >= 5 - CASE(5) -#endif -#if MAX_MOD_SZ >= 6 - CASE(6) -#endif -#if MAX_MOD_SZ >= 10 - CASE(7) - CASE(8) - CASE(9) - CASE(10) -#endif -#undef CASE -#endif - default: - Mont_Mult_variable(z, x, y); - break; - } } inline void Zp_Data::Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, diff --git a/Math/gfp.h b/Math/gfp.h index 7b257b5f..bde43025 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -11,7 +11,6 @@ using namespace std; #include "Math/Bit.h" #include "Math/Setup.h" #include "Tools/random.h" -#include "GC/NoShare.h" #include "Processor/OnlineOptions.h" #include "Math/modp.hpp" @@ -101,6 +100,7 @@ class gfp_ : public ValueInterface static int size() { return t() * sizeof(mp_limb_t); } static int size_in_bits() { return 8 * size(); } static int length() { return ZpD.pr_bit_length; } + static int n_bits() { return length() - 1; } static void reqbl(int n); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index c2e1403b..9d8da651 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -5,6 +5,7 @@ #include "CryptoPlayer.h" #include "Math/Setup.h" +#include "Tools/Bundle.h" void check_ssl_file(string filename) { @@ -124,12 +125,14 @@ CryptoPlayer::~CryptoPlayer() void CryptoPlayer::send_to_no_stats(int other, const octetStream& o) const { + assert(other != my_num()); senders[other]->request(o); senders[other]->wait(o); } void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const { + assert(other != my_num()); receivers[other]->request(o); receivers[other]->wait(o); } @@ -137,6 +140,7 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, octetStream& to_receive) const { + assert(other != my_num()); if (&to_send == &to_receive) { MultiPlayer::exchange_no_stats(other, to_send, to_receive); @@ -153,6 +157,7 @@ void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, void CryptoPlayer::pass_around_no_stats(const octetStream& to_send, octetStream& to_receive, int offset) const { + assert(get_player(offset) != my_num()); if (&to_send == &to_receive) { MultiPlayer::pass_around_no_stats(to_send, to_receive, offset); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 61c8fd65..cd92df54 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -14,12 +14,14 @@ using namespace std; -void Names::init(int player,int pnb,int my_port,const char* servername) +void Names::init(int player, int pnb, int my_port, const char* servername, + bool setup_socket) { player_no=player; portnum_base=pnb; setup_names(servername, my_port); - setup_server(); + if (setup_socket) + setup_server(); } Names::Names(int player, int nplayers, const string& servername, int pnb, @@ -124,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port) my_port = default_port(player_no); int socket_num; - int pn = portnum_base - 1; + int pn = portnum_base; set_up_client_socket(socket_num, servername, pn); octetStream("P" + to_string(player_no)).Send(socket_num); #ifdef DEBUG_NETWORKING @@ -132,15 +134,11 @@ void Names::setup_names(const char *servername, int my_port) #endif // Send my name - octet my_name[512]; - memset(my_name,0,512*sizeof(octet)); sockaddr_in address; socklen_t size = sizeof address; getsockname(socket_num, (sockaddr*)&address, &size); - char* name = inet_ntoa(address.sin_addr); - // max length of IP address with ending 0 - strncpy((char*)my_name, name, 16); - send(socket_num,my_name,512); + char* my_name = inet_ntoa(address.sin_addr); + octetStream(my_name).Send(socket_num); send(socket_num,(octet*)&my_port,4); #ifdef DEBUG_NETWORKING fprintf(stderr, "My Name = %s\n",my_name); @@ -158,9 +156,10 @@ void Names::setup_names(const char *servername, int my_port) names.resize(nplayers); ports.resize(nplayers); for (i=0; iinit(); } +void Names::set_server(ServerSocket* socket) +{ + assert(not server); + server = socket; +} + Names::Names(const Names& other) { @@ -201,6 +206,7 @@ Player::Player(const Names& Nms) : { nplayers=Nms.nplayers; player_no=Nms.player_no; + thread_stats.resize(nplayers); } @@ -243,6 +249,10 @@ MultiPlayer::~MultiPlayer() Player::~Player() { +#ifdef VERBOSE + for (auto& x : thread_stats) + x.print(); +#endif } PlayerBase::~PlayerBase() @@ -685,7 +695,7 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const { TimeScope ts(comm_stats["Sending one-to-one"].add(o)); P.send_to_no_stats(other_player, o); - sent += o.get_length(); + comm_stats.sent += o.get_length(); } void RealTwoPartyPlayer::receive(octetStream& o) const @@ -729,12 +739,13 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const void VirtualTwoPartyPlayer::send_receive_player(vector& o) const { TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0])); - sent += o[0].get_length(); + comm_stats.sent += o[0].get_length(); P.exchange_no_stats(other_player, o[0], o[1]); } VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) : - TwoPartyPlayer(P.my_num()), P(P), other_player(other_player) + TwoPartyPlayer(P.my_num()), P(P), other_player(other_player), comm_stats( + P.thread_stats.at(other_player)) { } @@ -814,5 +825,13 @@ void NamedCommStats::print(bool newline) cerr << endl; } +NamedCommStats Player::total_comm() const +{ + auto res = comm_stats; + for (auto& x : thread_stats) + res += x; + return res; +} + template class MultiPlayer; template class MultiPlayer ; diff --git a/Networking/Player.h b/Networking/Player.h index 033aa3bd..9c90dbd1 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -35,6 +35,7 @@ class Names friend class Player; friend class PlainPlayer; friend class RealTwoPartyPlayer; + friend class Server; vector names; vector ports; @@ -51,6 +52,8 @@ class Names void setup_server(); + void set_server(ServerSocket* socket); + public: static const int DEFAULT_PORT = -1; @@ -62,8 +65,10 @@ class Names * @param my_port my port number (`DEFAULT_PORT` for default, * which is base port number plus player number) * @param servername location of server + * @param setup_socket whether to start listening */ - void init(int player,int pnb,int my_port,const char* servername); + void init(int player, int pnb, int my_port, const char* servername, + bool setup_socket = true); Names(int player,int pnb,int my_port,const char* servername) : Names() { init(player,pnb,my_port,servername); } @@ -172,11 +177,12 @@ class PlayerBase protected: int player_no; -public: size_t& sent; - mutable Timer timer; mutable NamedCommStats comm_stats; +public: + mutable Timer timer; + PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {} virtual ~PlayerBase(); @@ -205,6 +211,8 @@ protected: public: const Names& N; + mutable vector thread_stats; + Player(const Names& Nms); virtual ~Player(); @@ -358,6 +366,8 @@ public: virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; } virtual void wait_receive(int i, octetStream& o) const { receive_player(i, o); } + + NamedCommStats total_comm() const; }; /** @@ -500,6 +510,7 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer { Player& P; int other_player; + NamedCommStats& comm_stats; public: VirtualTwoPartyPlayer(Player& P, int other_player); diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index e93f47c4..7e8c93fe 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -51,9 +51,17 @@ void Receiver::run() while (in.pop(os)) { os->reset_write_head(); +#ifdef VERBOSE_SSL timer.start(); + RunningTimer mytimer; +#endif os->Receive(socket); +#ifdef VERBOSE_SSL + cout << "receiving " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 51d5f471..4e4b9881 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -47,9 +47,17 @@ void Sender::run() const octetStream* os = 0; while (in.pop(os)) { -// timer.start(); +#ifdef VERBOSE_SSL + timer.start(); + RunningTimer mytimer; +#endif os->Send(socket); -// timer.stop(); +#ifdef VERBOSE_SSL + cout << "sending " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; + timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Server.cpp b/Networking/Server.cpp index d9a056dd..facda0a2 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -28,9 +28,7 @@ void Server::get_ip(int num) inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr); } - names[num]=new octet[512]; - memset(names[num], 0, 512); - strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN); + names[num] = ipstr; #ifdef DEBUG_NETWORKING cerr << "Client IP address: " << names[num] << endl; @@ -45,11 +43,11 @@ void Server::get_name(int num) #endif // Receive name sent by client (legacy) - not used here - octet my_name[512]; - receive(socket_num[num],my_name,512); + octetStream os; + os.Receive(socket_num[num]); receive(socket_num[num],(octet*)&ports[num],4); #ifdef DEBUG_NETWORKING - cerr << "Player " << num << " sent (IP for info only) " << my_name << ":" + cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":" << ports[num] << endl; #endif @@ -66,7 +64,7 @@ void Server::send_names(int num) send(socket_num[num],nmachines,4); for (int i=0; i= 0); assert(my_num < nplayers); @@ -172,12 +175,19 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, { pthread_create(&thread, 0, Server::start_in_thread, server = new Server(nplayers, portnum)); - } - N.init(my_num, portnum, my_port, hostname.c_str()); - if (my_num == 0) - { + N.init(my_num, portnum, my_port, hostname.c_str(), false); pthread_join(thread, 0); + N.set_server(server->get_socket()); delete server; } + else + N.init(my_num, portnum, my_port, hostname.c_str()); return 0; } + +ServerSocket* Server::get_socket() +{ + auto res = server_socket; + server_socket = 0; + return res; +} diff --git a/Networking/Server.h b/Networking/Server.h index a5e833ad..ad6d5fd5 100644 --- a/Networking/Server.h +++ b/Networking/Server.h @@ -14,10 +14,11 @@ using namespace std; class Server { vector socket_num; - vector names; + vector names; vector ports; int nmachines; int PortnumBase; + ServerSocket* server_socket; void get_ip(int num); void get_name(int num); @@ -31,7 +32,11 @@ public: Server(int argc, char** argv); Server(int nmachines, int PortnumBase); + ~Server(); + void start(); + + ServerSocket* get_socket(); }; #endif /* NETWORKING_SERVER_H_ */ diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 8989a0a1..79cb3522 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -7,6 +7,7 @@ #define CRYPTO_SSL_SOCKETS_H_ #include "Tools/int.h" +#include "Tools/time-func.h" #include "sockets.h" #include "Math/Setup.h" @@ -46,6 +47,10 @@ public: string me, bool client) : parent(io_service, ctx) { +#ifdef DEBUG_NETWORKING + cerr << me << " setting up SSL to " << other << " as " << + (client ? "client" : "server") << endl; +#endif lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket); set_verify_mode(boost::asio::ssl::verify_peer); set_verify_callback(boost::asio::ssl::rfc2818_verification(other)); @@ -82,8 +87,16 @@ template<> inline void send(ssl_socket* socket, octet* data, size_t length) { size_t sent = 0; +#ifdef VERBOSE_SSL + RunningTimer timer; +#endif while (sent < length) + { sent += send_non_blocking(socket, data + sent, length - sent); +#ifdef VERBOSE_SSL + cout << "sent " << sent * 1e-6 << " MB at " << timer.elapsed() << endl; +#endif + } } template<> diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 8847728e..98856585 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -1,6 +1,7 @@ #include "OT/BaseOT.h" #include "Tools/random.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include #include @@ -78,6 +79,23 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol void BaseOT::exec_base(bool new_receiver_inputs) { + Bundle bundle(*P); +#ifdef NO_AVX_OT + bundle.mine = string("OT without AVX"); +#else + bundle.mine = string("OT with AVX"); +#endif + try + { + bundle.compare(*P); + } + catch (mismatch_among_parties&) + { + cerr << "Parties compiled with different base OT algorithms" << endl; + cerr << "Set \"AVX_OT\" to the same value on all parties" << endl; + exit(1); + } + #ifdef NO_AVX_OT #ifdef USE_RISTRETTO typedef CurveElement Element; diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index d5981e71..8a84ca0a 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -116,7 +116,7 @@ public: mac_key_type get_mac_key() const { return mac_key; } - NamedCommStats comm_stats(); + Player& get_player() { return globalPlayer; } }; template @@ -209,15 +209,4 @@ public: void generateTriples(); }; -template -NamedCommStats OTTripleGenerator::comm_stats() -{ - NamedCommStats res; - if (parentPlayer != &globalPlayer) - res = globalPlayer.comm_stats; - for (auto& player : players) - res += player->comm_stats; - return res; -} - #endif diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index bc36a860..019fc6f2 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -110,22 +110,31 @@ void BaseMachine::time() void BaseMachine::start(int n) { cout << "Starting timer " << n << " at " << timer[n].elapsed() + << " (" << timer[n].mb_sent() << " MB)" << " after " << timer[n].idle() << endl; - timer[n].start(); + timer[n].start(total_comm()); } void BaseMachine::stop(int n) { - timer[n].stop(); - cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl; + timer[n].stop(total_comm()); + cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" + << timer[n].mb_sent() << " MB)" << endl; } void BaseMachine::print_timers() { + cerr << "The following timing is "; + if (OnlineOptions::singleton.live_prep) + cerr << "in"; + else + cerr << "ex"; + cerr << "clusive preprocessing." << endl; cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; timer.erase(0); - for (map::iterator it = timer.begin(); it != timer.end(); it++) - cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + for (auto it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" + << it->second.mb_sent() << " MB)" << endl; } string BaseMachine::memory_filename(const string& type_short, int my_number) @@ -170,3 +179,18 @@ bigint BaseMachine::prime_from_schedule(string progname) else return 0; } + +NamedCommStats BaseMachine::total_comm() +{ + NamedCommStats res; + for (auto& queue : queues) + res += queue->get_comm_stats(); + return res; +} + +void BaseMachine::set_thread_comm(const NamedCommStats& stats) +{ + auto queue = queues.at(BaseMachine::thread_num); + assert(queue); + queue->set_comm_stats(stats); +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 0e08549e..035a0cfe 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -7,6 +7,7 @@ #define PROCESSOR_BASEMACHINE_H_ #include "Tools/time-func.h" +#include "Tools/TimerWithComm.h" #include "OT/OTTripleSetup.h" #include "ThreadJob.h" #include "ThreadQueues.h" @@ -22,7 +23,7 @@ class BaseMachine protected: static BaseMachine* singleton; - std::map timer; + std::map timer; string compiler; string domain; @@ -66,12 +67,18 @@ public: virtual void reqbl(int) {} - OTTripleSetup fresh_ot_setup(); + static OTTripleSetup fresh_ot_setup(Player& P); + + NamedCommStats total_comm(); + void set_thread_comm(const NamedCommStats& stats); }; -inline OTTripleSetup BaseMachine::fresh_ot_setup() +inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) { - return ot_setups.at(thread_num).get_fresh(); + if (singleton and size_t(thread_num) < s().ot_setups.size()) + return s().ot_setups.at(thread_num).get_fresh(); + else + return OTTripleSetup(P, true); } #endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index be1fb8fd..9878f4a6 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -38,7 +38,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, int size_in_bytes = T::size() * buffer.size(); int n_read = 0; - char * read_buffer = new char[size_in_bytes]; + char read_buffer[size_in_bytes]; inf.seekg(start_posn * T::size()); do { diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 8f44ed25..8d05747e 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -89,6 +89,7 @@ template class Processor; template class Data_Files; template class Machine; template class SubProcessor; +template class NoFilePrep; /** * Abstract base class for preprocessing @@ -125,6 +126,7 @@ public: template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); + template static Preprocessing* get_new(bool live_prep, const Names& N, DataPositions& usage); static Preprocessing* get_live_prep(SubProcessor* proc, @@ -133,22 +135,21 @@ public: Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {} virtual ~Preprocessing() {} - virtual void set_protocol(typename T::Protocol& protocol) = 0; + virtual void set_protocol(typename T::Protocol&) {}; virtual void set_proc(SubProcessor* proc) { (void) proc; } virtual void seekg(DataPositions& pos) { (void) pos; } virtual void prune() {} virtual void purge() {} - virtual size_t data_sent() { return comm_stats().sent; } - virtual NamedCommStats comm_stats() { return {}; } - - virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; - virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; - virtual void get_one_no_count(Dtype dtype, T& a) = 0; - virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0; - virtual void get_no_count(vector& S, DataTag tag, const vector& regs, - int vector_size) = 0; + virtual void get_three_no_count(Dtype, T&, T&, T&) + { throw not_implemented(); } + virtual void get_two_no_count(Dtype, T&, T&) { throw not_implemented(); } + virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); } + virtual void get_input_no_count(T&, typename T::open_type&, int) + { throw not_implemented() ; } + virtual void get_no_count(vector&, DataTag, const vector&, int) + { throw not_implemented(); } void get(Dtype dtype, T* a); void get_three(Dtype dtype, T& a, T& b, T& c); @@ -191,6 +192,9 @@ class Sub_Data_Files : public Preprocessing { template friend class Sub_Data_Files; + typedef typename conditional, NoFilePrep>::type part_type; + static int tuple_length(int dtype); BufferOwner buffers[N_DTYPE]; @@ -205,7 +209,7 @@ class Sub_Data_Files : public Preprocessing const string prep_data_dir; int thread_num; - Sub_Data_Files* part; + part_type* part; void buffer_edabits_with_queues(bool strict, int n_bits) { buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); } @@ -274,7 +278,7 @@ public: void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); void get_dabit_no_count(T& a, typename T::bit_type& b); - Preprocessing& get_part(); + part_type& get_part(); }; template @@ -307,8 +311,6 @@ class Data_Files } void reset_usage() { usage.reset(); skipped.reset(); } - - NamedCommStats comm_stats(); }; template inline @@ -418,6 +420,7 @@ T Preprocessing::get_bit() template T Preprocessing::get_random() { + assert(not usage.inputs.empty()); return get_random_from_inputs(usage.inputs.size()); } @@ -429,10 +432,4 @@ inline void Data_Files::purge() DataFb.purge(); } -template -NamedCommStats Data_Files::comm_stats() -{ - return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats(); -} - #endif diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 3635dc0a..359ff620 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -3,6 +3,7 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" +#include "Processor/NoFilePrep.h" #include "Protocols/dabit.h" #include "Math/Setup.h" #include "GC/BitPrepFiles.h" @@ -30,6 +31,7 @@ Preprocessing* Preprocessing::get_new( } template +template Preprocessing* Preprocessing::get_new( bool live_prep, const Names& N, DataPositions& usage) @@ -156,17 +158,7 @@ Data_Files::Data_Files(const Names& N) : template Data_Files::~Data_Files() { -#ifdef VERBOSE - if (DataFp.data_sent()) - cerr << "Sent for " << sint::type_string() << " preprocessing threads: " << - DataFp.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataFp; -#ifdef VERBOSE - if (DataF2.data_sent()) - cerr << "Sent for " << sgf2n::type_string() << " preprocessing threads: " << - DataF2.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataF2; delete &DataFb; } @@ -264,6 +256,8 @@ void Sub_Data_Files::purge() for (auto it : extended) it.second.purge(); dabit_buffer.purge(); + if (part != 0) + part->purge(); } template @@ -329,10 +323,10 @@ void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, } template -Preprocessing& Sub_Data_Files::get_part() +typename Sub_Data_Files::part_type& Sub_Data_Files::get_part() { if (part == 0) - part = new Sub_Data_Files(my_num, num_players, + part = new part_type(my_num, num_players, get_prep_sub_dir(num_players), this->usage, thread_num); return *part; diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 95bcd029..b3ed5bc5 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -87,10 +87,10 @@ public: { } - void init_mul(SubProcessor* = 0) + void init_mul() { } - typename T::clear prepare_mul(const T&, const T&, int = 0) + void prepare_mul(const T&, const T&, int = 0) { throw not_implemented(); } diff --git a/Processor/FieldMachine.h b/Processor/FieldMachine.h index c544fb96..859c64a1 100644 --- a/Processor/FieldMachine.h +++ b/Processor/FieldMachine.h @@ -9,6 +9,9 @@ #include "RingMachine.h" #include "HonestMajorityMachine.h" #include "Tools/ezOptionParser.h" +#include "Math/gfp.h" + +#include "OnlineOptions.hpp" template class U, class V = HonestMajorityMachine> class HonestMajorityFieldMachine @@ -36,7 +39,7 @@ public: ez::ezOptionParser& opt, bool live_prep_default = true) { OnlineOptions& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, 1000, live_prep_default, true}; + online_opts = {opt, argc, argv, T(), live_prep_default}; FieldMachine(argc, argv, opt, online_opts); } diff --git a/Processor/FieldMachine.hpp b/Processor/FieldMachine.hpp index f93517d9..89ec66e1 100644 --- a/Processor/FieldMachine.hpp +++ b/Processor/FieldMachine.hpp @@ -10,6 +10,7 @@ #include "HonestMajorityMachine.h" #include "Math/gfp.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class T, class V> @@ -24,7 +25,7 @@ template class T, class V> HonestMajorityFieldMachine::HonestMajorityFieldMachine(int argc, const char **argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv, 0, true, true); + OnlineOptions online_opts(opt, argc, argv, T()); FieldMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/HonestMajorityMachine.cpp b/Processor/HonestMajorityMachine.cpp index 3a756bc8..295ef5fa 100644 --- a/Processor/HonestMajorityMachine.cpp +++ b/Processor/HonestMajorityMachine.cpp @@ -18,7 +18,6 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) : OnlineMachine(argc, argv, opt, online_opts, nplayers) { - OnlineOptions::singleton = online_opts; opt.add( "", // Default. 0, // Required? @@ -29,6 +28,7 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, "--unencrypted" // Flag token. ); online_opts.finalize(opt, argc, argv); + OnlineOptions::singleton = online_opts; use_encryption = not opt.get("-u")->isSet; diff --git a/Processor/Input.h b/Processor/Input.h index 9816c357..98c6c83b 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -14,6 +14,8 @@ using namespace std; #include "Tools/PointerVector.h" class ArithmeticProcessor; +template class SubProcessor; +template class Preprocessing; /** * Abstract base for input protocols @@ -25,6 +27,7 @@ class InputBase protected: Player* P; + int my_num; Buffer buffer; Timer timer; @@ -58,7 +61,7 @@ public: /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input); + void add_from_all(const clear& input, int n_bits = -1); /// Send my inputs virtual void send_mine() = 0; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 9272535b..b9f7a77a 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -19,6 +19,7 @@ template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) { + my_num = -1; if (proc) buffer.setup(&proc->private_input, -1, proc->private_input_filename); } @@ -83,6 +84,7 @@ template void InputBase::reset_all(Player& P) { this->P = &P; + my_num = P.my_num(); os.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) reset(i); @@ -111,13 +113,13 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input) +void InputBase::add_from_all(const clear& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) - add_mine(input); + add_mine(input, n_bits); else - add_other(i); + add_other(i, n_bits); } template @@ -202,7 +204,7 @@ void Input::finalize_other(int player, T& target, template T InputBase::finalize(int player, int n_bits) { - if (player == P->my_num()) + if (player == my_num) return finalize_mine(); else { diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e516fdf3..e45a8504 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -1091,9 +1091,11 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.time(); break; case START: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.start(n); break; case STOP: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.stop(n); break; case RUN_TAPE: diff --git a/Processor/Machine.h b/Processor/Machine.h index 3f23dc9f..331a9a22 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -69,7 +69,6 @@ class Machine : public BaseMachine OnlineOptions opts; - NamedCommStats comm_stats; ExecutionStats stats; Machine(int my_number, Names& playerNames, const string& progname, diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 804dc51a..d7d1a3ec 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -142,6 +142,8 @@ template Machine::~Machine() { delete P; + for (auto& queue : queues) + delete queue; } template @@ -324,7 +326,7 @@ void Machine::run() { Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); proc_timer.start(); - timer[0].start(); + timer[0].start({}); // run main tape run_tape(0, 0, 0, N.num_players()); @@ -352,7 +354,6 @@ void Machine::run() queues[i]->schedule({}); pos.increase(queues[i]->result().pos); pthread_join(threads[i],NULL); - delete queues[i]; } finish_timer.stop(); @@ -372,6 +373,8 @@ void Machine::run() cerr << "Finish timer: " << finish_timer.elapsed() << endl; #endif + NamedCommStats comm_stats = total_comm(); + if (opts.verbose) { cerr << "Communication details " @@ -457,9 +460,12 @@ void Machine::run() } #ifndef INSECURE - Data_Files df(*this); - df.seekg(pos); - df.prune(); + if (not opts.file_prep_per_thread) + { + Data_Files df(*this); + df.seekg(pos); + df.prune(); + } #endif sint::LivePrep::teardown(); diff --git a/Processor/Memory.h b/Processor/Memory.h index 2c4a3d2e..9ec02d2b 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -43,8 +43,11 @@ class Memory template static void check_index(const vector& M, size_t i) { + (void) M, (void) i; +#ifdef NO_CHECK_INDEX if (i >= M.size()) throw overflow("memory", i, M.size()); +#endif } const typename T::clear& read_C(size_t i) const diff --git a/Processor/NoFilePrep.h b/Processor/NoFilePrep.h new file mode 100644 index 00000000..fbb44912 --- /dev/null +++ b/Processor/NoFilePrep.h @@ -0,0 +1,22 @@ +/* + * NoFilePrep.h + * + */ + +#ifndef PROCESSOR_NOFILEPREP_H_ +#define PROCESSOR_NOFILEPREP_H_ + +#include "Data_Files.h" + +template +class NoFilePrep : public Preprocessing +{ +public: + NoFilePrep(int, int, const string&, DataPositions& usage, int = -1) : + Preprocessing(usage) + { + throw runtime_error("don't call this"); + } +}; + +#endif /* PROCESSOR_NOFILEPREP_H_ */ diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index cffaded4..dcfafe55 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -71,7 +71,7 @@ void OfflineMachine::generate() auto my_usage = domain_usage[i]; Dtype dtype = Dtype(i); string filename = Sub_Data_Files::get_filename(playerNames, dtype, - T::clear::field_type() == DATA_GF2 ? 0 : -1); + 0); if (my_usage > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -106,7 +106,7 @@ void OfflineMachine::generate() for (int i = 0; i < P.num_players(); i++) { auto n_inputs = usage.inputs[i][T::clear::field_type()]; - string filename = Sub_Data_Files::get_input_filename(playerNames, i); + string filename = Sub_Data_Files::get_input_filename(playerNames, i, 0); if (n_inputs > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -137,7 +137,7 @@ void OfflineMachine::generate() int total = usage.edabits[{false, n_bits}] + usage.edabits[{true, n_bits}]; string filename = Sub_Data_Files::get_edabit_filename(playerNames, - n_bits); + n_bits, 0); if (total > 0) { ofstream out(filename, ios::binary); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index cb25b426..e98f1a3a 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -279,7 +279,7 @@ void thread_info::Sub_Main_Func() printf("\tSignalling I have finished\n"); #endif wait_timer.start(); - queues->finished(job); + queues->finished(job, P.total_comm()); wait_timer.stop(); } } @@ -287,6 +287,11 @@ void thread_info::Sub_Main_Func() // final check Proc.check(); +#ifndef INSECURE + if (machine.opts.file_prep_per_thread) + Proc.DataF.prune(); +#endif + wait_timer.start(); queues->next(); wait_timer.stop(); @@ -314,16 +319,10 @@ void thread_info::Sub_Main_Func() #endif // wind down thread by thread - auto prep_stats = Proc.DataF.comm_stats(); - prep_stats += Proc.share_thread.DataF.comm_stats(); - prep_stats += Proc.Procp.bit_prep.comm_stats(); - for (auto& x : Proc.Procp.personal_bit_preps) - prep_stats += x->comm_stats(); machine.stats += Proc.stats; delete processor; - machine.comm_stats += P.comm_stats + prep_stats; - queues->finished(actual_usage); + queues->finished(actual_usage, P.total_comm()); delete MC2; delete MCp; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 41308603..2a5e090b 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -29,6 +29,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; + trunc_error = 40; #ifdef VERBOSE verbose = true; #else @@ -326,6 +327,19 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, #endif lgp = max(lgp, gfp0::MAX_N_BITS); } + + set_trunc_error(opt); +} + +void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) +{ + if (opt.get("-E")) + { + opt.get("-E")->getInt(trunc_error); +#ifdef VERBOSE + cerr << "Truncation error probability 2^-" << trunc_error << endl; +#endif + } } int OnlineOptions::prime_length() diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index de8f1e72..4b2fe4f8 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -30,6 +30,7 @@ public: std::string cmd_private_output_file; bool verbose; bool file_prep_per_thread; + int trunc_error; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -37,10 +38,15 @@ public: OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, bool variable_prime_length = false); + template + OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, + bool default_live_prep = true); ~OnlineOptions() {} void finalize(ez::ezOptionParser& opt, int argc, const char** argv); + void set_trunc_error(ez::ezOptionParser& opt); + int prime_length(); int prime_limbs(); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp new file mode 100644 index 00000000..8961853e --- /dev/null +++ b/Processor/OnlineOptions.hpp @@ -0,0 +1,30 @@ +/* + * OnlineOptions.hpp + * + */ + +#ifndef PROCESSOR_ONLINEOPTIONS_HPP_ +#define PROCESSOR_ONLINEOPTIONS_HPP_ + +#include "OnlineOptions.h" + +template +OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, + const char** argv, T, bool default_live_prep) : + OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0, + default_live_prep, T::clear::prime_field) +{ + if (T::has_trunc_pr) + opt.add( + to_string(trunc_error).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Probabilistic truncation error " + "(2^-x, default: 40)", // Help description. + "-E", // Flag token. + "--trunc-error" // Flag token. + ); +} + +#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 5c44b908..4ca77daa 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -40,21 +40,33 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, + to_string(my_num) + get_suffix(thread_num); } -void PrepBase::print_left(const char* name, size_t n, const string& type_string) +void PrepBase::print_left(const char* name, size_t n, const string& type_string, + size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) cerr << "\t" << n << " " << name << " of " << type_string << " left" << endl; + + if (n > used / 10) + cerr << "Significant amount of unused " << name << " of " << type_string + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b." << endl; } void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits) + int n_bits, size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) { cerr << "\t~" << n * n_batch; if (not strict) cerr << " loose"; cerr << " edaBits of size " << n_bits << " left" << endl; } + + if (n > used / 10) + cerr << "Significant amount of unused edaBits of size " << n_bits + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b " + << "or increasing the bucket size with -B." << endl; } diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index bedba629..ccc2f4b4 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -24,8 +24,10 @@ public: static string get_edabit_filename(const string& prep_data_dir, int n_bits, int my_num, int thread_num = 0); - static void print_left(const char* name, size_t n, const string& type_string); - static void print_left_edabits(size_t n, size_t n_batch, bool strict, int n_bits); + static void print_left(const char* name, size_t n, + const string& type_string, size_t used); + static void print_left_edabits(size_t n, size_t n_batch, bool strict, + int n_bits, size_t used); }; #endif /* PROCESSOR_PREPBASE_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index d9141855..a78058cd 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -243,10 +243,6 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); - // Print the processor state - template - friend ostream& operator<<(ostream& s,const Processor& P); - private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 6206e27c..caea1e67 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -28,8 +28,8 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, bit_prep(bit_usage) { DataF.set_proc(this); + protocol.init(DataF, MC); DataF.set_protocol(protocol); - protocol.init_mul(this); bit_usage.set_num_players(P.num_players()); personal_bit_preps.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) @@ -39,22 +39,12 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, template SubProcessor::~SubProcessor() { - protocol.check(); - for (size_t i = 0; i < personal_bit_preps.size(); i++) { auto& x = personal_bit_preps[i]; -#ifdef VERBOSE - if (x->data_sent()) - cerr << "Sent for personal bit preprocessing threads of player " << i << ": " << - x->data_sent() * 1e-6 << " MB" << endl; -#endif delete x; } #ifdef VERBOSE - if (bit_prep.data_sent()) - cerr << "Sent for global bit preprocessing threads: " << - bit_prep.data_sent() * 1e-6 << " MB" << endl; if (not bit_usage.empty()) { cerr << "Mixed-circuit preprocessing cost:" << endl; @@ -423,7 +413,7 @@ void SubProcessor::muls(const vector& reg, int size) int n = reg.size() / 3; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < size; j++) { @@ -448,7 +438,7 @@ void SubProcessor::mulrs(const vector& reg) int n = reg.size() / 4; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < reg[4 * i]; j++) { @@ -470,7 +460,7 @@ void SubProcessor::mulrs(const vector& reg) template void SubProcessor::dotprods(const vector& reg, int size) { - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < size; i++) { auto it = reg.begin(); @@ -512,7 +502,7 @@ void SubProcessor::matmuls(const vector& source, assert(B + dim[1] * dim[2] <= source.end()); assert(C + dim[0] * dim[2] <= S.end()); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) for (int j = 0; j < dim[2]; j++) { @@ -536,7 +526,7 @@ void SubProcessor::matmulsm(const CheckVector& source, assert(C + dim[0] * dim[2] <= S.end()); assert(Proc); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) { auto ii = Proc->get_Ci().at(dim[3] + i); @@ -562,7 +552,7 @@ void SubProcessor::matmulsm(const CheckVector& source, template void SubProcessor::conv2ds(const Instruction& instruction) { - protocol.init_dotprod(this); + protocol.init_dotprod(); auto& args = instruction.get_start(); int output_h = args[0], output_w = args[1]; int inputs_h = args[2], inputs_w = args[3]; @@ -670,30 +660,4 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } -template -ostream& operator<<(ostream& s,const Processor& P) -{ - s << "Processor State" << endl; - s << "Char 2 Registers" << endl; - s << "Val\tClearReg\tSharedReg" << endl; - for (int i=0; i(), live_prep_default}; RingMachine(argc, argv, opt, online_opts); } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index add3f43c..e422e0aa 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -12,6 +12,7 @@ #include "Tools/ezOptionParser.h" #include "Math/gf2n.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class U, template class V> @@ -25,7 +26,7 @@ template class U, template class V> HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char** argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv); + OnlineOptions online_opts(opt, argc, argv, U<64>()); RingMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/ThreadQueue.cpp b/Processor/ThreadQueue.cpp index 3f5b1c76..6358e4a4 100644 --- a/Processor/ThreadQueue.cpp +++ b/Processor/ThreadQueue.cpp @@ -27,6 +27,19 @@ void ThreadQueue::finished(const ThreadJob& job) out.push(job); } +void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats) +{ + finished(job); + set_comm_stats(new_comm_stats); +} + +void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats) +{ + lock.lock(); + comm_stats = new_comm_stats; + lock.unlock(); +} + ThreadJob ThreadQueue::result() { auto res = out.pop(); @@ -38,3 +51,11 @@ ThreadJob ThreadQueue::result() lock.unlock(); return res; } + +NamedCommStats ThreadQueue::get_comm_stats() +{ + lock.lock(); + auto res = comm_stats; + lock.unlock(); + return res; +} diff --git a/Processor/ThreadQueue.h b/Processor/ThreadQueue.h index 2e994b3a..f49722ab 100644 --- a/Processor/ThreadQueue.h +++ b/Processor/ThreadQueue.h @@ -13,6 +13,7 @@ class ThreadQueue WaitQueue in, out; Lock lock; int left; + NamedCommStats comm_stats; public: ThreadQueue() : @@ -28,7 +29,11 @@ public: void schedule(const ThreadJob& job); ThreadJob next(); void finished(const ThreadJob& job); + void finished(const ThreadJob& job, const NamedCommStats& comm_stats); ThreadJob result(); + + void set_comm_stats(const NamedCommStats& new_comm_stats); + NamedCommStats get_comm_stats(); }; #endif /* PROCESSOR_THREADQUEUE_H_ */ diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h index 06a96845..267acae4 100644 --- a/Processor/TruncPrTuple.h +++ b/Processor/TruncPrTuple.h @@ -10,26 +10,35 @@ #include using namespace std; +#include "OnlineOptions.h" + template class TruncPrTuple { public: + const static int n = 4; + int dest_base; int source_base; int k; int m; int n_shift; - TruncPrTuple(const vector& regs, size_t base) + TruncPrTuple(const vector& regs, size_t base) : + TruncPrTuple(regs.begin() + base) { - dest_base = regs[base]; - source_base = regs[base + 1]; - k = regs[base + 2]; - m = regs[base + 3]; + } + + TruncPrTuple(vector::const_iterator it) + { + dest_base = *it++; + source_base = *it++; + k = *it++; + m = *it++; n_shift = T::N_BITS - 1 - k; assert(m < k); assert(0 < k); - assert(m < T::N_BITS); + assert(m < T::n_bits()); } T upper(T mask) @@ -49,10 +58,17 @@ class TruncPrTupleWithGap : public TruncPrTuple { public: TruncPrTupleWithGap(const vector& regs, size_t base) : - TruncPrTuple(regs, base) + TruncPrTupleWithGap(regs.begin() + base) { } + TruncPrTupleWithGap(vector::const_iterator it) : + TruncPrTuple(it) + { + if (T::prime_field and small_gap()) + throw runtime_error("domain too small for chosen truncation error"); + } + T upper(T mask) { if (big_gap()) @@ -69,7 +85,12 @@ public: bool big_gap() { - return this->k <= T::N_BITS - 40; + return this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error; + } + + bool small_gap() + { + return not big_gap(); } }; diff --git a/Programs/Source/keras_mnist_lenet_predict.mpc b/Programs/Source/keras_mnist_lenet_predict.mpc new file mode 100644 index 00000000..8b55de56 --- /dev/null +++ b/Programs/Source/keras_mnist_lenet_predict.mpc @@ -0,0 +1,44 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +# training_samples = MultiArray([60000, 28, 28], sfix) +# training_labels = MultiArray([60000, 10], sint) + +test_samples = MultiArray([1, 28, 28], sfix) +test_labels = MultiArray([1, 10], sint) + +# training_labels.input_from(0) +# training_samples.input_from(0) + +# test_labels.input_from(0) +# test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +model.build(test_samples.sizes) + +start = 0 +for var in model.trainable_variables: + var.assign_all(0) +# start = var.read_from_file(start) + +guesses = model.predict(test_samples, batch_size=1) + +print_ln('guess %s', guesses.reveal_nested()[:3]) +print_ln('truth %s', test_labels.reveal_nested()[:3]) diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 3dd34d17..c99d911a 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -53,18 +53,13 @@ public: return shamir.get_n_relevant_players(); } - void init_mul(Preprocessing&, typename T::MAC_Check&) - { - init_mul(); - } - - void init_mul(SubProcessor* proc = 0); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void prepare(const typename T::open_type& product); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index bb6f18bf..c3a919b3 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -38,7 +38,7 @@ array Atlas::get_double_sharing() } template -void Atlas::init_mul(SubProcessor*) +void Atlas::init_mul() { oss.reset(); oss2.reset(); @@ -47,10 +47,9 @@ void Atlas::init_mul(SubProcessor*) } template -typename T::clear Atlas::prepare_mul(const T& x, const T& y, int) +void Atlas::prepare_mul(const T& x, const T& y, int) { prepare(x * y); - return {}; } template @@ -98,9 +97,9 @@ T Atlas::finalize_mul(int) } template -void Atlas::init_dotprod(SubProcessor* proc) +void Atlas::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index e0c24e49..2d28127c 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -38,14 +38,17 @@ public: Beaver(Player& P) : prep(0), MC(0), P(P) {} - Player& branch(); + typename T::Protocol branch(); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); + void check(); + void start_exchange(); void stop_exchange(); diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 63993005..dc981487 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -13,30 +13,34 @@ #include template -Player& Beaver::branch() +typename T::Protocol Beaver::branch() { - return P; + typename T::Protocol res(P); + res.prep = prep; + res.MC = MC; + res.init_mul(); + return res; } template -void Beaver::init_mul(SubProcessor* proc) -{ - assert(proc != 0); - init_mul(proc->DataF, proc->MC); -} - -template -void Beaver::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) +void Beaver::init(Preprocessing& prep, typename T::MAC_Check& MC) { this->prep = &prep; this->MC = &MC; +} + +template +void Beaver::init_mul() +{ + assert(this->prep); + assert(this->MC); shares.clear(); opened.clear(); triples.clear(); } template -typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) +void Beaver::prepare_mul(const T& x, const T& y, int n) { (void) n; triples.push_back({{}}); @@ -44,7 +48,6 @@ typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); - return 0; } template @@ -86,4 +89,11 @@ T Beaver::finalize_mul(int n) return tmp; } +template +void Beaver::check() +{ + assert(MC); + MC->Check(P); +} + #endif diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index 301ed9b0..77f2e35f 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -38,6 +38,8 @@ public: const static int N_MASK_BITS = clear::N_BITS + S; const static int Z_BITS = 2 * (N_MASK_BITS) + 5 + S; + static const bool has_trunc_pr = false; + BrainShare() { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 85378245..fb55f0cf 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -9,6 +9,7 @@ #include "Replicated.h" #include "Math/Z2k.h" #include "Processor/Instruction.h" +#include "Processor/TruncPrTuple.h" #include @@ -75,15 +76,14 @@ public: return P; } - void init_mul(SubProcessor*) + void init_mul() { results.clear(); } - typename T::clear prepare_mul(const T& x, const T& y, int = -1) + void prepare_mul(const T& x, const T& y, int = -1) { results.push_back(x * y); - return {}; } void exchange() @@ -95,9 +95,9 @@ public: return results.next(); } - void init_dotprod(SubProcessor* proc) + void init_dotprod() { - init_mul(proc); + init_mul(); dot_prod = {}; } @@ -177,19 +177,22 @@ public: res += overflow; } #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - T r; - r.randomize(G); + if (TruncPrTupleWithGap(regs, i).big_gap()) + { + T r; + r.randomize(G); - if (source.negative()) - res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + if (source.negative()) + res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + else + res = ((source + r) >> n_shift) - (r >> n_shift); + } else - res = ((source + r) >> n_shift) - (r >> n_shift); -#else - T r; - r.randomize_part(G, n_shift); - res = (source + r) >> n_shift; -#endif + { + T r; + r.randomize_part(G, n_shift); + res = (source + r) >> n_shift; + } #endif } } diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index f36a7b75..569c136e 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -32,6 +32,9 @@ public: typedef GC::FakeSecret bit_type; + static const bool has_trunc_pr = true; + static const bool dishonest_majority = false; + static string type_short() { return "emul"; diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 1e802146..8a00c793 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -6,14 +6,14 @@ #ifndef PROTOCOLS_HEMI_H_ #define PROTOCOLS_HEMI_H_ -#include "SPDZ.h" +#include "Semi.h" #include "HemiMatrixPrep.h" /** * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public SPDZ +class Hemi : public Semi { map, HemiMatrixPrep*> matrix_preps; @@ -22,7 +22,7 @@ class Hemi : public SPDZ public: Hemi(Player& P) : - SPDZ(P) + Semi(P) { } ~Hemi(); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index dc285c14..e67b28a9 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -51,19 +51,20 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { - auto ii = Proc->get_Ci().at(dim[3] + i); + for (int i = 0; i < dim[0]; i++) + { + auto kk = Proc->get_Ci().at(dim[4] + k); + auto ii = Proc->get_Ci().at(dim[3] + i); + A[{i, k}] = source.at(a + ii * dim[7] + kk); + } + for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ll = Proc->get_Ci().at(dim[5] + k); - A[{i, k}] = source.at(a + ii * dim[7] + kk); - B[{k, j}] = source.at(b + ll * dim[8] + jj); - } + auto ll = Proc->get_Ci().at(dim[5] + k); + B[{k, j}] = source.at(b + ll * dim[8] + jj); } } @@ -93,7 +94,8 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); MatrixMC mc; - beaver.init_mul(prep, mc); + beaver.init(prep, mc); + beaver.init_mul(); beaver.prepare_mul(A.from(0, i, subdim.data()), B.from(i, j, subdim.data() + 1)); beaver.exchange(); diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp index 2618feba..1c8f9f74 100644 --- a/Protocols/HighGearKeyGen.cpp +++ b/Protocols/HighGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PartSetup::key_and_mac_generation(Player& P, MachineBase& machine, int, false_type) { - HighGearKeyGen<2, 2>(P, params).run(*this, machine); + HighGearKeyGen<0, 0>(P, params).run(*this, machine); } diff --git a/Protocols/LowGearKeyGen.cpp b/Protocols/LowGearKeyGen.cpp index 2b149bc0..61829b36 100644 --- a/Protocols/LowGearKeyGen.cpp +++ b/Protocols/LowGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PairwiseSetup::key_and_mac_generation(Player& P, PairwiseMachine& machine, int, false_type) { - LowGearKeyGen<2>(P, machine, params).run(*this); + LowGearKeyGen<0>(P, machine, params).run(*this); } diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index a5982040..9ff92fb0 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -126,7 +126,7 @@ typename KeyGenProtocol::vector_type KeyGenProtocol::schur_product( vector_type res; assert(x.size() == y.size()); auto& protocol = proc->protocol; - protocol.init_mul(proc); + protocol.init_mul(); for (size_t i = 0; i < x.size(); i++) protocol.prepare_mul(x[i], y[i]); protocol.exchange(); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index db3f8dc7..85e9c84a 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -50,11 +50,13 @@ Tree_MAC_Check::Tree_MAC_Check(const typename U::mac_key_type::Scalar& ai, in template Tree_MAC_Check::~Tree_MAC_Check() { +#ifndef NO_SECURITY_CHECK if (WaitingForCheck() > 0) { cerr << endl << "SECURITY BUG: insufficient checking" << endl; terminate(); } +#endif } template diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index ce34b64e..96f2c813 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -121,21 +121,6 @@ void shuffle_triple_generation(vector>& triples, Player& P, #endif } -template -void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) -{ - int buffer_size = check_triples.size(); - - // shuffle - GlobalPRNG G(P); - for (int i = 0; i < buffer_size; i++) - { - int remaining = buffer_size - i; - int pos = G.get_uint(remaining); - swap(check_triples[i], check_triples[i + pos]); - } -} - template TripleShuffleSacrifice::TripleShuffleSacrifice() { @@ -251,32 +236,6 @@ void RingOnlyBitsFromSquaresPrep::buffer_bits() bits_from_square_in_ring(this->bits, this->buffer_size, &prep); } -template -void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, - ThreadQueues* queues) -{ - RunningTimer timer; -#ifndef NONPERSONAL_EDA - this->buffer_edabits_from_personal(strict, n_bits, queues); -#else - assert(this->proc != 0); - ShuffleSacrifice shuffle_sacrifice; - typedef typename T::bit_type::part_type bit_type; - vector> bits; - vector sums; - this->buffer_edabits_without_check(n_bits, sums, bits, - shuffle_sacrifice.minimum_n_inputs(), queues); - vector>& checked = this->edabits[{strict, n_bits}]; - shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, - n_bits, *this->proc, strict, -1, queues); - if (strict) - this->sanitize(checked, n_bits, -1, queues); -#endif -#ifdef VERBOSE_EDA - cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; -#endif -} - template void MalRepRingPrep::buffer_inputs(int player) { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 1967994d..f98e9797 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,6 +42,7 @@ public: typedef GC::MaliciousRepSecret bit_type; const static bool expensive = true; + static const bool has_trunc_pr = false; static string type_short() { diff --git a/Protocols/MaliciousRepPO.h b/Protocols/MaliciousRepPO.h index 62d4b178..7b58970f 100644 --- a/Protocols/MaliciousRepPO.h +++ b/Protocols/MaliciousRepPO.h @@ -11,17 +11,21 @@ template class MaliciousRepPO { +protected: Player& P; octetStream to_send; octetStream to_receive[2]; + PointerVector secrets; public: MaliciousRepPO(Player& P); + virtual ~MaliciousRepPO() {} void prepare_sending(const T& secret, int player); - void send(int player); - void receive(); + virtual void send(int player); + virtual void receive(); typename T::clear finalize(const T& secret); + typename T::clear finalize(); }; #endif /* PROTOCOLS_MALICIOUSREPPO_H_ */ diff --git a/Protocols/MaliciousRepPO.hpp b/Protocols/MaliciousRepPO.hpp index 38a3a274..bae23564 100644 --- a/Protocols/MaliciousRepPO.hpp +++ b/Protocols/MaliciousRepPO.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPPO_HPP_ +#define PROTOCOLS_MALICIOUSREPPO_HPP_ + #include "MaliciousRepPO.h" #include @@ -16,7 +19,10 @@ MaliciousRepPO::MaliciousRepPO(Player& P) : P(P) template void MaliciousRepPO::prepare_sending(const T& secret, int player) { - secret[2 - P.get_offset(player)].pack(to_send); + if (player == P.my_num()) + secrets.push_back(secret); + else + secret[2 - P.get_offset(player)].pack(to_send); } template @@ -24,7 +30,7 @@ void MaliciousRepPO::send(int player) { if (P.get_offset(player) == 2) P.send_to(player, to_send); - else + else if (P.my_num() != player) P.send_to(player, to_send.hash()); } @@ -42,3 +48,11 @@ typename T::clear MaliciousRepPO::finalize(const T& secret) { return secret.sum() + to_receive[0].template get(); } + +template +typename T::clear MaliciousRepPO::finalize() +{ + return finalize(secrets.next()); +} + +#endif diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 4be3fc63..8ffbff7b 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -61,8 +61,9 @@ void MaliciousBitOnlyRepPrep::set_protocol(typename T::Protocol& protocol) template void MaliciousBitOnlyRepPrep::init_honest(Player& P) { - honest_proc = new SubProcessor(honest_mc, honest_prep, - P); + if (not honest_proc) + honest_proc = new SubProcessor(honest_mc, + honest_prep, P); } template diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index ef61ec7b..c9eb63cf 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -6,6 +6,7 @@ #include "MamaPrep.h" #include "SemiMC.hpp" +#include "MalRepRingPrep.hpp" template MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 734453d3..5cfa82b8 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -21,8 +21,6 @@ public: ~OTPrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; /** diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index cef603a2..1393bb46 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -40,8 +40,9 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) // make sure not to use Montgomery multiplication T::open_type::next::template init(false); + assert(not triple_generator); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(proc->P), proc->P.N, -1, OnlineOptions::singleton.batch_size, 1, params, proc->MC.get_alphai(), &proc->P); @@ -121,13 +122,4 @@ T Preprocessing::get_random_from_inputs(int nplayers) return res; } -template -NamedCommStats OTPrep::comm_stats() -{ - auto res = BitPrep::comm_stats(); - if (triple_generator) - res += triple_generator->comm_stats(); - return res; -} - #endif diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h index b99ce4e3..d8259eb0 100644 --- a/Protocols/NoProtocol.h +++ b/Protocols/NoProtocol.h @@ -45,12 +45,12 @@ public: } // prepare next round of multiplications - void init_mul(SubProcessor*) + void init_mul() { } // schedule multiplication - typename T::clear prepare_mul(const T&, const T&, int = -1) + void prepare_mul(const T&, const T&, int = -1) { throw runtime_error("no multiplication preparation"); } diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 70371744..d4f2ab0f 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -22,6 +22,8 @@ public: static const int BIT_LENGTH = K; static const int SECURITY = S; + static const bool has_trunc_pr = false; + typedef SignedZ2 clear; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index 73ec766e..54b178a7 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -30,8 +30,8 @@ public: Player& branch(); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange() { internal.exchange(); } T finalize_mul(int n = -1); diff --git a/Protocols/PostSacrifice.hpp b/Protocols/PostSacrifice.hpp index 4db3b73b..0f72f4e8 100644 --- a/Protocols/PostSacrifice.hpp +++ b/Protocols/PostSacrifice.hpp @@ -25,9 +25,8 @@ Player& PostSacrifice::branch() } template -void PostSacrifice::init_mul(SubProcessor* proc) +void PostSacrifice::init_mul() { - (void) proc; // throw away unused operands operands.resize(results.size()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) @@ -36,11 +35,11 @@ void PostSacrifice::init_mul(SubProcessor* proc) } template -typename T::clear PostSacrifice::prepare_mul(const T& x, const T& y, int n) +void PostSacrifice::prepare_mul(const T& x, const T& y, int n) { (void) n; operands.push_back({{x, y}}); - return internal.prepare_mul(x, y); + internal.prepare_mul(x, y); } template diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h new file mode 100644 index 00000000..e6a8eb52 --- /dev/null +++ b/Protocols/ProtocolSet.h @@ -0,0 +1,107 @@ +/* + * ProtocolSet.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSET_H_ +#define PROTOCOLS_PROTOCOLSET_H_ + +#include "Processor/Processor.h" +#include "GC/ShareThread.h" +#include "ProtocolSetup.h" + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type + */ +template +class ProtocolSet +{ + DataPositions usage; + +public: + typename T::MAC_Check output; + typename T::LivePrep preprocessing; + SubProcessor processor; + typename T::Protocol& protocol; + typename T::Input& input; + + ProtocolSet(Player& P, typename T::mac_key_type mac_key) : + usage(P.num_players()), output(mac_key), preprocessing(0, usage), processor( + output, preprocessing, P), protocol(processor.protocol), input( + processor.input) + { + } + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + ProtocolSet(Player& P, const ProtocolSetup& setup) : + ProtocolSet(P, setup.get_mac_key()) + { + } + + ~ProtocolSet() + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for a binary share type + */ +template +class BinaryProtocolSet +{ + DataPositions usage; + typename T::LivePrep prep; + GC::ShareThread thread; + +public: + typename T::MAC_Check& output; + typename T::Protocol& protocol; + typename T::Input input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + BinaryProtocolSet(Player& P, const BinaryProtocolSetup& setup) : + usage(P.num_players()), prep(usage), thread(prep, P, + setup.get_mac_key()), output(*thread.MC), protocol( + *thread.protocol), input(output, prep, P) + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSet +{ + ProtocolSet arithmetic; + +public: + BinaryProtocolSet binary; + + typename T::MAC_Check& output; + typename T::LivePrep& preprocessing; + typename T::Protocol& protocol; + typename T::Input& input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + MixedProtocolSet(Player& P, const MixedProtocolSetup& setup) : + arithmetic(P, setup), binary(P, setup.binary), output( + arithmetic.output), preprocessing(arithmetic.preprocessing), protocol( + arithmetic.protocol), input(arithmetic.input) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h new file mode 100644 index 00000000..b6d91b2b --- /dev/null +++ b/Protocols/ProtocolSetup.h @@ -0,0 +1,95 @@ +/* + * ProtocolSetup.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSETUP_H_ +#define PROTOCOLS_PROTOCOLSETUP_H_ + +#include "Networking/Player.h" + +/** + * Global setup for an arithmetic share type + */ +template +class ProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + ProtocolSetup(Player& P, int prime_length = 0, string directory = "") + { + // initialize fields + if (prime_length == 0) + prime_length = T::clear::MAX_N_BITS; + + T::clear::init_default(prime_length); + T::clear::next::init_default(prime_length, false); + + // must initialize MAC key for security of some protocols + T::read_or_generate_mac_key(directory, P, mac_key); + } + + ~ProtocolSetup() + { + T::LivePrep::teardown(); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for a binary share type + */ +template +class BinaryProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param directory location to read MAC if needed + */ + BinaryProtocolSetup(Player& P, string directory = "") + { + T::part_type::open_type::init_field(); + T::mac_key_type::init_field(); + T::part_type::read_or_generate_mac_key(directory, P, mac_key); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSetup : public ProtocolSetup +{ +public: + BinaryProtocolSetup binary; + + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + MixedProtocolSetup(Player& P, int prime_length = 0, string directory = "") : + ProtocolSetup(P, prime_length, directory), binary(P, directory) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSETUP_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index d115b4c5..e85065ac 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -11,6 +11,7 @@ #include "Protocols/Replicated.h" #include "GC/ShareSecret.h" #include "ShareInterface.h" +#include "Processor/Instruction.h" template class ReplicatedPrep; template class ReplicatedRingPrep; @@ -67,6 +68,31 @@ public: assert(full); FixedVec::unpack(os); } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + shrsi(proc, inst, T::invertible); + } + + template + static void shrsi(SubProcessor&, const Instruction&, + true_type) + { + throw runtime_error("shrsi not implemented"); + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst, + false_type) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; template @@ -94,6 +120,7 @@ public: const static bool dishonest_majority = false; const static bool expensive = false; const static bool variable_players = false; + static const bool has_trunc_pr = true; static string type_short() { diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index c7a49452..23f28cf9 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -31,7 +31,6 @@ public: typedef GC::SemiHonestRepSecret bit_type; - static const bool has_trunc_pr = true; static const bool has_split = true; Rep3Share2() @@ -132,17 +131,6 @@ public: } } } - - template - static void shrsi(SubProcessor& proc, const Instruction& inst) - { - for (int i = 0; i < inst.get_size(); i++) - { - auto& dest = proc.get_S_ref(inst.get_r(0) + i); - auto& source = proc.get_S_ref(inst.get_r(1) + i); - dest = source >> inst.get_n(); - } - } }; #endif /* PROTOCOLS_REP3SHARE2K_H_ */ diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index aa0fc7bc..6acfae42 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -60,6 +60,11 @@ class Rep4 : public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type); + template + T finalize_mul(int n_bits, true_type); + template + T finalize_mul(int n_bits, false_type); + public: prngs_type rep_prngs; Player& P; @@ -70,14 +75,13 @@ public: Rep4 branch(); - void init_mul(SubProcessor* proc = 0); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void check(); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index e77b4e6f..a2deab2b 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -59,7 +59,7 @@ Rep4 Rep4::branch() } template -void Rep4::init_mul(SubProcessor*) +void Rep4::init_mul() { for (auto& x : add_shares) x.clear(); @@ -70,12 +70,6 @@ void Rep4::init_mul(SubProcessor*) channels.resize(P.num_players(), vector(P.num_players(), false)); } -template -void Rep4::init_mul(Preprocessing&, typename T::MAC_Check&) -{ - init_mul(); -} - template void Rep4::reset_joint_input(int n_inputs) { @@ -194,13 +188,12 @@ int Rep4::get_player(int offset) } template -typename T::clear Rep4::prepare_mul(const T& x, const T& y, int n_bits) +void Rep4::prepare_mul(const T& x, const T& y, int n_bits) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) add_shares[i].push_back(a[i]); bit_lengths.push_back(n_bits); - return {}; } template @@ -215,7 +208,7 @@ array Rep4::get_addshares(const T& x, const T& y) } template -void Rep4::init_dotprod(SubProcessor*) +void Rep4::init_dotprod() { init_mul(); dotprod_shares = {}; @@ -260,10 +253,27 @@ void Rep4::exchange() } template -T Rep4::finalize_mul(int) +T Rep4::finalize_mul(int n_bits) { this->counter++; - return results.next().res; + if (n_bits == -1) + return results.next().res; + else + return finalize_mul(n_bits, T::clear::binary); +} + +template +template +T Rep4::finalize_mul(int n_bits, true_type) +{ + return results.next().res.mask(n_bits); +} + +template +template +T Rep4::finalize_mul(int, false_type) +{ + throw runtime_error("bit-wise multiplication not supported"); } template diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp index 17915e43..e871e82c 100644 --- a/Protocols/Rep4Prep.hpp +++ b/Protocols/Rep4Prep.hpp @@ -54,7 +54,7 @@ template void Rep4RingPrep::buffer_squares() { generate_squares(this->squares, OnlineOptions::singleton.batch_size, - this->protocol, this->proc); + this->protocol); } template diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 3de9bfab..67527a20 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -76,10 +76,13 @@ public: /// Single multiplication T mul(const T& x, const T& y); + /// Initialize protocol if needed (repeated call possible) + virtual void init(Preprocessing&, typename T::MAC_Check&) {} + /// Initialize multiplication round - virtual void init_mul(SubProcessor* proc) = 0; + virtual void init_mul() = 0; /// Schedule multiplication of operand pair - virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; + virtual void prepare_mul(const T& x, const T& y, int n = -1) = 0; /// Run multiplication protocol virtual void exchange() = 0; /// Get next multiplication result @@ -88,7 +91,7 @@ public: virtual void finalize_mult(T& res, int n = -1); /// Initialize dot product round - void init_dotprod(SubProcessor* proc) { init_mul(proc); } + void init_dotprod() { init_mul(); } /// Add operand pair to current dot product void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } /// Finish dot product @@ -132,6 +135,11 @@ class Replicated : public ReplicatedBase, public ProtocolBase PointerVector add_shares; typename T::clear dotprod_share; + template + void trunc_pr(const vector& regs, int size, U& proc, true_type); + template + void trunc_pr(const vector& regs, int size, U& proc, false_type); + public: typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; @@ -149,17 +157,13 @@ public: share[my_num] = value; } - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - void init_mul(); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void prepare_reshare(const typename T::clear& share, int n = -1); - void init_dotprod(SubProcessor*) { init_mul(); } void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 75dc785b..374ed89b 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -11,12 +11,10 @@ #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" -#include "SemiShare.h" -#include "SemiMC.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" -#include "SemiMC.hpp" +#include "ReplicatedPO.hpp" #include "Math/Z2k.hpp" template @@ -99,7 +97,8 @@ void ProtocolBase::multiply(vector& products, BaseMachine::thread_num); #endif - init_mul(&proc); + init(proc.DataF, proc.MC); + init_mul(); for (int i = begin; i < end; i++) prepare_mul(multiplicands[i].first, multiplicands[i].second); exchange(); @@ -110,7 +109,7 @@ void ProtocolBase::multiply(vector& products, template T ProtocolBase::mul(const T& x, const T& y) { - init_mul(0); + init_mul(); prepare_mul(x, y); exchange(); return finalize_mul(); @@ -146,20 +145,6 @@ T ProtocolBase::get_random() return res; } -template -void Replicated::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - -template -void Replicated::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) -{ - (void) prep, (void) MC; - init_mul(); -} - template void Replicated::init_mul() { @@ -169,12 +154,11 @@ void Replicated::init_mul() } template -inline typename T::clear Replicated::prepare_mul(const T& x, +void Replicated::prepare_mul(const T& x, const T& y, int n) { typename T::value_type add_share = x.local_mul(y); prepare_reshare(add_share, n); - return add_share; } template @@ -276,109 +260,89 @@ void Replicated::randoms(T& res, int n_bits) res[i].randomize_part(shared_prngs[i], n_bits); } -template -void trunc_pr(const vector& regs, int size, - SubProcessor>& proc) +template +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + false_type) { assert(regs.size() % 4 == 0); assert(proc.P.num_players() == 3); assert(proc.Proc != 0); - typedef SignedZ2 value_type; - typedef Rep3Share T; - bool generate = proc.P.my_num() == 2; + typedef typename T::clear value_type; + int gen_player = 2; + int comp_player = 1; + bool generate = P.my_num() == gen_player; + bool compute = P.my_num() == comp_player; + ArgList> infos(regs); + auto& S = proc.get_S(); + + octetStream cs; + ReplicatedInput input(P); + if (generate) { - octetStream os[2]; - for (size_t i = 0; i < regs.size(); i += 4) - { - TruncPrTuple info(regs, i); - for (int l = 0; l < size; l++) + SeededPRNG G; + for (auto info : infos) + for (int i = 0; i < size; i++) { - auto& res = proc.get_S_ref(regs[i] + l); - auto& G = proc.Proc->secure_prng; - auto mask = G.template get(); - auto unmask = info.upper(mask); - T shares[4]; - shares[0].randomize_to_sum(mask, G); - shares[1].randomize_to_sum(unmask, G); - shares[2].randomize_to_sum(info.msb(mask), G); - res.randomize(G); - shares[3] = res; - for (int i = 0; i < 2; i++) - { - for (int j = 0; j < 4; j++) - shares[j][i].pack(os[i]); - } + auto r = G.get(); + input.add_mine(info.upper(r)); + if (info.small_gap()) + input.add_mine(info.msb(r)); + (r + S[info.source_base + i][0]).pack(cs); } - } - for (int i = 0; i < 2; i++) - proc.P.send_to(i, os[i]); + P.send_to(comp_player, cs); } else + input.add_other(gen_player); + + if (compute) { - octetStream os; - proc.P.receive_player(2, os); - OffsetPlayer player(proc.P, 1 - 2 * proc.P.my_num()); - typedef SemiShare semi_type; - vector> to_open; - PointerVector> mask_shares[3]; - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) + P.receive_player(gen_player, cs); + for (auto info : infos) + for (int i = 0; i < size; i++) { - SemiShare share; - auto& x = proc.get_S_ref(regs[i + 1] + l); - if (proc.P.my_num() == 0) - share = x.sum(); - else - share = x[0]; - for (auto& mask_share : mask_shares) - mask_share.push_back(os.get()); - to_open.push_back(share + mask_shares[0].next()); - auto& res = proc.get_S_ref(regs[i] + l); - auto& a = res[1 - proc.P.my_num()]; - a.unpack(os); + auto c = cs.get() + S[info.source_base + i].sum(); + input.add_mine(info.upper(c)); + if (info.small_gap()) + input.add_mine(info.msb(c)); } - PointerVector opened; - DirectSemiMC> MC; - MC.POpen_(opened, to_open, player); - os.reset_write_head(); - for (size_t i = 0; i < regs.size(); i += 4) + } + + input.add_other(comp_player); + input.exchange(); + init_mul(); + + for (auto info : infos) + for (int i = 0; i < size; i++) { - int k = regs[i + 2]; - int m = regs[i + 3]; - int n_shift = value_type::N_BITS - 1 - k; - assert(m < k); - assert(0 < k); - assert(m < value_type::N_BITS); - for (int l = 0; l < size; l++) + auto c_prime = input.finalize(comp_player); + auto r_prime = input.finalize(gen_player); + S[info.dest_base + i] = c_prime - r_prime; + + if (info.small_gap()) { - auto& res = proc.get_S_ref(regs[i] + l); - auto masked = opened.next() << n_shift; - auto shifted = (masked << 1) >> (n_shift + m + 1); - auto diff = SemiShare::constant(shifted, - player.my_num()) - mask_shares[1].next(); - auto msb = masked >> (value_type::N_BITS - 1); - auto bit_mask = mask_shares[2].next(); - auto overflow = (bit_mask - + SemiShare::constant(msb, player.my_num()) - - bit_mask * msb * 2); - auto res_share = diff + (overflow << (k - m)); - auto& a = res[1 - proc.P.my_num()]; - auto& b = res[proc.P.my_num()]; - b = res_share - a; - b.pack(os); + auto c_dprime = input.finalize(comp_player); + auto r_msb = input.finalize(gen_player); + S[info.dest_base + i] += ((r_msb + c_dprime) + << (info.k - info.m)); + prepare_mul(r_msb, c_dprime); } } - player.exchange(os); - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) - proc.get_S_ref(regs[i] + l)[proc.P.my_num()] += - os.get(); - } + + exchange(); + + for (auto info : infos) + for (int i = 0; i < size; i++) + if (info.small_gap()) + S[info.dest_base + i] -= finalize_mul() + << (info.k - info.m + 1); } template -void trunc_pr(const vector& regs, int size, SubProcessor& proc) +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + true_type) { (void) regs, (void) size, (void) proc; throw runtime_error("trunc_pr not implemented"); @@ -390,7 +354,7 @@ void Replicated::trunc_pr(const vector& regs, int size, U& proc) { this->trunc_rounds++; - ::trunc_pr(regs, size, proc); + trunc_pr(regs, size, proc, T::clear::characteristic_two); } #endif diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 7d62838a..9bb3c30a 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -72,9 +72,8 @@ public: PrepLessInput(proc), proc(proc), P(P), protocol(P) { assert(T::length == 2); - InputBase::P = &P; - InputBase::os.resize(P.num_players()); expect.resize(P.num_players()); + this->reset_all(P); } void reset(int player); diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 741d2c49..1cfac4a1 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -71,7 +71,7 @@ template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o, int n_bits) { - int offset = player - P.my_num(); + int offset = player - this->my_num; if (offset == 1 or offset == -2) { typename T::value_type t; diff --git a/Protocols/ReplicatedPO.h b/Protocols/ReplicatedPO.h new file mode 100644 index 00000000..a533a5b1 --- /dev/null +++ b/Protocols/ReplicatedPO.h @@ -0,0 +1,24 @@ +/* + * ReplicatedPO.h + * + */ + +#ifndef PROTOCOLS_REPLICATEDPO_H_ +#define PROTOCOLS_REPLICATEDPO_H_ + +#include "MaliciousRepPO.h" + +template +class ReplicatedPO : public MaliciousRepPO +{ +public: + ReplicatedPO(Player& P) : + MaliciousRepPO(P) + { + } + + void send(int player); + void receive(); +}; + +#endif /* PROTOCOLS_REPLICATEDPO_H_ */ diff --git a/Protocols/ReplicatedPO.hpp b/Protocols/ReplicatedPO.hpp new file mode 100644 index 00000000..aecd85b3 --- /dev/null +++ b/Protocols/ReplicatedPO.hpp @@ -0,0 +1,21 @@ +/* + * ReplicatedPO.cpp + * + */ + +#include "ReplicatedPO.h" + +#include "MaliciousRepPO.hpp" + +template +void ReplicatedPO::send(int player) +{ + if (this->P.get_offset(player) == 2) + this->P.send_to(player, this->to_send); +} + +template +void ReplicatedPO::receive() +{ + this->P.receive_relative(1, this->to_receive[0]); +} diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 8c3ed3f1..8a30749c 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -184,6 +184,15 @@ protected: template void sanitize(vector>& edabits, int n_bits); + template + void buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size); + template + void buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end); + public: RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep(); @@ -224,6 +233,13 @@ public: template class SemiHonestRingPrep : public virtual RingPrep { + template + void buffer_bits(false_type, false_type); + template + void buffer_bits(true_type, false_type); + template + void buffer_bits(false_type, true_type); + public: SemiHonestRingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -232,7 +248,7 @@ public: } virtual ~SemiHonestRingPrep() {} - virtual void buffer_bits() { this->buffer_bits_without_check(); } + virtual void buffer_bits(); virtual void buffer_inputs(int player) { this->buffer_inputs_as_usual(player, this->proc); } @@ -358,11 +374,6 @@ template class ReplicatedPrep : public virtual ReplicatedRingPrep, public virtual SemiHonestRingPrep { - template - void buffer_bits(false_type); - template - void buffer_bits(true_type); - public: ReplicatedPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -384,7 +395,7 @@ public: } void buffer_squares() { ReplicatedRingPrep::buffer_squares(); } - void buffer_bits(); + void buffer_bits() { SemiHonestRingPrep::buffer_bits(); } }; #endif /* PROTOCOLS_REPLICATEDPREP_H_ */ diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 2b8aa160..916ee6b8 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,24 +56,23 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif - if (OnlineOptions::singleton.verbose) - { - this->print_left("triples", triples.size() * T::default_length, - type_string); + this->print_left("triples", triples.size() * T::default_length, type_string, + this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) + * T::default_length); -#define X(KIND) \ - this->print_left(#KIND, KIND.size(), type_string); - X(squares) - X(inverses) - X(bits) - X(dabits) +#define X(KIND, TYPE) \ + this->print_left(#KIND, KIND.size(), type_string, \ + this->usage.files.at(T::clear::field_type()).at(TYPE)); + X(squares, DATA_SQUARE) + X(inverses, DATA_INVERSE) + X(bits, DATA_BIT) + X(dabits, DATA_DABIT) #undef X - for (auto& x : this->edabits) - { - this->print_left_edabits(x.second.size(), x.second[0].size(), - x.first.first, x.first.second); - } + for (auto& x : this->edabits) + { + this->print_left_edabits(x.second.size(), x.second[0].size(), + x.first.first, x.first.second, this->usage.edabits[x.first]); } } @@ -100,7 +99,9 @@ RingPrep::~RingPrep() template void BitPrep::set_protocol(typename T::Protocol& protocol) { - this->protocol = new typename T::Protocol(protocol.branch()); + if (not this->protocol) + this->protocol = new typename T::Protocol(protocol.branch()); + this->protocol->init_mul(); auto proc = this->proc; if (proc and proc->Proc) this->base_player = proc->Proc->thread_num; @@ -202,16 +203,16 @@ template void ReplicatedRingPrep::buffer_squares() { generate_squares(this->squares, this->buffer_size, - this->protocol, this->proc); + this->protocol); } template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc) + U* protocol) { assert(protocol != 0); squares.resize(n_squares); - protocol->init_mul(proc); + protocol->init_mul(); for (size_t i = 0; i < squares.size(); i++) { auto& square = squares[i]; @@ -289,7 +290,7 @@ void BufferPrep::get_two_no_count(Dtype dtype, T& a, T& b) template void XOR(vector& res, vector& x, vector& y, - typename T::Protocol& prot, SubProcessor* proc) + typename T::Protocol& prot) { assert(x.size() == y.size()); int buffer_size = x.size(); @@ -302,7 +303,7 @@ void XOR(vector& res, vector& x, vector& y, return; } - prot.init_mul(proc); + prot.init_mul(); for (int i = 0; i < buffer_size; i++) prot.prepare_mul(x[i], y[i]); prot.exchange(); @@ -337,13 +338,14 @@ void buffer_bits_from_squares(RingPrep& prep) template template -void ReplicatedPrep::buffer_bits(true_type) +void SemiHonestRingPrep::buffer_bits(true_type, false_type) { if (this->protocol->get_n_relevant_players() > 10 - or OnlineOptions::singleton.bits_from_squares) + or OnlineOptions::singleton.bits_from_squares + or T::dishonest_majority) buffer_bits_from_squares(*this); else - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template @@ -409,10 +411,9 @@ void MaliciousRingPrep::buffer_personal_dabits_without_check( auto& P = this->proc->P; auto &party = GC::ShareThread::s(); typedef typename T::bit_type::part_type BT; - SubProcessor bit_proc(party.MC->get_part_MC(), + typename BT::Input bit_input(party.MC->get_part_MC(), this->proc->bit_prep, this->proc->P); typename T::Input input(*this->proc, this->proc->MC); - typename BT::Input bit_input(bit_proc, bit_proc.MC); input.reset_all(P); bit_input.reset_all(P); SeededPRNG G; @@ -454,10 +455,24 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, typename BT::Input bit_input(proc, proc.MC); input.reset_all(P); bit_input.reset_all(P); - SeededPRNG G; assert(begin % BT::default_length == 0); int buffer_size = end - begin; + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + input_player, buffer_size); + input.exchange(); + bit_input.exchange(); + buffer_personal_edabits_without_check_post(n_bits, sums, bits, input, + bit_input, input_player, begin, end); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size) +{ int n_chunks = DIV_CEIL(buffer_size, BT::default_length); + SeededPRNG G; if (input_player == P.my_num()) { for (int i = 0; i < n_chunks; i++) @@ -482,8 +497,16 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, for (int i = 0; i < BT::default_length; i++) input.add_other(input_player); } - input.exchange(); - bit_input.exchange(); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end) +{ + int buffer_size = end - begin; + int n_chunks = DIV_CEIL(buffer_size, BT::default_length); for (int i = 0; i < buffer_size; i++) sums[begin + i] = input.finalize(input_player); assert(bits.size() == size_t(n_bits)); @@ -600,18 +623,18 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, assert(proc != 0); int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits; - auto stat = proc->P.comm_stats; + auto stat = proc->P.total_comm(); buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; - XOR(bits, player_bits[0], player_bits[1], prot, proc); + XOR(bits, player_bits[0], player_bits[1], prot); for (int i = 2; i < n_relevant_players; i++) - XOR(bits, bits, player_bits[i], prot, proc); + XOR(bits, bits, player_bits[i], prot); this->base_player++; (void) stat; #ifdef VERBOSE_PREP cerr << "bit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -730,9 +753,22 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, vector> player_ints(n_relevant, vector(buffer_size)); vector>> parts(n_relevant, vector>(n_bits, vector(buffer_size / dl))); + InScope in_scope(this->do_count, false); + assert(this->proc != 0); + auto& P = proc->P; + typename T::Input input(*this->proc, this->proc->MC); + typename BT::Input bit_input(bit_proc, bit_proc.MC); + input.reset_all(P); + bit_input.reset_all(P); + assert(begin % BT::default_length == 0); for (int i = 0; i < n_relevant; i++) - buffer_personal_edabits_without_check<0>(n_bits, player_ints[i], parts[i], - bit_proc, i, 0, buffer_size); + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + i, buffer_size); + input.exchange(); + bit_input.exchange(); + for (int i = 0; i < n_relevant; i++) + buffer_personal_edabits_without_check_post(n_bits, player_ints[i], + parts[i], input, bit_input, i, 0, buffer_size); vector>> player_bits(n_bits, vector>(n_relevant)); for (int i = 0; i < n_bits; i++) @@ -754,7 +790,7 @@ template void RingPrep::buffer_edabits_without_check(int n_bits, vector>& edabits, int buffer_size) { - auto stat = this->proc->P.comm_stats; + auto stat = this->proc->P.total_comm(); typedef typename T::bit_type::part_type bit_type; vector> bits; vector sums; @@ -763,7 +799,7 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& (void) stat; #ifdef VERBOSE_PREP cerr << "edaBit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -920,40 +956,38 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) delete &MCB; } -template<> -inline -void SemiHonestRingPrep>::buffer_bits() -{ - assert(protocol != 0); - bits_from_random(bits, *protocol); -} - template -void bits_from_random(vector& bits, typename T::Protocol& protocol) +template +void SemiHonestRingPrep::buffer_bits(false_type, true_type) { - while (bits.size() < (size_t)OnlineOptions::singleton.batch_size) - { - Rep3Share share = protocol.get_random(); - for (int j = 0; j < gf2n::degree(); j++) + assert(this->protocol != 0); + if (not T::dishonest_majority and T::variable_players) + // Shamir + this->buffer_bits_without_check(); + else + while (this->bits.size() < (size_t) OnlineOptions::singleton.batch_size) { - bits.push_back(share & 1); - share >>= 1; + auto share = this->get_random(); + for (int j = 0; j < T::open_type::degree(); j++) + { + this->bits.push_back(share & 1); + share >>= 1; + } } - } } template template -void ReplicatedPrep::buffer_bits(false_type) +void SemiHonestRingPrep::buffer_bits(false_type, false_type) { - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template -void ReplicatedPrep::buffer_bits() +void SemiHonestRingPrep::buffer_bits() { assert(this->protocol != 0); - buffer_bits<0>(T::clear::prime_field); + buffer_bits(T::clear::prime_field, T::clear::characteristic_two); } template diff --git a/Protocols/Semi2k.h b/Protocols/Semi.h similarity index 75% rename from Protocols/Semi2k.h rename to Protocols/Semi.h index 69cf63aa..e290ca0e 100644 --- a/Protocols/Semi2k.h +++ b/Protocols/Semi.h @@ -3,8 +3,8 @@ * */ -#ifndef PROTOCOLS_SEMI2K_H_ -#define PROTOCOLS_SEMI2K_H_ +#ifndef PROTOCOLS_SEMI_H_ +#define PROTOCOLS_SEMI_H_ #include "SPDZ.h" #include "Processor/TruncPrTuple.h" @@ -13,12 +13,12 @@ * Dishonest-majority protocol for computation modulo a power of two */ template -class Semi2k : public SPDZ +class Semi : public SPDZ { SeededPRNG G; public: - Semi2k(Player& P) : + Semi(Player& P) : SPDZ(P) { } @@ -30,6 +30,19 @@ public: void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { + trunc_pr(regs, size, proc, T::clear::characteristic_two); + } + + template + void trunc_pr(const vector&, int, SubProcessor&, true_type) + { + throw not_implemented(); + } + + template + void trunc_pr(const vector& regs, int size, + SubProcessor& proc, false_type) { if (this->P.num_players() > 2) throw runtime_error("probabilistic truncation " @@ -60,4 +73,4 @@ public: } }; -#endif /* PROTOCOLS_SEMI2K_H_ */ +#endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index a9df48b4..ee5e8320 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -7,7 +7,7 @@ #define PROTOCOLS_SEMI2KSHARE_H_ #include "SemiShare.h" -#include "Semi2k.h" +#include "Semi.h" #include "OT/Rectangle.h" #include "GC/SemiSecret.h" #include "GC/square64.h" @@ -27,7 +27,7 @@ public: typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef Semi2k Protocol; + typedef Semi Protocol; typedef SemiPrep2k LivePrep; typedef Semi2kShare prep_type; @@ -35,8 +35,6 @@ public: typedef OTTripleGenerator TripleGenerator; typedef Z2kSquare Rectangle; - typedef GC::SemiSecret bit_type; - static const bool has_split = true; Semi2kShare() diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index ed044c46..c2dd9085 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SEMISHARE_H_ #include "Protocols/Beaver.h" +#include "Protocols/Semi.h" #include "Processor/DummyProtocol.h" #include "ShareInterface.h" @@ -16,7 +17,7 @@ using namespace std; template class Input; template class SemiMC; template class DirectSemiMC; -template class SPDZ; +template class Semi; template class SemiPrep; template class SemiInput; template class PrivateOutput; @@ -59,7 +60,7 @@ public: typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef Semi Protocol; typedef SemiPrep LivePrep; typedef LivePrep TriplePrep; @@ -69,12 +70,15 @@ public: typedef T sacri_type; typedef typename T::Square Rectangle; +#ifndef NO_MIXED_CIRCUITS typedef GC::SemiSecret bit_type; +#endif const static bool needs_ot = true; const static bool dishonest_majority = true; const static bool variable_players = true; const static bool expensive = false; + static const bool has_trunc_pr = true; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 3d2bf469..f722886e 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -62,20 +62,8 @@ public: void reset(); void init_mul(); - void init_mul(SubProcessor* proc); - template - void init_mul(V*) - { - init_mul(); - } - template - void init_mul(const V&, const W&) - { - init_mul(); - } - - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void start_exchange(); @@ -85,7 +73,7 @@ public: T finalize(int n_input_players); - void init_dotprod(SubProcessor* proc = 0); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d387f3b4..9fe10bde 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -80,13 +80,6 @@ void Shamir::reset() resharing->reset(i); } -template -void Shamir::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - template void Shamir::init_mul() { @@ -96,13 +89,12 @@ void Shamir::init_mul() } template -typename T::clear Shamir::prepare_mul(const T& x, const T& y, int n) +void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) resharing->add_mine(add_share); - return {}; } template @@ -157,9 +149,9 @@ T Shamir::finalize(int n_relevant_players) } template -void Shamir::init_dotprod(SubProcessor* proc) +void Shamir::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index fe509321..81e85931 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -10,7 +10,6 @@ #include "Tools/PointerVector.h" #include "GC/BitAdder.h" -#include "MalRepRingPrep.hpp" #include "LimitedPrep.hpp" inline @@ -25,6 +24,21 @@ ShuffleSacrifice::ShuffleSacrifice(int B, int C) : { } +template +void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) +{ + int buffer_size = check_triples.size(); + + // shuffle + GlobalPRNG G(P); + for (int i = 0; i < buffer_size; i++) + { + int remaining = buffer_size - i; + int pos = G.get_uint(remaining); + swap(check_triples[i], check_triples[i + pos]); + } +} + template void TripleShuffleSacrifice::triple_combine(vector >& triples, vector >& to_combine, Player& P, diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 33883c66..03a91ff2 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -26,7 +26,6 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, MascotTriplePrep* bit_prep; SubProcessor* bit_proc; typename BitShare::MAC_Check* bit_MC; - typename BitShare::Protocol* bit_protocol; public: Spdz2kPrep(SubProcessor* proc, DataPositions& usage); @@ -41,8 +40,6 @@ public: #ifdef SPDZ2K_BIT void get_dabit(T& a, GC::TinySecret& b); #endif - - NamedCommStats comm_stats(); }; #endif /* PROTOCOLS_SPDZ2KPREP_H_ */ diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index f5c9cdce..81527761 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -25,7 +25,6 @@ Spdz2kPrep::Spdz2kPrep(SubProcessor* proc, DataPositions& usage) : bit_MC = 0; bit_proc = 0; bit_prep = 0; - bit_protocol = 0; } template @@ -36,7 +35,6 @@ Spdz2kPrep::~Spdz2kPrep() delete bit_prep; delete bit_proc; delete bit_MC; - delete bit_protocol; } } @@ -50,10 +48,8 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) // just dummies bit_pos = DataPositions(proc->P.num_players()); bit_prep = new MascotTriplePrep(bit_proc, bit_pos); - bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_prep->params.amplify = false; - bit_protocol = new typename BitShare::Protocol(proc->P); - bit_prep->set_protocol(*bit_protocol); + bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_MC->set_prep(*bit_prep); this->proc->MC.set_prep(*this); } @@ -65,7 +61,7 @@ void MaliciousRingPrep::buffer_bits() RingPrep::buffer_bits_without_check(); assert(this->protocol != 0); auto& protocol = *this->protocol; - protocol.init_dotprod(this->proc); + protocol.init_dotprod(); auto one = T::constant(1, protocol.P.my_num(), this->proc->MC.get_alphai()); GlobalPRNG G(protocol.P); for (auto& bit : this->bits) @@ -238,12 +234,29 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, } template -NamedCommStats Spdz2kPrep::comm_stats() +void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, + ThreadQueues* queues) { - auto res = OTPrep::comm_stats(); - if (bit_prep) - res += bit_prep->comm_stats(); - return res; + RunningTimer timer; +#ifndef NONPERSONAL_EDA + this->buffer_edabits_from_personal(strict, n_bits, queues); +#else + assert(this->proc != 0); + ShuffleSacrifice shuffle_sacrifice; + typedef typename T::bit_type::part_type bit_type; + vector> bits; + vector sums; + this->buffer_edabits_without_check(n_bits, sums, bits, + shuffle_sacrifice.minimum_n_inputs(), queues); + vector>& checked = this->edabits[{strict, n_bits}]; + shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, + n_bits, *this->proc, strict, -1, queues); + if (strict) + this->sanitize(checked, n_bits, -1, queues); +#endif +#ifdef VERBOSE_EDA + cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; +#endif } #endif diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index afbf2c85..c12b4f5f 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -38,22 +38,23 @@ public: SpdzWise(Player& P); virtual ~SpdzWise(); - Player& branch(); + typename T::Protocol branch(); - void init(SubProcessor* proc); + void init(Preprocessing&, typename T::MAC_Check& MC); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor*); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); void add_to_check(const T& x); void check(); + void maybe_check(); int get_n_relevant_players() { return internal.get_n_relevant_players(); } diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 40f3cee7..2ea08ba4 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -19,34 +19,40 @@ SpdzWise::~SpdzWise() } template -Player& SpdzWise::branch() +typename T::Protocol SpdzWise::branch() { - return P; + typename T::Protocol res(P); + res.mac_key = mac_key; + return res; } template -void SpdzWise::init(SubProcessor* proc) +void SpdzWise::init(Preprocessing&, typename T::MAC_Check& MC) { - assert(proc != 0); - mac_key = proc->MC.get_alphai(); + mac_key = MC.get_alphai(); +} + +template +void SpdzWise::maybe_check() +{ + assert(not mac_key.is_zero()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) check(); } template -void SpdzWise::init_mul(SubProcessor* proc) +void SpdzWise::init_mul() { - init(proc); + maybe_check(); internal.init_mul(); internal2.init_mul(); } template -typename T::clear SpdzWise::prepare_mul(const T& x, const T& y, int) +void SpdzWise::prepare_mul(const T& x, const T& y, int) { internal.prepare_mul(x.get_share(), y.get_share()); internal.prepare_mul(x.get_mac(), y.get_share()); - return {}; } template @@ -67,9 +73,9 @@ void SpdzWise::exchange() } template -void SpdzWise::init_dotprod(SubProcessor* proc) +void SpdzWise::init_dotprod() { - init(proc); + maybe_check(); internal.init_dotprod(); internal2.init_dotprod(); } diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index ef7f549b..e0d508e5 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -12,6 +12,7 @@ SpdzWiseInput::SpdzWiseInput(SubProcessor* proc, Player& P) : { assert(proc != 0); mac_key = proc->MC.get_alphai(); + checker.init(proc->DataF, proc->MC); } template @@ -76,7 +77,7 @@ void SpdzWiseInput::exchange() shares[i][j].set_mac(honest_mult.finalize_mul()); checker.results.push_back(shares[i][j]); } - checker.init(proc); + checker.maybe_check(); } template diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index f88e97d6..9cb86017 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -9,19 +9,21 @@ #include "MaliciousShamirShare.h" #include "SquarePrep.h" #include "Math/gfp.h" +#include "ProtocolSet.h" #include "ReplicatedPrep.hpp" #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" #include "MaliciousShamirPO.hpp" +#include "GC/RepPrep.hpp" template void SpdzWisePrep::buffer_triples() { assert(this->protocol != 0); assert(this->proc != 0); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); generate_triples_initialized(this->triples, OnlineOptions::singleton.batch_size, this->protocol); } @@ -38,8 +40,11 @@ void SpdzWisePrep>>::buffer_bits() { typedef MaliciousRep3Share part_type; vector bits; - typename part_type::Honest::Protocol protocol(this->protocol->P); - bits_from_random(bits, protocol); + ProtocolSet set(this->proc->P, {}); + auto& protocol = set.protocol; + auto& prep = set.preprocessing; + for (int i = 0; i < buffer_size; i++) + bits.push_back(prep.get_bit()); protocol.init_mul(); for (auto& bit : bits) protocol.prepare_mul(bit, this->proc->MC.get_alphai()); @@ -99,7 +104,7 @@ void SpdzWisePrep::buffer_inputs(int player) vector rs(OnlineOptions::singleton.batch_size); auto& P = this->proc->P; this->inputs.resize(P.num_players()); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); for (auto& r : rs) { r = this->protocol->get_random(); diff --git a/Protocols/SpdzWiseRing.hpp b/Protocols/SpdzWiseRing.hpp index 30904c38..36e638d1 100644 --- a/Protocols/SpdzWiseRing.hpp +++ b/Protocols/SpdzWiseRing.hpp @@ -36,7 +36,7 @@ void SpdzWiseRing::zero_check(check_type t) while(bits.size() > 1) { auto& protocol = zero_proc.protocol; - protocol.init_mul(&zero_proc); + protocol.init_mul(); for (int i = bits.size() - 2; i >= 0; i -= 2) protocol.prepare_mul(bits[i], bits[i + 1]); protocol.exchange(); diff --git a/Protocols/SquarePrep.h b/Protocols/SquarePrep.h index fcdc2c23..be0913b3 100644 --- a/Protocols/SquarePrep.h +++ b/Protocols/SquarePrep.h @@ -10,7 +10,7 @@ template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc); + U* protocol); template class SquarePrep : public BufferPrep @@ -22,8 +22,8 @@ class SquarePrep : public BufferPrep void buffer_squares() { - generate_squares(this->squares, this->buffer_size, &this->proc->protocol, - this->proc); + generate_squares(this->squares, this->buffer_size, + &this->proc->protocol); } public: diff --git a/README.md b/README.md index daa658a5..bd107512 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ The following table lists all protocols that are fully supported. | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | See [this paper](https://eprint.iacr.org/2020/300) for an explanation -of the various security models and high-level introduction to +of the various security models and a high-level introduction to multi-party computation. ##### Finding the most efficient protocol @@ -131,8 +131,8 @@ there are a few things to consider: dot products. - Fixed-point multiplication: Three- and four-party replicated secret - sharing modulo a power of two allow a special probabilistic - truncation protocol (see [Dalskov et + sharing as well semi-honest full-threshold protocols allow a special + probabilistic truncation protocol (see [Dalskov et al.](https://eprint.iacr.org/2019/131) and [Dalskov et al.](https://eprint.iacr.org/2020/1330)). You can activate it by adding `program.use_trunc_pr = True` at the beginning of your diff --git a/Scripts/decompile.py b/Scripts/decompile.py new file mode 100755 index 00000000..0142ba69 --- /dev/null +++ b/Scripts/decompile.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +import sys, os + +sys.path.append('.') + +from Compiler.instructions_base import Instruction +from Compiler.program import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +for tapename in Program.read_tapes(sys.argv[1]): + with open('Programs/Bytecode/%s.asm' % tapename, 'w') as out: + for i, inst in enumerate(Tape.read_instructions(tapename)): + print(inst, '#', i, file=out) diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py new file mode 100755 index 00000000..15959ee6 --- /dev/null +++ b/Scripts/memory-usage.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +for tapename in Program.read_tapes(sys.argv[1]): + for inst in Tape.read_instructions(tapename): + t = inst.type + if issubclass(t, DirectMemoryInstruction): + res[t.arg_format[0]] = max(inst.args[1].i + inst.size, + res[t.arg_format[0]]) + for arg in inst.args: + if isinstance(arg, RegisterArgFormat): + m = max(m, arg.i + inst.size) + +print (res) +print (m) + diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 3c0891e6..7e5e6d44 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -34,39 +34,26 @@ run_player() { if ! test -e $SPDZROOT/logs; then mkdir $SPDZROOT/logs fi - if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then - params="$prog $* -pn $port -h localhost" - if [[ ! ($bin =~ 'rep' || $bin =~ 'brain' || $bin =~ 'yao') ]]; then - params="$params -N $players" - fi - else - params="$port localhost $prog $*" + params="$prog $* -pn $port -h localhost" + if $SPDZROOT/$bin 2>&1 | grep -q '^-N,'; then + params="$params -N $players" fi - rem=$(($players - 2)) if test "$prog"; then log_prefix=$prog- fi - for i in $(seq 0 $rem); do + set -o pipefail + for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | { if test $i = 0; then tee $log; else cat > $log; fi; } & + codes[$i]=$! + done + for i in $(seq 0 $[players-1]); do + wait ${codes[$i]} || return 1 done - last_player=$(($players - 1)) - i=$last_player - >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params - $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 - wait } -sleep 0.5 - -#mkdir /dev/shm/Player-Data - players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} - -#. Scripts/setup.sh - -mkdir logs 2> /dev/null diff --git a/Scripts/test_streaming.sh b/Scripts/test_streaming.sh index 0ff2fb33..62a49308 100755 --- a/Scripts/test_streaming.sh +++ b/Scripts/test_streaming.sh @@ -15,3 +15,7 @@ done ./stream-fake-mascot-triples.x & Scripts/mascot.sh test_thread_mul -f || exit 1 + +./stream-fake-mascot-triples.x & + +Scripts/mascot.sh test_thread_mul -f || exit 1 diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index 5dd4f45d..ed6c0144 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -27,7 +27,8 @@ if test "$flags"; then cpu=amd64 fi - cp -av bin/`uname`-$cpu/* . + cp -av bin/`uname`-$cpu/* . || { echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2; exit 1; } fi mkdir Player-Data 2> /dev/null +exit 0 diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 4ef3406f..567e5788 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -9,6 +9,15 @@ #include #include +void BitVector::assign(const BitVector& K) +{ + if (nbits != K.nbits) + { + resize(K.nbits); + } + memcpy(bytes, K.bytes, nbytes); +} + void BitVector::resize_zero(size_t new_nbits) { size_t old_nbytes = nbytes; diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 05561051..54d9ed10 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -33,14 +33,7 @@ class BitVector public: - void assign(const BitVector& K) - { - if (nbits != K.nbits) - { - resize(K.nbits); - } - memcpy(bytes, K.bytes, nbytes); - } + void assign(const BitVector& K); void assign_bytes(char* new_bytes, int len) { resize(len*8); diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index c669081f..9dd15804 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -26,7 +26,7 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, bool BufferBase::is_pipe() { struct stat buf; - if (stat(filename.c_str(), &buf)) + if (stat(filename.c_str(), &buf) == 0) return S_ISFIFO(buf.st_mode); else return false; @@ -113,6 +113,17 @@ void BufferBase::prune() rename(tmp_name.c_str(), filename.c_str()); file->open(filename.c_str(), ios::in | ios::binary); } +#ifdef VERBOSE + else + { + cerr << "Not pruning " << filename << " because it's "; + if (file) + cerr << "closed"; + else + cerr << "unused"; + cerr << endl; + } +#endif } void BufferBase::purge() diff --git a/Tools/Bundle.h b/Tools/Bundle.h index ed4b982e..7859e3e4 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -31,7 +31,7 @@ public: { } - void compare(Player& P) + void compare(PlayerBase& P) { P.unchecked_broadcast(*this); for (auto& os : *this) diff --git a/Tools/TimerWithComm.cpp b/Tools/TimerWithComm.cpp new file mode 100644 index 00000000..2a5e8e12 --- /dev/null +++ b/Tools/TimerWithComm.cpp @@ -0,0 +1,23 @@ +/* + * TimerWithComm.cpp + * + */ + +#include "TimerWithComm.h" + +void TimerWithComm::start(const NamedCommStats& stats) +{ + Timer::start(); + last_stats = stats; +} + +void TimerWithComm::stop(const NamedCommStats& stats) +{ + Timer::stop(); + total_stats += stats - last_stats; +} + +double TimerWithComm::mb_sent() +{ + return total_stats.sent * 1e-6; +} diff --git a/Tools/TimerWithComm.h b/Tools/TimerWithComm.h new file mode 100644 index 00000000..2f3976a2 --- /dev/null +++ b/Tools/TimerWithComm.h @@ -0,0 +1,23 @@ +/* + * TimerWithComm.h + * + */ + +#ifndef TOOLS_TIMERWITHCOMM_H_ +#define TOOLS_TIMERWITHCOMM_H_ + +#include "time-func.h" +#include "Networking/Player.h" + +class TimerWithComm : public Timer +{ + NamedCommStats total_stats, last_stats; + +public: + void start(const NamedCommStats& stats = {}); + void stop(const NamedCommStats& stats = {}); + + double mb_sent(); +}; + +#endif /* TOOLS_TIMERWITHCOMM_H_ */ diff --git a/Tools/benchmarking.cpp b/Tools/benchmarking.cpp new file mode 100644 index 00000000..e956f15e --- /dev/null +++ b/Tools/benchmarking.cpp @@ -0,0 +1,15 @@ +/* + * benchmarking.cpp + * + */ + +#include "benchmarking.h" + +void insecure_fake() +{ +#if defined(INSECURE) or defined(INSECURE_FAKE) + cerr << "WARNING: insecure preprocessing" << endl; +#else + insecure("preprocessing"); +#endif +} diff --git a/Tools/benchmarking.h b/Tools/benchmarking.h index 0ca65b76..13fa9c36 100644 --- a/Tools/benchmarking.h +++ b/Tools/benchmarking.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; // call before insecure benchmarking functionality @@ -26,4 +27,6 @@ inline void insecure(string message, bool warning = true) #endif } +void insecure_fake(); + #endif /* TOOLS_BENCHMARKING_H_ */ diff --git a/Tools/octetStream.h b/Tools/octetStream.h index df920a30..cd90b0e9 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -35,7 +35,9 @@ class bigint; class FlexBuffer; /** - * Buffer for networking communication with a pointer for sequential reading + * Buffer for network communication with a pointer for sequential reading. + * When sent over the network or stored in a file, the length is prefixed + * as eight bytes in little-endian order. */ class octetStream { diff --git a/Tools/random.cpp b/Tools/random.cpp index 7a0cd1da..7cf1924f 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -13,7 +13,7 @@ using namespace std; PRNG::PRNG() : - cnt(0), n_cached_bits(0), cached_bits(0) + cnt(0), n_cached_bits(0), cached_bits(0), initialized(false) { #if defined(__AES__) || !defined(__x86_64__) #ifdef USE_AES @@ -83,6 +83,7 @@ void PRNG::SecureSeed(Player& player) void PRNG::InitSeed() { + initialized = true; #ifdef USE_AES if (useC) { aes_schedule(KeyScheduleC,seed); } @@ -122,6 +123,7 @@ void PRNG::print_state() const void PRNG::hash() { + assert(initialized); #ifndef USE_AES unsigned char tmp[RAND_SIZE + SEED_SIZE]; randombytes_buf_deterministic(tmp, sizeof tmp, seed); diff --git a/Tools/random.h b/Tools/random.h index d22be6e8..5e65d835 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -61,6 +61,8 @@ class PRNG int n_cached_bits; word cached_bits; + bool initialized; + void hash(); // Hashes state to random and sets cnt=0 void next(); diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index e8026a95..f1158cfa 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -387,7 +387,7 @@ int generate(ez::ezOptionParser& opt); int main(int argc, const char** argv) { - insecure("preprocessing"); + insecure_fake(); bigint::init_thread(); FakeParams params; diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp new file mode 100644 index 00000000..45e5f337 --- /dev/null +++ b/Utils/binary-example.cpp @@ -0,0 +1,140 @@ +/* + * binary-example.cpp + * + */ + +#include "GC/TinierSecret.h" +#include "GC/PostSacriSecret.h" +#include "GC/CcdSecret.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/AtlasSecret.h" +#include "GC/TinyMC.h" +#include "GC/VectorInput.h" +#include "GC/PostSacriBin.h" +#include "Protocols/ProtocolSet.h" + +#include "GC/ShareSecret.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/RepPrep.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ThreadMaster.hpp" +#include "Protocols/Atlas.hpp" +#include "Protocols/MaliciousRepPrep.hpp" +#include "Protocols/Share.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Shamir.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Machines/ShamirMachine.hpp" +#include "Machines/Rep4.hpp" + +template +void run(int argc, char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol [bit length [threshold]]]" + << endl; + exit(1); + } + + string protocol = "Tinier"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "Tinier") + run>(argc, argv); + else if (protocol == "Rep3") + run(argc, argv); + else if (protocol == "Rep4") + run(argc, argv); + else if (protocol == "PS") + run(argc, argv); + else if (protocol == "Semi") + run(argc, argv); + else if (protocol == "CCD" or protocol == "MalCCD" or protocol == "Atlas") + { + int nparties = (atoi(argv[2])); + int threshold = (nparties - 1) / 2; + if (argc > 5) + threshold = atoi(argv[5]); + assert(2 * threshold < nparties); + ShamirOptions::s().threshold = threshold; + ShamirOptions::s().nparties = nparties; + + if (protocol == "CCD") + run>>(argc, argv); + else if (protocol == "MalCCD") + run>(argc, argv); + else + run(argc, argv); + } + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(int argc, char** argv) +{ + // run 16-bit computation by default + int n_bits = 16; + if (argc > 4) + n_bits = atoi(argv[4]); + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + BinaryProtocolSetup setup(P); + + // set of protocols (input, multiplication, output) + BinaryProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; + + int n = 10; + vector a(n), b(n); + + input.reset_all(P); + for (int i = 0; i < n; i++) + input.add_from_all(i + P.my_num(), n_bits); + input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = input.finalize(0, n_bits); + b[i] = input.finalize(1, n_bits); + } + + protocol.init_mul(); + for (int i = 0; i < n; i++) + protocol.prepare_mul(a[i], b[i], n_bits); + protocol.exchange(); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + auto c = protocol.finalize_mul(n_bits); + output.prepare_open(c); + } + output.exchange(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + protocol.check(); + output.Check(P); +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp new file mode 100644 index 00000000..532d705e --- /dev/null +++ b/Utils/mixed-example.cpp @@ -0,0 +1,137 @@ +/* + * mixed-example.cpp + * + */ + +#include "Protocols/ProtocolSet.h" + +#include "Machines/SPDZ.hpp" +#include "Machines/Semi2k.hpp" +#include "Machines/Rep.hpp" +#include "Machines/Rep4.hpp" +#include "Machines/Atlas.hpp" + +template +void run(char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol]" + << endl; + exit(1); + } + + string protocol = "SPDZ2k"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "SPDZ2k") + run>(argv); + else if (protocol == "Semi2k") + run>(argv); + else if (protocol == "Rep3") + run>(argv); + else if (protocol == "Rep4") + run>(argv); + else if (protocol == "Atlas") + run>>(argv); + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(char** argv) +{ + // reduce batch size + OnlineOptions::singleton.bucket_size = 5; + OnlineOptions::singleton.batch_size = 100; + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + MixedProtocolSetup setup(P); + + // set of protocols (bit_input, multiplication, output) + MixedProtocolSet set(P, setup); + auto& output = set.output; + auto& bit_input = set.binary.input; + auto& bit_protocol = set.binary.protocol; + auto& bit_output = set.binary.output; + auto& prep = set.preprocessing; + + int n = 10; + int n_bits = 16; + vector a(n), b(n); + + // inputs in binary domain + bit_input.reset_all(P); + for (int i = 0; i < n; i++) + bit_input.add_from_all(i + P.my_num(), n_bits); + bit_input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = bit_input.finalize(0, n_bits); + b[i] = bit_input.finalize(1, n_bits); + } + + // compute AND in binary domain + bit_protocol.init_mul(); + for (int i = 0; i < n; i++) + bit_protocol.prepare_mul(a[i], b[i], n_bits); + bit_protocol.exchange(); + bit_protocol.check(); + bit_output.init_open(P, n * n_bits); + PointerVector> dabits; + for (int i = 0; i < n; i++) + { + auto c = bit_protocol.finalize_mul(n_bits); + + // mask result with dabits and open + for (int j = 0; j < n_bits; j++) + { + dabits.push_back({}); + auto& dabit = dabits.back(); + prep.get_dabit(dabit.first, dabit.second); + bit_output.prepare_open( + typename T::bit_type::part_type( + dabit.second.get_bit(0) + c.get_bit(j))); + } + } + bit_output.exchange(P); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + T res; + // unmask via XOR and recombine + for (int j = 0; j < n_bits; j++) + { + typename T::clear masked = bit_output.finalize_open().get_bit(0); + auto mask = dabits.next().first; + res += (mask - mask * masked * 2 + + T::constant(masked, P.my_num(), setup.get_mac_key())) + << j; + } + output.prepare_open(res); + } + output.exchange(P); + bit_output.Check(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + output.Check(P); +} diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 87247fee..9cae6953 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -11,8 +11,10 @@ #include "Machines/SPDZ.hpp" #include "Machines/MalRep.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/Semi2k.hpp" #include "Protocols/CowGearShare.h" #include "Protocols/CowGearPrep.hpp" +#include "Protocols/ProtocolSet.h" template void run(char** argv, int prime_length); @@ -42,6 +44,8 @@ int main(int argc, char** argv) run>>(argv, prime_length); else if (protocol == "SPDZ2k") run>(argv, 0); + else if (protocol == "Semi2k") + run>(argv, 0); else if (protocol == "Shamir" or protocol == "MalShamir") { int nparties = (atoi(argv[2])); @@ -74,35 +78,14 @@ void run(char** argv, int prime_length) Names N(my_number, n_parties, "localhost", port_base); CryptoPlayer P(N); - // initialize fields - T::clear::init_default(prime_length); - T::clear::next::init_default(prime_length, false); + // protocol setup (domain, MAC key if needed etc) + ProtocolSetup setup(P, prime_length); - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - - // output protocol - typename T::MAC_Check output(mac_key); - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - - // input protocol - typename T::Input input(processor, output); - - // multiplication protocol - typename T::Protocol protocol(P); + // set of protocols (input, multiplication, output) + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; int n = 1000; vector a(n), b(n); @@ -119,19 +102,23 @@ void run(char** argv, int prime_length) b[i] = input.finalize(1); } - protocol.init_dotprod(&processor); + protocol.init_dotprod(); for (int i = 0; i < n; i++) protocol.prepare_dotprod(a[i], b[i]); protocol.next_dotprod(); protocol.exchange(); c = protocol.finalize_dotprod(n); + + // protocol check before revealing results + protocol.check(); + output.init_open(P); output.prepare_open(c); output.exchange(P); result = output.finalize_open(); cout << "result: " << result << endl; - output.Check(P); - T::LivePrep::teardown(); + // result check after opening + output.Check(P); } diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp index 5aa85a05..517056e7 100644 --- a/Utils/stream-fake-mascot-triples.cpp +++ b/Utils/stream-fake-mascot-triples.cpp @@ -27,13 +27,18 @@ void* run(void* arg) int count = 0; while (true) { - gfpvar triple[3]; - for (int i = 0; i < 2; i++) - triple[i].randomize(G); - triple[2] = triple[0] * triple[1]; - for (int i = 0; i < 3; i++) - files.output_shares(triple[i]); - count++; + for (int i = 0; i < 100000; i++) + { + gfpvar triple[3]; + for (int i = 0; i < 2; i++) + triple[i].randomize(G); + triple[2] = triple[0] * triple[1]; + for (int i = 0; i < 3; i++) + files.output_shares(triple[i]); + count++; + } + // take a break to make them wait + sleep(1); } cerr << "failed after " << count << endl; return 0; @@ -41,7 +46,7 @@ void* run(void* arg) int main() { - insecure("preprocessing"); + insecure_fake(); typedef Share T; int nplayers = 2; int lgp = 128; diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 074fb340..749ba287 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -58,9 +58,6 @@ public: int get_n_worker_threads() { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } - - NamedCommStats comm_stats() - { return super::comm_stats() + player.comm_stats; } }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index e6ae6cda..647369a1 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -120,8 +120,3 @@ void YaoGarbler::process_receiver_inputs() receiver_input_keys.pop_front(); } } - -NamedCommStats YaoGarbler::comm_stats() -{ - return super::comm_stats() + player.comm_stats; -} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 038fe432..0608336c 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -71,8 +71,6 @@ public: int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } - - NamedCommStats comm_stats(); }; inline YaoGarbler& YaoGarbler::s() diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h index ddaf3b9c..92f3ec61 100644 --- a/Yao/YaoWire.h +++ b/Yao/YaoWire.h @@ -23,6 +23,10 @@ public: static void xors(GC::Processor& processor, const vector& args, size_t start, size_t end); + template + static void andm(GC::Processor& processor, + const BaseInstruction& instruction); + void XOR(const YaoWire& left, const YaoWire& right) { key_ = left.key_ ^ right.key_; diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp index bb3b1406..aa04fe35 100644 --- a/Yao/YaoWire.hpp +++ b/Yao/YaoWire.hpp @@ -46,4 +46,24 @@ void YaoWire::xors(GC::Processor& processor, const vector& args, processor.xors(args, start, end); } +template +void YaoWire::andm(GC::Processor& processor, + const BaseInstruction& instruction) +{ + + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto &dest = processor.S[instruction.get_r(0) + i]; + int n = min(unsigned(unit), instruction.get_n() - i * unit); + dest.resize_regs(n); + for (int j = 0; j < n; j++) + if (processor.C[instruction.get_r(2) + i].get_bit(j)) + dest.get_reg(j) = + processor.S[instruction.get_r(1) + i].get_reg(j); + else + dest.get_reg(j).public_input(0); + } +} + #endif /* YAO_YAOWIRE_HPP_ */ diff --git a/doc/Doxyfile b/doc/Doxyfile index 771f8cf1..3dd29940 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/conf.py b/doc/conf.py index 57f730ad..86bb12d4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -21,7 +21,7 @@ exec(compile(open('gen-instructions.py').read(), 'gen', 'exec')) # -- Project information ----------------------------------------------------- project = u'MP-SPDZ' -copyright = u'2021, CSIRO\'s Data61' +copyright = u'2022, CSIRO\'s Data61' author = u'Marcel Keller' # The short X.Y version @@ -185,7 +185,8 @@ epub_exclude_files = ['search.html'] breathe_projects = {'mp-spdz': 'xml'} breathe_default_project = 'mp-spdz' import subprocess -subprocess.call('doxygen', shell=True) +if (subprocess.call('doxygen', shell=True)): + raise Exception('doxygen failed') def setup(app): app.add_css_file('custom.css') diff --git a/doc/index.rst b/doc/index.rst index d7a13e94..d2a2c4dc 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,10 +1,16 @@ Welcome to MP-SPDZ's documentation! =================================== -This documentation provides a reference to the most important -high-level functionality provided by the MP-SPDZ compiler. For a -tutorial and documentation on how to run programs, the -implemented protocols etc. see https://github.com/data61/MP-SPDZ. +If you're new to MP-SPDZ, consider the following: + +1. `Quickstart tutorial `_ +2. `Implemented protocols `_ +3. :ref:`troubleshooting` + +Unlike the `Readme +`_, this +documentation provides a reference for more detailed aspects of the +software. Compilation process ------------------- diff --git a/doc/io.rst b/doc/io.rst index 5184ab33..a4d00cee 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -83,6 +83,8 @@ covering both client code and server-side high-level code. :py:func:`Compiler.types.MultiArray.reveal_to_clients`. The same functions are available for :py:class:`~Compiler.types.sfix` and :py:class:`~Compiler.types.Array`, respectively. +See also :ref:`client ref` below. + Secret Shares ~~~~~~~~~~~~~ @@ -114,3 +116,11 @@ etc. Note also that all types based on :py:class:`~Compiler.types.sfix`) share the same memory, and that the address is only a base address. This means that vectors will be written to the memory starting at the given address. + +.. _client ref: + +Reference +~~~~~~~~~ + +.. doxygenclass:: Client + :members: diff --git a/doc/low-level.rst b/doc/low-level.rst index 0aaf3708..c70bf5b6 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -83,109 +83,24 @@ number of parties. .. code-block:: cpp - // initialize fields - T::clear::init_default(prime_length); + ProtocolSetup setup(P, prime_length); We have to use a specific prime for computation modulo a prime. This deterministically generates one of the desired length if necessary. For computation modulo a power of two, this does not do -anything. +anything. Some protocols use an information-theoretic tag that is +constant throughout the protocol. This code reads it from storage if +available or generates a fresh one otherwise. .. code-block:: cpp - T::clear::next::init_default(prime_length, false); + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; -For computation modulo a prime, it is more efficient to use Montgomery -representation, which is not compatible with the MASCOT offline phase -however. This line initializes another field instance for MASCOT -without using Montgomery representation. - -.. code-block:: cpp - - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - -Some protocols use an information-theoretic tag that is constant -throughout the protocol. This codes reads it from storage if available -or generates a fresh one otherwise. - -.. code-block:: cpp - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - -Many protocols for a dishonest majority use oblivious transfer. This -block runs a few instances to seed the oblivious transfer -extension. The resulting setup only works for one thread. For several -threads, you need to add sufficiently many instances to -:member:`ot_setups` and set :member:`BaseMachine::thread_num` -(thread-local) to a different consecutive number in every thread. - -.. code-block:: cpp - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - -To help keeping track of the required preprocessing, it is necessary -to initialize preprocessing instances with a :class:`DataPositions` -variable that will store the usage. - -.. code-block:: cpp - - // initialize binary computation - T::bit_type::mac_key_type::init_field(); - typename T::bit_type::mac_key_type binary_mac_key; - T::bit_type::part_type::read_or_generate_mac_key("", P, binary_mac_key); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, binary_mac_key, usage); - -While this example only uses arithmetic computation, you need to -initialize binary computation as well unless you use the compile-time -option ``NO_MIXED_CIRCUITS``. - -.. code-block:: cpp - - // output protocol - typename T::MAC_Check output(mac_key); - -Some output protocols use the MAC key to check the correctness. - -.. code-block:: cpp - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - -In this example we use live preprocessing, but it is also possible to -read preprocessing data from disk by using :class:`Sub_Data_Files` -instead. You can use a live preprocessing instances to generate -preprocessing data independently, but many protocols require that a -:class:`SubProcessor` instance has been created as well. The latter -essentially glues an instance of the output and the preprocessing -protocol together, which is necessary for Beaver-based multiplication -protocols. - -.. code-block:: cpp - - // input protocol - typename T::Input input(processor, output); - -Some input protocols depend on preprocessing and an output protocol, -which is reflect in the standard constructor. Other constructors are -available depending on the protocol. - -.. code-block:: cpp - - // multiplication protocol - typename T::Protocol protocol(P); - -This instantiates a multiplication protocol. :var:`P` is required -because some protocols start by exchanging keys for pseudo-random -secret sharing. +The :class:`ProtocolSet` contains one instance for every essential +protocol step. .. code-block:: cpp @@ -235,6 +150,14 @@ The initialization of the multiplication sets the preprocessing and output instances to use in Beaver multiplication. :func:`next_dotprod` separates dot products in the data preparation phase. +.. code-block:: cpp + + protocol.check(); + +Some protocols require a check of all multiplications up to a certain +point. To guarantee that outputs do not reveal secret information, it +has to be run before using the output protocol. + .. code-block:: cpp output.init_open(P); @@ -245,8 +168,8 @@ separates dot products in the data preparation phase. cout << "result: " << result << endl; output.Check(P); -The output protocol follows the same blueprint except that it is -necessary to call the checking in order to verify the outputs. +The output protocol follows the same blueprint as the multiplication +protocol. .. code-block:: cpp @@ -281,6 +204,9 @@ Domain Types the time of writing, 4, 8, 28, 40, 63, and 128 are supported if the storage type is large enough. + +.. _share-type-reference: + Share Types ------------ @@ -385,6 +311,28 @@ Share Types ``MaliciousShamirShare`` or ``MaliciousRep3Share``. +Protocol Setup +-------------- + +.. doxygenclass:: ProtocolSetup + :members: + +.. doxygenclass:: ProtocolSet + :members: + +.. doxygenclass:: BinaryProtocolSetup + :members: + +.. doxygenclass:: BinaryProtocolSet + :members: + +.. doxygenclass:: MixedProtocolSetup + :members: + +.. doxygenclass:: MixedProtocolSet + :members: + + Protocol Interfaces ------------------- diff --git a/doc/networking.rst b/doc/networking.rst index 16908681..a1c61b98 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -18,7 +18,7 @@ individually setting ports: coordination server being run as a thread of party 0. The hostname of the coordination server has to be given with the command-line parameter ``--hostname``, and the coordination server runs on the - base port number minus one, thus defaulting to 4999. Furthermore, you + base port number, thus defaulting to 5000. Furthermore, you can specify a party's listening port using ``--my-port``. 2. The parties read the information from a local file, which needs to @@ -40,7 +40,9 @@ change this by either using ``--encrypted/-e`` or If using encryption, the certificates (``Player-Data/*.pem``) must be the same on all hosts, and you have to run ``c_rehash Player-Data`` on -all of them. +all of them. ``Scripts/setup-ssl.sh`` can be used to generate the +necessary certificates. The common name has to be ``P`` +for computing parties and ``C`` for clients. .. _network-reference: diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 5fe8df1f..bcdbbd3a 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -7,8 +7,8 @@ domains (modulus other than two) only comes in three flavors throughout MP-SPDZ: Unknown prime modulus - This approach goes back to `Catrina and Saxena - `_. It crucially relies on + This approach goes back to `Catrina and de Hoogh + `_. It crucially relies on the use of secret random bits in the arithmetic domain. Enough such bits allow to mask a secret value so that it is secure to reveal the masked value. This can then be split in bits as it is diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 3dadcfae..1441e352 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -16,7 +16,7 @@ is a thread created by control flow instructions such as The exceptions to the general rule are edaBit generation with malicious security and AND triples with malicious security and honest -majority, both when use bucket size three. Bucket size three implies +majority, both when using bucket size three. Bucket size three implies batches of over a million to achieve 40-bit statistical security, and in honest-majority binary computation the item size is 64, which makes the actual batch size 64 million triples. In multithreaded programs, @@ -27,3 +27,65 @@ jump whenever another batch is generated. Note that, while some protocols are flexible with the batch size and can thus be controlled using ``-b``, others mandate a batch size, which can be as large as a million. + + +Separate preprocessing +====================== + +It is possible to separate out the preprocessing from the +input-dependent ("online") phase. This is done by either option ``-F`` +or ``-f`` on the virtual machines. In both cases, the preprocessing +data is read from files, either all data per type from a single file +(``-F``) or one file per thread (``-f``). The latter allows to use +named pipes. + +The file name depends on the protocol and the computation domain. It +is generally ``/--/--P[-T]``. For example, the +triples for party 1 in SPDZ modulo a 128-bit prime can be found in +``Player-Data/2-p-128/Triples-p-P1``. The protocol shorthand can be +found by calling ``::type_short()``. See +:ref:`share-type-reference` for a description of the share types. + +Preprocessing files start with a header describing the protocol and +computation domain to avoid errors due to mismatches. The header is as +follows: + +- Length to follow (little-endian 8-byte number) +- Protocol descriptor +- Domain descriptor + +The protocol descriptor is defined by ``::type_string()``. For SPDZ modulo a prime it is ``SPDZ gfp``. + +The domain descriptor depends on the kind of domain: + +Modulo a prime + Serialization of the prime + + - Sign bit (0 as 1 byte) + - Length to follow (little-endian 4-byte number) + - Prime (big-endian) + +Modulo a power of two: + Exponent (little-endian 4-byte number) + +:math:`GF(2^n)` + - Storage size in bytes (little-endian 8-byte number). Default is 16. + - :math:`n` (little-endian 4-byte number) + +As an example, the following output of ``hexdump -C`` describes SPDZ +modulo the default 128-bit prime +(170141183460469231731687303715885907969):: + + 00000000 1d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |........SPDZ gfp| + 00000010 00 10 00 00 00 80 00 00 00 00 00 00 00 00 00 00 |................| + 00000020 00 00 1b 80 01 |.....| + 00000025 + + +``Fake-Offline.x`` generates preprocessing data insecurely for a range +of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +sufficient preprocessing data for a specific high-level program with +MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 1c096d98..6a79ea19 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -1,3 +1,5 @@ +.. _troubleshooting: + Troubleshooting --------------- @@ -57,10 +59,23 @@ second batch is necessary the cost shoots up. Other preprocessing methods allow for a variable batch size, which can be changed using ``-b``. Smaller batch sizes generally reduce the communication cost while potentially increasing the number of communication rounds. Try -adding ``-b 10`` to the virtal machine (or script) arguments for very +adding ``-b 10`` to the virtual machine (or script) arguments for very short computations. +Disparities in round figures +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The number of virtual machine rounds given by the compiler are not an +exact prediction of network rounds but the number of relevant protocol +calls (such as multiplication, input, output etc) in the program. The +actual number of network rounds is determined by the choice of +protocol, which might use several rounds per protocol +call. Furthermore, communication at the beginning and the end of a +computation such as random key distribution and MAC checks further +increase the number of network rounds. + + Handshake failures ~~~~~~~~~~~~~~~~~~ @@ -82,8 +97,8 @@ use the client facility. Connection failures ~~~~~~~~~~~~~~~~~~~ -MP-SPDZ requires at least one TCP port per party to be open to other -parties. In the default setting, it's 4999 and 5000 on party 0, and +MP-SPDZ requires one TCP port per party to be open to other +parties. In the default setting, it's 5000 on party 0, and 5001 on party 1 etc. You change change the base port (5000) using ``--portnumbase`` and individual ports for parties using ``--my-port``. The scripts in use a random base port number, which you From aaa90a20bbf33e4ebac270968200f04255b46eca Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 12 Jan 2022 18:57:40 +1100 Subject: [PATCH 012/265] Don't overwrite persistence files at beginning. --- Processor/Machine.hpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index d7d1a3ec..909a8f3a 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -94,13 +94,6 @@ Machine::Machine(int my_number, Names& playerNames, load_schedule(progname_str); - // remove persistence if necessary - for (auto& prog : progs) - { - if (prog.writes_persistance) - ofstream(Binary_File_IO::filename(my_number), ios::out); - } - #ifdef VERBOSE progs[0].print_offline_cost(); #endif From 962919c3cf592fb2c79a6282b2d62842d35e1dc9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 13 Jan 2022 14:09:53 +1100 Subject: [PATCH 013/265] Bug in regint optimizer. --- Compiler/allocator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 7ce9896b..9871d97f 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -583,13 +583,6 @@ class RegintOptimizer: self.cache[inst.args[0]] = res instructions[i] = ldint(inst.args[0], res, add_to_prog=False) - elif isinstance(inst, addint_class): - if inst.args[1] in self.cache and \ - self.cache[inst.args[1]] == 0: - instructions[i] = inst.args[0].link(inst.args[2]) - elif inst.args[2] in self.cache and \ - self.cache[inst.args[2]] == 0: - instructions[i] = inst.args[0].link(inst.args[1]) elif isinstance(inst, IndirectMemoryInstruction): if inst.args[1] in self.cache: instructions[i] = inst.get_direct(self.cache[inst.args[1]]) @@ -606,7 +599,4 @@ class RegintOptimizer: if op == 0: instructions[i] = ldsi(inst.args[0], 0, add_to_prog=False) - elif op == 1: - instructions[i] = None - inst.args[0].link(inst.args[1]) instructions[:] = list(filter(lambda x: x is not None, instructions)) From f343d73b25ae6ddc62d47aaf9cd146362bfcbf47 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 13 Jan 2022 14:10:06 +1100 Subject: [PATCH 014/265] Bug in for_range_opt. --- Compiler/library.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Compiler/library.py b/Compiler/library.py index 7bab1951..4f6c2de1 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -12,6 +12,7 @@ import inspect,math import random import collections import operator +import copy from functools import reduce def get_program(): @@ -1031,12 +1032,14 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], state = tuplify(initializer()) k = 0 block = get_block() + pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ and block is get_block(): j = i + k state = reducer(tuplify(loop_body(j)), state) k += 1 + _link(pre, loop_body.__globals__) r = reducer(mem_state, state) write_state_to_memory(r) global n_opt_loops @@ -1395,9 +1398,12 @@ def do_loop(condition, loop_fn): def _run_and_link(function, g=None): if g is None: g = function.__globals__ - import copy pre = copy.copy(g) res = function() + _link(pre, g) + return res + +def _link(pre, g): if g: from .types import _single for name, var in pre.items(): @@ -1407,7 +1413,6 @@ def _run_and_link(function, g=None): raise CompilerError('cannot reassign constants in blocks') if id(new_var) != id(var): new_var.link(var) - return res def do_while(loop_fn, g=None): """ Do-while loop. The loop is stopped if the return value is zero. From 0f9d5de6979915ce35e57ea747bdb5214d2dfd61 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 12 Jan 2022 20:11:28 +1100 Subject: [PATCH 015/265] Allow overwriting of persistence files. --- Compiler/instructions.py | 5 +++-- Compiler/types.py | 42 +++++++++++++++++++++++++----------- Processor/Binary_File_IO.h | 3 ++- Processor/Binary_File_IO.hpp | 18 ++++++++++++++-- Processor/Instruction.hpp | 4 ++-- Processor/Machine.hpp | 13 +++++++++++ Processor/Processor.h | 2 +- Processor/Processor.hpp | 6 ++++-- 8 files changed, 70 insertions(+), 23 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 1533fc52..a85fb25a 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1727,14 +1727,15 @@ class writesharestofile(base.IOInstruction): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: number of shares (int) + :param: number of arguments to follow / number of shares plus one (int) + :param: position (regint, -1 for appending) :param: source (sint) :param: (repeat from source)... """ __slots__ = [] code = base.opcodes['WRITEFILESHARE'] - arg_format = itertools.repeat('s') + arg_format = tools.chain(['ci'], itertools.repeat('s')) def has_var_args(self): return True diff --git a/Compiler/types.py b/Compiler/types.py index 33df2e37..77de5d71 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2329,16 +2329,20 @@ class sint(_secret, _int): return stop, shares @staticmethod - def write_to_file(shares): + def write_to_file(shares, position=None): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). - :param: shares (list or iterable of sint) + :param shares: (list or iterable of sint) + :param position: start position (int/regint/cint), + defaults to end of file """ for share in shares: assert isinstance(share, sint) assert share.size == 1 - writesharestofile(*shares) + if position is None: + position = -1 + writesharestofile(regint.conv(position), *shares) @vectorized_classmethod def load_mem(cls, address, mem_type=None): @@ -3922,13 +3926,15 @@ class _single(_number, _secret_structure): return stop, [cls._new(x) for x in shares] @classmethod - def write_to_file(cls, shares): + def write_to_file(cls, shares, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. - :param: shares (list or iterable of sfix) + :param shares: (list or iterable of sfix) + :param position: start position (int/regint/cint), + defaults to end of file """ - cls.int_type.write_to_file([x.v for x in shares]) + cls.int_type.write_to_file([x.v for x in shares], position) def store_in_mem(self, address): """ Store in memory by public address. """ @@ -5389,11 +5395,14 @@ class Array(_vectorizable): self.assign(shares) return stop - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ - self.value_type.write_to_file(list(self)) + self.value_type.write_to_file(list(self), position) def __add__(self, other): """ Vector addition. @@ -5723,13 +5732,20 @@ class SubMultiArray(_vectorizable): def _(i): self[i].input_from(player, budget=budget, raw=raw) - def write_to_file(self): + def write_to_file(self, position=None): """ Write shares of integer representation to - ``Persistence/Transactions-P.data`` (appending at the end). + ``Persistence/Transactions-P.data``. + + :param position: start position (int/regint/cint), + defaults to end of file """ @library.for_range(len(self)) def _(i): - self[i].write_to_file() + if position is None: + my_pos = None + else: + my_pos = position + i * self[i].total_size() + self[i].write_to_file(my_pos) def read_from_file(self, start): """ Read content from ``Persistence/Transactions-P.data``. diff --git a/Processor/Binary_File_IO.h b/Processor/Binary_File_IO.h index 4e38cd16..c19a129a 100644 --- a/Processor/Binary_File_IO.h +++ b/Processor/Binary_File_IO.h @@ -27,7 +27,8 @@ class Binary_File_IO * Throws file_error. */ template - void write_to_file(const string filename, const vector< T >& buffer); + void write_to_file(const string filename, const vector& buffer, + long start_pos); /* * Read from posn in the filename the binary values until the buffer is full. diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index 9878f4a6..ef735279 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -14,18 +14,32 @@ inline string Binary_File_IO::filename(int my_number) } template -void Binary_File_IO::write_to_file(const string filename, const vector< T >& buffer) + +void Binary_File_IO::write_to_file(const string filename, + const vector& buffer, long start_pos) { ofstream outf; - outf.open(filename, ios::out | ios::binary | ios::app); + outf.open(filename, ios::out | ios::binary | ios::ate | ios::in); if (outf.fail()) { throw file_error(filename); } + if (start_pos != -1) + { + long write_pos = start_pos * T::size(); + // fill with zeros if needed + for (long i = outf.tellp(); i < write_pos; i++) + outf.put(0); + outf.seekp(write_pos); + } + for (unsigned int i = 0; i < buffer.size(); i++) { buffer[i].output(outf, false); } + if (outf.fail()) + throw runtime_error("failed writing to " + filename); + outf.close(); } diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e45a8504..25fa666f 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -273,7 +273,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_vector(2, start, s); break; // open instructions + read/write instructions with variable length args - case WRITEFILESHARE: case OPEN: case GOPEN: case MULS: @@ -376,6 +375,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case BITDECINT: case EDABIT: case SEDABIT: + case WRITEFILESHARE: num_var_args = get_int(s) - 1; r[0] = get_int(s); get_vector(num_var_args, start, s); @@ -1175,7 +1175,7 @@ inline void Instruction::execute(Processor& Proc) const break; case WRITEFILESHARE: // Write shares to file system - Proc.write_shares_to_file(start); + Proc.write_shares_to_file(Proc.read_Ci(r[0]), start); break; case READFILESHARE: // Read shares from file system diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 909a8f3a..a43c9d47 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -94,6 +94,19 @@ Machine::Machine(int my_number, Names& playerNames, load_schedule(progname_str); + // initialize persistence if necessary + for (auto& prog : progs) + { + if (prog.writes_persistance) + { + string filename = Binary_File_IO::filename(my_number); + ifstream pers(filename); + if (pers.fail()) + ofstream pers(filename, ios::binary); + break; + } + } + #ifdef VERBOSE progs[0].print_offline_cost(); #endif diff --git a/Processor/Processor.h b/Processor/Processor.h index a78058cd..c91b677b 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -239,7 +239,7 @@ class Processor : public ArithmeticProcessor // Read and write secret numeric data to file (name hardcoded at present) void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); - void write_shares_to_file(const vector& data_registers); + void write_shares_to_file(long start_pos, const vector& data_registers); cint get_inverse2(unsigned m); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index caea1e67..c55a6dfc 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -370,7 +370,9 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ // Append share data in data_registers to end of file. Expects Persistence directory to exist. template -void Processor::write_shares_to_file(const vector& data_registers) { +void Processor::write_shares_to_file(long start_pos, + const vector& data_registers) +{ string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -382,7 +384,7 @@ void Processor::write_shares_to_file(const vector& data_regist inpbuf[i] = get_Sp_ref(data_registers[i]); } - binary_file_io.write_to_file(filename, inpbuf); + binary_file_io.write_to_file(filename, inpbuf, start_pos); } template From 9fcffad831ca4c452e6d9c322ce49d47f64187ef Mon Sep 17 00:00:00 2001 From: jvmncs Date: Tue, 18 Jan 2022 08:49:00 -0500 Subject: [PATCH 016/265] approx_sigmoid is attributed to an earlier paper --- Compiler/ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 7e53a78f..5c4664be 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -104,7 +104,7 @@ def sigmoid_prime(x): @vectorize def approx_sigmoid(x, n=3): """ Piece-wise approximate sigmoid as in - `Dahl et al. `_ + `Hong et al. `_ :param x: input :param n: number of pieces, 3 (default) or 5 From fc3a2a0f320e27910bbfdba1c96ef64e6633337f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 24 Jan 2022 13:24:51 +1100 Subject: [PATCH 017/265] Personal array functionality. --- Compiler/types.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index 77de5d71..3fdc6cf0 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1698,6 +1698,12 @@ class personal(Tape._no_truth): def _div_san(self): return self._v.conv((library.get_player_id() == self.player)._v).if_else(self._v, 1) + def __setitem__(self, index, value): + self._san(value) + self._v[index] = value + + __getitem__ = lambda self, index: personal(self.player, self._v[index]) + __add__ = lambda self, other: personal(self.player, self._san(other) + other) __sub__ = lambda self, other: personal(self.player, self._san(other) - other) __mul__ = lambda self, other: personal(self.player, self._san(other) * other) @@ -5500,6 +5506,14 @@ class Array(_vectorizable): """ self.get_vector().binary_output(player) + def reveal_to(self, player): + """ Reveal secret array to :py:obj:`player`. + + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` containing an array + """ + return personal(player, self.create_from(self[:].reveal_to(player)._v)) + def sort(self, n_threads=None): """ Sort in place using Batchers' odd-even merge mergesort From 5584e1818dd52eefbe89b3b0e37aee52d11daa2d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 31 Jan 2022 14:29:44 +1100 Subject: [PATCH 018/265] Bugs in binary register conversion. --- GC/Secret.hpp | 2 +- Yao/YaoGarbleWire.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 01c70247..88f1926a 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -140,7 +140,7 @@ T& GC::Secret::get_new_reg() template void Secret::load_clear(int n, const Integer& x) { - if ((unsigned)n < 8 * sizeof(x) and abs(x.get()) > (1LL << n)) + if ((unsigned)n < 8 * sizeof(x) and (unsigned long) abs(x.get()) > (1ul << n)) throw out_of_range("public value too long"); #ifdef DEBUG_ROUNDS2 cout << "secret from integer " << hex << this << dec << " " << endl; diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 37931df4..05a8646d 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -241,7 +241,7 @@ void YaoGarbleWire::convcbit2s(GC::Processor& processor, int n = min(unsigned(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) - dest.get_reg(i).public_input( + dest.get_reg(j).public_input( processor.C[instruction.get_r(1) + i].get_bit(j)); } } From d50e97fde91369256caf339c0c55e19071d542f2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 1 Feb 2022 13:54:53 +1100 Subject: [PATCH 019/265] Simplify code. --- Compiler/GC/types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 53da15ba..94d52082 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -111,8 +111,7 @@ class bits(Tape.Register, _structure, _bit): if mem_type == 'sd': return cls.load_dynamic_mem(address) else: - for i in range(res.size): - cls.mem_op(cls.load_inst, res[i], address + i) + cls.mem_op(cls.load_inst, res, address) return res def store_in_mem(self, address): self.mem_op(self.store_inst, self, address) From 61d40b7d8392ee836e973df7a26bc53154c3d6a7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 4 Feb 2022 11:16:12 +1100 Subject: [PATCH 020/265] Fix bugs in mathematical functions using binary circuits. --- Compiler/GC/types.py | 10 +++++++++- Compiler/mpc_math.py | 10 +++++----- Compiler/types.py | 1 + 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 94d52082..13619c7f 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -811,7 +811,7 @@ class sbitvec(_vec): def store_in_mem(self, address): for i, x in enumerate(self.elements()): x.store_in_mem(address + i) - def bit_decompose(self, n_bits=None, security=None): + def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): @@ -1267,6 +1267,9 @@ class sbitfixvec(_fix): int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) clear_type = cbitfix + @property + def bit_type(self): + return type(self.v[0]) @classmethod def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) @@ -1284,6 +1287,8 @@ class sbitfixvec(_fix): if isinstance(value, (list, tuple)): self.v = self.int_type.from_vec(sbitvec([x.v for x in value])) else: + if isinstance(value, sbitvec): + value = self.int_type(value) super(sbitfixvec, self).__init__(value, *args, **kwargs) def elements(self): return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] @@ -1293,9 +1298,12 @@ class sbitfixvec(_fix): else: return super(sbitfixvec, self).mul(other) def __xor__(self, other): + if util.is_zero(other): + return self return self._new(self.v ^ other.v) def __and__(self, other): return self._new(self.v & other.v) + __rxor__ = __xor__ @staticmethod def multipliable(other, k, f, size): class cls(_fix): diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 322989b3..47253dc4 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -290,7 +290,7 @@ def exp2_fx(a, zero_output=False, as19=False): # how many bits to use from integer part n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits - sint = types.sint + sint = a.int_type if types.program.options.ring and not as19: intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k @@ -367,17 +367,17 @@ def exp2_fx(a, zero_output=False, as19=False): bits = a.v.bit_decompose(a.k, maybe_mixed=True) lower = sint.bit_compose(bits[:a.f]) higher_bits = bits[a.f:n_bits] - s = sint.conv(bits[-1]) + s = a.bit_type.conv(bits[-1]) bits_to_check = bits[n_bits:-1] if not as19: - c = types.sfix._new(lower, k=a.k, f=a.f) + c = a._new(lower, k=a.k, f=a.f) assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) g = exp_from_parts(d, c) - small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, + small_result = a._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, - nearest=types.sfix.round_nearest), + nearest=a.round_nearest), k=a.k, f=a.f) if zero_output: t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), diff --git a/Compiler/types.py b/Compiler/types.py index 3fdc6cf0..0063fdc1 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -4274,6 +4274,7 @@ class sfix(_fix): :params _v: int/float/regint/cint/sint/sfloat """ int_type = sint + bit_type = sintbit clear_type = cfix @vectorized_classmethod From 0f7020d791a667ede375aa365f109ac286e89d43 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 17 Feb 2022 13:21:19 +1100 Subject: [PATCH 021/265] Semi-honest computation based on threshold semi-homomorphic encryption. --- CHANGELOG.md | 13 +- CONFIG | 1 + Compiler/GC/instructions.py | 37 ++- Compiler/GC/types.py | 28 +- Compiler/allocator.py | 20 +- Compiler/instructions.py | 177 +++++++------ Compiler/instructions_base.py | 57 ++++- Compiler/library.py | 57 ++++- Compiler/ml.py | 351 +++++++++++++++++++++++--- Compiler/oram.py | 1 + Compiler/program.py | 11 +- Compiler/types.py | 97 +++++-- FHE/FHE_Keys.cpp | 16 +- FHE/FHE_Keys.h | 2 + FHE/FHE_Params.cpp | 15 ++ FHE/FHE_Params.h | 6 +- FHE/NTL-Subs.cpp | 11 +- FHE/NTL-Subs.h | 2 +- FHE/NoiseBounds.cpp | 5 +- FHE/Ring_Element.cpp | 1 + FHE/Rq_Element.cpp | 9 +- FHE/Rq_Element.h | 8 +- FHEOffline/DataSetup.cpp | 2 +- FHEOffline/Multiplier.cpp | 7 + FHEOffline/Multiplier.h | 3 + FHEOffline/PairwiseSetup.cpp | 14 +- FHEOffline/PairwiseSetup.h | 2 +- FHEOffline/SimpleDistDecrypt.cpp | 8 + FHEOffline/SimpleDistDecrypt.h | 1 + FHEOffline/TemiSetup.cpp | 59 +++++ FHEOffline/TemiSetup.h | 34 +++ GC/Memory.h | 2 +- GC/ShareSecret.h | 1 + GC/TinySecret.h | 1 + GC/instructions.h | 2 +- Machines/ShamirMachine.hpp | 1 + Machines/temi-party.cpp | 37 +++ Makefile | 5 +- Math/FixedVec.h | 5 - Math/Zp_Data.h | 2 +- Math/gf2n.cpp | 16 +- Math/mpn_fixed.h | 6 + Networking/Player.h | 1 + OT/BaseOT.cpp | 18 +- Processor/Binary_File_IO.hpp | 13 +- Processor/Input.h | 19 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 2 + Processor/Instruction.hpp | 46 ++-- Processor/Machine.hpp | 15 +- Processor/Memory.h | 4 +- Processor/Memory.hpp | 7 +- Processor/PrivateOutput.h | 12 +- Processor/PrivateOutput.hpp | 33 ++- Processor/Processor.h | 6 +- Processor/Processor.hpp | 36 ++- Processor/Program.cpp | 2 +- Processor/Program.h | 4 +- Processor/SpecificPrivateOutput.h | 65 +++++ Programs/Source/falcon_alex.mpc | 100 ++++++++ Programs/Source/keras_cifar_lenet.mpc | 45 ++++ Programs/Source/keras_mnist_dense.mpc | 3 +- Programs/Source/keras_mnist_lenet.mpc | 13 + Programs/Source/mnist_full_A.mpc | 6 + Programs/Source/mnist_full_C.mpc | 8 +- Protocols/Atlas.hpp | 6 + Protocols/Hemi.hpp | 2 +- Protocols/HemiMatrixPrep.h | 5 +- Protocols/HemiMatrixPrep.hpp | 68 +++-- Protocols/HemiPrep.h | 3 + Protocols/HemiPrep.hpp | 14 + Protocols/HemiShare.h | 1 + Protocols/LowGearKeyGen.hpp | 8 +- Protocols/MAC_Check.h | 17 +- Protocols/MAC_Check.hpp | 1 + Protocols/MAC_Check_Base.h | 4 + Protocols/MalRepRingShare.h | 4 +- Protocols/MaliciousRep3Share.h | 3 +- Protocols/MaliciousShamirPO.h | 3 +- Protocols/MaliciousShamirShare.h | 4 +- Protocols/MamaShare.h | 6 - Protocols/PostSacriRepFieldShare.h | 4 +- Protocols/PostSacriRepRingShare.h | 4 +- Protocols/ProtocolSet.h | 25 +- Protocols/Rep3Share.h | 7 +- Protocols/Rep3Share2k.h | 3 +- Protocols/Rep4Input.h | 1 - Protocols/Rep4Input.hpp | 6 - Protocols/Replicated.h | 7 - Protocols/Replicated.hpp | 6 +- Protocols/ReplicatedPrep.hpp | 33 ++- Protocols/ReplicatedPrivateOutput.h | 26 -- Protocols/ReplicatedPrivateOutput.hpp | 30 --- Protocols/Semi.h | 6 + Protocols/SemiInput.h | 29 +-- Protocols/SemiInput.hpp | 62 ++++- Protocols/Shamir.h | 1 - Protocols/Shamir.hpp | 36 +-- Protocols/ShamirInput.h | 7 +- Protocols/ShamirInput.hpp | 33 ++- Protocols/ShamirMC.h | 4 + Protocols/ShamirMC.hpp | 13 + Protocols/ShamirShare.h | 7 +- Protocols/Share.h | 1 + Protocols/ShareInterface.h | 1 + Protocols/SpdzWiseInput.h | 3 - Protocols/SpdzWiseInput.hpp | 18 -- Protocols/SpdzWiseMC.h | 2 +- Protocols/SpdzWisePrep.hpp | 1 - Protocols/TemiPrep.h | 72 ++++++ Protocols/TemiPrep.hpp | 129 ++++++++++ Protocols/TemiShare.h | 42 +++ Protocols/fake-stuff.hpp | 9 +- README.md | 33 ++- Scripts/prep-usage.py | 23 ++ Scripts/temi.sh | 8 + Scripts/test_tutorial.sh | 2 +- Tools/Buffer.h | 4 + Tools/Exceptions.cpp | 4 +- Tools/Exceptions.h | 2 +- Tools/octetStream.h | 2 + Utils/binary-example.cpp | 4 +- Utils/mixed-example.cpp | 4 +- Utils/paper-example.cpp | 4 +- doc/instructions.rst | 10 +- doc/low-level.rst | 5 + doc/non-linear.rst | 2 +- doc/preprocessing.rst | 34 ++- doc/requirements.txt | 1 + 129 files changed, 1973 insertions(+), 539 deletions(-) create mode 100644 FHEOffline/TemiSetup.cpp create mode 100644 FHEOffline/TemiSetup.h create mode 100644 Machines/temi-party.cpp create mode 100644 Processor/SpecificPrivateOutput.h create mode 100644 Programs/Source/falcon_alex.mpc create mode 100644 Programs/Source/keras_cifar_lenet.mpc delete mode 100644 Protocols/ReplicatedPrivateOutput.h delete mode 100644 Protocols/ReplicatedPrivateOutput.hpp create mode 100644 Protocols/TemiPrep.h create mode 100644 Protocols/TemiPrep.hpp create mode 100644 Protocols/TemiShare.h create mode 100755 Scripts/prep-usage.py create mode 100755 Scripts/temi.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b75d24f..6a0406a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,17 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.2.9 (Jan 11, 2021) +## 0.3.0 (Feb 17, 2022) + +- Semi-honest computation based on threshold semi-homomorphic encryption +- Batch normalization backward propagation +- AlexNet for CIFAR-10 +- Specific private output protocols +- Semi-honest additive secret sharing without communication +- Sending of personal values +- Allow overwriting of persistence files +- Protocol signature in persistence files + +## 0.2.9 (Jan 11, 2022) - Disassembler - Run-time parameter for probabilistic truncation error diff --git a/CONFIG b/CONFIG index ba6855ea..05b3683d 100644 --- a/CONFIG +++ b/CONFIG @@ -42,6 +42,7 @@ else AVX_OT = 1 endif else +ARCH = AVX_OT = 0 endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index fc64ae2d..ef9c14a3 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -497,7 +497,7 @@ class movsb(NonVectorInstruction): code = opcodes['MOVSB'] arg_format = ['sbw','sb'] -class trans(base.VarArgsInstruction): +class trans(base.VarArgsInstruction, base.DynFormatInstruction): """ Secret bit register vector transpose. The first destination vector will contain the least significant bits of all source vectors etc. @@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction): code = opcodes['TRANS'] is_vec = lambda self: True def __init__(self, *args): - self.arg_format = ['int'] + ['sbw'] * args[0] + \ - ['sb'] * (len(args) - 1 - args[0]) super(trans, self).__init__(*args) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + n = next(args) + for i in range(n): + yield 'sbw' + next(args) + while True: + try: + yield 'sb' + next(args) + except StopIteration: + break + class bitb(NonVectorInstruction): """ Copy fresh secret random bit to secret bit register. @@ -560,7 +572,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1]) class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, - base.Mergeable): + base.Mergeable, base.DynFormatInstruction): """ Copy private input to secret bit registers bit by bit. The input is read as floating-point number, multiplied by a power of two, rounded to an integer, and then decomposed into bits. @@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, code = opcodes['INPUTBVEC'] def __init__(self, *args, **kwargs): - self.arg_format = [] - for x in self.get_arg_tuples(args): - self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3) super(inputbvec, self).__init__(*args, **kwargs) + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + yield 'p' + for j in range(n - 3): + yield 'sbw' + yield 'int' + @staticmethod def get_arg_tuples(args): i = 0 @@ -590,10 +609,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, i += args[i] assert i == len(args) - def merge(self, other): - self.args += other.args - self.arg_format += other.arg_format - def add_usage(self, req_node): for x in self.get_arg_tuples(self.args): req_node.increment(('bit', 'input', x[2]), x[0] - 3) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 13619c7f..38c37a26 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -41,7 +41,7 @@ class bits(Tape.Register, _structure, _bit): return cls.types[length] @classmethod def conv(cls, other): - if isinstance(other, cls): + if isinstance(other, cls) and cls.n == other.n: return other elif isinstance(other, MemValue): return cls.conv(other.read()) @@ -246,14 +246,20 @@ class cbits(bits): assert n == res.n assert n == other.size cls.conv_cint_vec(cint(other, size=other.size), res) + @classmethod + def conv(cls, other): + if isinstance(other, cbits) and cls.n != None and \ + cls.n // cls.unit == other.n // cls.unit: + return other + else: + return super(cbits, cls).conv(other) types = {} def load_int(self, value): - if self.n <= 64: - tmp = regint(value) - elif value == self.long_one(): - tmp = cint(1, size=self.n) - else: - raise CompilerError('loading long integers to cbits not supported') + n_limbs = math.ceil(self.n / self.unit) + tmp = regint(size=n_limbs) + for i in range(n_limbs): + tmp[i].load_int(value % 2 ** self.unit) + value >>= self.unit self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) @@ -1163,14 +1169,14 @@ class cbitfix(object): @classmethod def _new(cls, value): res = cls() + if cls.k < value.unit: + bits = value.bit_decompose(cls.k) + sign = bits[-1] + value += (sign << (cls.k)) * -1 res.v = value return res def output(self): v = self.v - if self.k < v.unit: - bits = self.v.bit_decompose(self.k) - sign = bits[-1] - v += (sign << (self.k)) * -1 inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 9871d97f..cf2f13ef 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -403,6 +403,20 @@ class Merger: add_edge(last_input[t][1], n) last_input[t][0] = n + def keep_text_order(inst, n): + if inst.get_players() is None: + # switch + for x in list(last_input.keys()): + if isinstance(x, int): + add_edge(last_input[x][0], n) + del last_input[x] + keep_merged_order(instr, n, None) + elif last_input[None][0] is not None: + keep_merged_order(instr, n, None) + else: + for player in inst.get_players(): + keep_merged_order(instr, n, player) + for n,instr in enumerate(block.instructions): outputs,inputs = instr.get_def(), instr.get_used() @@ -427,7 +441,7 @@ class Merger: # will be merged if isinstance(instr, TextInputInstruction): - keep_merged_order(instr, n, TextInputInstruction) + keep_text_order(instr, n) elif isinstance(instr, RawInputInstruction): keep_merged_order(instr, n, RawInputInstruction) @@ -479,10 +493,6 @@ class Merger: last_print_str = n elif isinstance(instr, PublicFileIOInstruction): keep_order(instr, n, instr.__class__) - elif isinstance(instr, startprivateoutput_class): - keep_order(instr, n, startprivateoutput_class, 2) - elif isinstance(instr, stopprivateoutput_class): - keep_order(instr, n, stopprivateoutput_class, 2) elif isinstance(instr, prep_class): keep_order(instr, n, instr.args[0]) elif isinstance(instr, StackInstruction): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a85fb25a..e0679768 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -421,6 +421,10 @@ class use_matmul(base.Instruction): code = base.opcodes['USE_MATMUL'] arg_format = ['int','int','int','int'] + @classmethod + def get_usage(cls, args): + return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i} + class run_tape(base.Instruction): """ Start tape/bytecode file in another thread. @@ -1229,15 +1233,20 @@ class inverse(base.DataInstruction): @base.gf2n @base.vectorize class inputmask(base.Instruction): - r""" Load secret $s_i$ with the next input mask for player $p$ and - write the mask on player $p$'s private output. """ + """ Store fresh random input mask(s) in secret register (vector) and clear + register (vector) of the relevant player. + + :param: mask (sint) + :param: mask (cint, player only) + :param: player (int) + """ __slots__ = [] code = base.opcodes['INPUTMASK'] - arg_format = ['sw', 'p'] + arg_format = ['sw', 'cw', 'p'] field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[1]), \ + req_node.increment((self.field_type, 'input', self.args[2]), \ self.get_size()) @base.vectorize @@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[1::2]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[1::2] @base.vectorize class inputfix(base.TextInputInstruction): @@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction): arg_format = tools.cycle(['sw', 'int', 'p']) field_type = 'modp' - def add_usage(self, req_node): - for player in self.args[2::3]: - req_node.increment((self.field_type, 'input', player), \ - self.get_size()) + def get_players(self): + return self.args[2::3] @base.vectorize class inputfloat(base.TextInputInstruction): @@ -1322,7 +1327,7 @@ class inputfloat(base.TextInputInstruction): req_node.increment((self.field_type, 'input', player), \ 4 * self.get_size()) -class inputmixed_base(base.TextInputInstruction): +class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction): __slots__ = [] field_type = 'modp' # the following has to match TYPE: (N_DEST, N_PARAM) @@ -1341,22 +1346,30 @@ class inputmixed_base(base.TextInputInstruction): type_id = self.type_ids[name] super(inputmixed_base, self).__init__(type_id, *args) - @property - def arg_format(self): - for i in self.bases(): - t = self.args[i] - yield 'int' + @classmethod + def dynamic_arg_format(self, args): + yield 'int' + for i, t in self.bases(iter(args)): for j in range(self.types[t][0]): yield 'sw' for j in range(self.types[t][1]): yield 'int' yield self.player_arg_type + yield 'int' - def bases(self): + @classmethod + def bases(self, args): i = 0 - while i < len(self.args): - yield i - i += sum(self.types[self.args[i]]) + 2 + while True: + try: + t = next(args) + except StopIteration: + return + yield i, t + n = sum(self.types[t]) + i += n + 2 + for j in range(n + 1): + next(args) @base.vectorize class inputmixed(inputmixed_base): @@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base): player_arg_type = 'p' def add_usage(self, req_node): - for i in self.bases(): - t = self.args[i] + for i, t in self.bases(iter(self.args)): player = self.args[i + sum(self.types[t]) + 1] n_dest = self.types[t][0] req_node.increment((self.field_type, 'input', player), \ n_dest * self.get_size()) + def get_players(self): + for i, t in self.bases(iter(self.args)): + yield self.args[i + sum(self.types[t]) + 1] + @base.vectorize class inputmixedreg(inputmixed_base): """ Store private input in secret registers (vectors). The input is @@ -1412,6 +1428,9 @@ class inputmixedreg(inputmixed_base): # player 0 as proxy req_node.increment((self.field_type, 'input', 0), float('inf')) + def get_players(self): + pass + @base.gf2n @base.vectorize class rawinput(base.RawInputInstruction, base.Mergeable): @@ -1433,7 +1452,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable): req_node.increment((self.field_type, 'input', player), \ self.get_size()) -class inputpersonal(base.Instruction, base.Mergeable): +class personal_base(base.Instruction, base.Mergeable): + __slots__ = [] + field_type = 'modp' + + def __init__(self, *args): + super(personal_base, self).__init__(*args) + for i in range(0, len(args), 4): + assert args[i + 2].size == args[i] + assert args[i + 3].size == args[i] + + def add_usage(self, req_node): + for i in range(0, len(self.args), 4): + player = self.args[i + 1] + req_node.increment((self.field_type, 'input', player), \ + self.args[i]) + +class inputpersonal(personal_base): """ Private input from cint. :param: vector size (int) @@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable): __slots__ = [] code = base.opcodes['INPUTPERSONAL'] arg_format = tools.cycle(['int','p','sw','c']) - field_type = 'modp' + +class privateoutput(personal_base): + """ Private input from cint. + + :param: vector size (int) + :param: player (int) + :param: destination (cint) + :param: source (sint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['PRIVATEOUTPUT'] + arg_format = tools.cycle(['int','p','cw','s']) + +class sendpersonal(base.Instruction, base.Mergeable): + """ Private input from cint. + + :param: vector size (int) + :param: destination player (int) + :param: destination (cint) + :param: source player (int) + :param: source (cint) + :param: (repeat from vector size)... + """ + __slots__ = [] + code = base.opcodes['SENDPERSONAL'] + arg_format = tools.cycle(['int','p','cw','p','c']) def __init__(self, *args): - super(inputpersonal, self).__init__(*args) - for i in range(0, len(args), 4): + super(sendpersonal, self).__init__(*args) + for i in range(0, len(args), 5): assert args[i + 2].size == args[i] - assert args[i + 3].size == args[i] - - def add_usage(self, req_node): - for i in range(0, len(self.args), 4): - player = self.args[i + 1] - req_node.increment((self.field_type, 'input', player), \ - self.args[i]) + assert args[i + 4].size == args[i] @base.gf2n @base.vectorize @@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction): code = base.opcodes['FLOATOUTPUT'] arg_format = ['p','c','c','c','c'] -@base.gf2n -@base.vectorize -class startprivateoutput(base.Instruction): - r""" Initiate private output to $n$ of $s_j$ via $s_i$. """ - __slots__ = [] - code = base.opcodes['STARTPRIVATEOUTPUT'] - arg_format = ['sw','s','p'] - field_type = 'modp' - - def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[2]), \ - self.get_size()) - -@base.gf2n -@base.vectorize -class stopprivateoutput(base.Instruction): - r""" Previously iniated private output to $n$ via $c_i$. """ - __slots__ = [] - code = base.opcodes['STOPPRIVATEOUTPUT'] - arg_format = ['cw','c','p'] - @base.vectorize class rand(base.Instruction): """ Store insecure random value of specified length in clear integer @@ -2210,7 +2244,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction): @base.gf2n @base.vectorize -class dotprods(base.VarArgsInstruction, base.DataInstruction): +class dotprods(base.VarArgsInstruction, base.DataInstruction, + base.DynFormatInstruction): """ Dot product of secret registers (vectors). Note that the vectorized version works element-wise. @@ -2238,31 +2273,29 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction): flat_args += [x, y] base.Instruction.__init__(self, *flat_args) - @property - def arg_format(self): + @classmethod + def dynamic_arg_format(self, args): field = 'g' if self.is_gf2n() else '' - for i in self.bases(): - yield 'int' + yield 'int' + for i, n in self.bases(args): yield 's' + field + 'w' - for j in range(self.args[i] - 2): + for j in range(n - 2): yield 's' + field + yield 'int' - gf2n_arg_format = arg_format - - def bases(self): - i = 0 - while i < len(self.args): - yield i - i += self.args[i] + @property + def gf2n_arg_format(self): + return self.arg_format() def get_repeat(self): - return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 + for i, n in self.bases(iter(self.args))) * self.get_size() def get_def(self): - return [self.args[i + 1] for i in self.bases()] + return [self.args[i + 1] for i, n in self.bases(iter(self.args))] def get_used(self): - for i in self.bases(): + for i, n in self.bases(iter(self.args)): for reg in self.args[i + 2:i + self.args[i]]: yield reg diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index fb2a67b8..d6c647ad 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -105,6 +105,7 @@ opcodes = dict( MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -128,6 +129,7 @@ opcodes = dict( INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -364,6 +366,7 @@ def gf2n(instruction): arg_format = copy.deepcopy(instruction_cls.arg_format) reformat(arg_format) + @classmethod def is_gf2n(self): return True @@ -505,8 +508,12 @@ def cisc(function): for arg in self.args: try: new_regs.append(type(arg)(size=size)) - except: + except TypeError: break + except: + print([call[0][0].size for call in self.calls]) + raise + assert len(new_regs) > 1 base = 0 for call in self.calls: for new_reg, reg in zip(new_regs[1:], call[0][1:]): @@ -854,6 +861,7 @@ class Instruction(object): def is_vec(self): return False + @classmethod def is_gf2n(self): return False @@ -902,6 +910,10 @@ class Instruction(object): new_args.append(arg) return new_args + @staticmethod + def get_usage(args): + return {} + # String version of instruction attempting to replicate encoded version def __str__(self): @@ -949,9 +961,18 @@ class ParsedInstruction: if name == 'cisc': arg_format = itertools.chain(['str'], itertools.repeat('int')) else: - arg_format = itertools.repeat('int') - self.args = [ArgFormats[next(arg_format)](f) - for i in range(n_args)] + def arg_iter(): + i = 0 + while True: + try: + yield self.args[i].i + except AttributeError: + yield None + i += 1 + arg_format = t.dynamic_arg_format(arg_iter()) + self.args = [] + for i in range(n_args): + self.args.append(ArgFormats[next(arg_format)](f)) def __str__(self): name = self.type.__name__ @@ -963,6 +984,9 @@ class ParsedInstruction: res += ', '.join(str(arg) for arg in self.args) return res + def get_usage(self): + return self.type.get_usage(self.args) + class VarArgsInstruction(Instruction): def has_var_args(self): return True @@ -974,6 +998,26 @@ class VectorInstruction(Instruction): def get_code(self): return super(VectorInstruction, self).get_code(len(self.args[0])) +class DynFormatInstruction(Instruction): + __slots__ = [] + + @property + def arg_format(self): + return self.dynamic_arg_format(iter(self.args)) + + @classmethod + def bases(self, args): + i = 0 + while True: + try: + n = next(args) + except StopIteration: + return + yield i, n + i += n + for j in range(n - 1): + next(args) + ### ### Basic arithmetic ### @@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction): """ Input from text file or stdin """ __slots__ = [] + def add_usage(self, req_node): + for player in self.get_players(): + req_node.increment((self.field_type, 'input', player), \ + self.get_size()) + ### ### Data access instructions ### diff --git a/Compiler/library.py b/Compiler/library.py index 4f6c2de1..3f31499b 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -223,7 +223,7 @@ def crash(condition=None): if isinstance(condition, localint): # allow crash on local values condition = condition._v - if condition == None: + if condition is None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -284,8 +284,8 @@ def get_arg(): def make_array(l): if isinstance(l, program.Tape.Register): - res = Array(1, type(l)) - res[0] = l + res = Array(len(l), type(l)) + res[:] = l else: l = list(l) res = Array(len(l), type(l[0]) if l else cint) @@ -1032,6 +1032,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], state = tuplify(initializer()) k = 0 block = get_block() + assert not isinstance(n_loops, int) or n_loops > 0 pre = copy.copy(loop_body.__globals__) while (not util.is_constant(n_loops) or k < n_loops) \ and (len(get_block()) < budget or k == 0) \ @@ -1211,7 +1212,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') - state = tuple(initializer()) + state = initializer() + if len(state) == 0: + state_type = cint + elif isinstance(state, (tuple, list)): + state_type = type(state[0]) + else: + state_type = type(state) def f(inc): base = args[get_arg()][0] if not util.is_constant(thread_rounds): @@ -1224,8 +1231,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) - mem_state = Array(len(state), type(state[0]) \ - if state else cint, args[get_arg()][1]) + mem_state = Array(len(state), state_type, args[get_arg()][1]) @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): @@ -1257,14 +1263,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) - if state: + if len(state): if thread_rounds: for i in range(n_threads - remainder): - state = reducer(Array(len(state), type(state[0]), \ + state = reducer(Array(len(state), state_type, \ args[remainder + i][1]), state) if remainder: for i in range(remainder): - state = reducer(Array(len(state), type(state[0]).reg_type, \ + state = reducer(Array(len(state), state_type, \ args[i][1]), state) def returner(): return untuplify(state) @@ -1300,6 +1306,39 @@ def map_sum_opt(n_threads, n_loops, types): """ return map_sum(n_threads, None, n_loops, len(types), types) +def map_sum_simple(n_threads, n_loops, type, size): + """ Vectorized multi-threaded sum reduction. The following computes a + 100 sums of ten squares in three threads:: + + @map_sum_simple(3, 10, sint, 100) + def summer(i): + return sint(regint.inc(100, i, 0)) ** 2 + + result = summer() + + :param n_threads: number of threads (int) + :param n_loops: number of loop runs (regint/cint/int) + :param type: return type, must match the return statement + in the loop + :param size: vector size, must match the return statement + in the loop + + """ + initializer = lambda: type(0, size=size) + def summer(*args): + assert len(args) == 2 + args = list(args) + for i in (0, 1): + if isinstance(args[i], tuple): + assert len(args[i]) == 1 + args[i] = args[i][0] + for i in (0, 1): + assert len(args[i]) == size + if isinstance(args[i], Array): + args[i] = args[i][:] + return args[0] + args[1] + return map_reduce(n_threads, 1, n_loops, initializer, summer) + def tree_reduce_multithread(n_threads, function, vector): inputs = vector.Array(len(vector)) inputs.assign_vector(vector) diff --git a/Compiler/ml.py b/Compiler/ml.py index 5c4664be..c521934f 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -223,6 +223,7 @@ class Layer: thetas = lambda self: () debug_output = False back_batch_size = 128 + print_random_update = False @property def shape(self): @@ -254,6 +255,9 @@ class Layer: def __str__(self): return type(self).__name__ + str(self._Y.sizes) + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.Y.sizes) + class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None output_weights = lambda *args: None @@ -459,6 +463,10 @@ class MultiOutput(MultiOutputBase): self.debug = debug self.true_X = sfix.Array(N) + def __repr__(self): + return '%s(%s, %s, approx=%s)' % \ + (type(self).__name__, self.N, self.d_out, self.approx) + def _forward(self, batch): N = len(batch) d_out = self.X.sizes[1] @@ -609,10 +617,11 @@ class DenseBase(Layer): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + @multithread(self.n_threads, self.d_in) def _(base, size): - A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), batch.get_vector(), @@ -622,16 +631,24 @@ class DenseBase(Layer): progress('nabla W (matmul)') - if self.d_in * self.d_out < 200000: - print('reduce at once') - @multithread(self.n_threads, self.d_in * self.d_out) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - else: - @for_range_opt(self.d_in) - def _(i): - self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() + @multithread(self.n_threads, self.d_in * self.d_out, + max_size=get_program().budget) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % self.d_in + j = regint.get_random(64) % self.d_out + print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', + str(self.nabla_W), i, j, tmp[i][j].v.reveal(), + self.nabla_W[i][j].reveal(), + A.get_column(j).reveal(), + B.get_column_by_row_indices( + batch.get_vector(), i).reveal()) + print_ln('batch=%s B=%s', batch, + [self.X[bi][0][i].reveal() for bi in batch]) progress('nabla W') @@ -699,6 +716,7 @@ class Dense(DenseBase): self.d_in = d_in self.d_out = d_out self.d = d + self.activation = activation self.X = MultiArray([N, d, d_in], sfix) self.Y = MultiArray([N, d, d_out], sfix) @@ -721,12 +739,17 @@ class Dense(DenseBase): else: self.f_input = self.Y + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) + self.W.randomize(-r, r) self.b.assign_all(0) def input_from(self, player, raw=False): @@ -820,6 +843,12 @@ class Dense(DenseBase): regint.inc(self.d_in))), base) + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + progress('nabla X') self.backward_params(f_schur_Y, batch=batch) @@ -890,6 +919,10 @@ class Dropout(NoVariableLayer): self.alpha = alpha self.B = MultiArray([N, d1, d2], sint) + def __repr__(self): + return '%s(%s, %s, alpha=%s)' % \ + (type(self).__name__, self.N, self.d1, self.alpha) + def forward(self, batch, training=False): if training: n_bits = -math.log(self.alpha, 2) @@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer): def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), padding='VALID'): assert len(shape) == 4 + assert min(shape) > 0, shape for x in strides, ksize: for i in 0, 3: assert x[i] == 1 @@ -1033,12 +1067,18 @@ class MaxPool(NoVariableLayer): self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize + self.padding = padding self.nabla_X = Tensor(shape, sfix) self.nabla_Y = Tensor(output_shape, sfix) self.N = shape[0] self.comparisons = MultiArray([self.N, self.X.sizes[3], ksize[1] * ksize[2]], sint) + def __repr__(self): + return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \ + (type(self).__name__, self.X.sizes, self.strides, + self.ksize, self.padding) + def _forward(self, batch): def process(pool, bi, k, i, j): def m(a, b): @@ -1165,7 +1205,7 @@ class Add(NoVariableLayer): self.Y[batch[0]].assign_vector(tmp, base) class FusedBatchNorm(Layer): - """ Fixed-point fused batch normalization layer. + """ Fixed-point fused batch normalization layer (inference only). :param shape: input/output shape (tuple/list of four int) """ @@ -1192,6 +1232,153 @@ class FusedBatchNorm(Layer): self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() + self.bias.get_vector()) +class BatchNorm(Layer): + """ Fixed-point batch normalization layer. + + :param shape: input/output shape (tuple/list of four int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=True, args=None): + assert len(shape) in (2, 3, 4) + if len(shape) == 4: + shape = [shape[0], shape[1] * shape[2], shape[3]] + elif len(shape) == 2: + shape = [shape[0], 1, shape[1]] + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.var, self.mu, self.weights, self.bias = arrays + arrays = (sfix.Array(shape[2]) for i in range(4)) + self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays + self.epsilon = 2 ** (-sfix.f + 1) + self.momentum = 0.1 + if args != None: + approx = 'precisebn' not in args + self.approx = approx + if approx: + print('Approximate square root inverse in batch normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in batch normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): + self.bias.assign_all(0) + self.weights.assign_all(1) + self.mu_hat.assign_all(0) + self.var_hat.assign_all(0) + + def _output(self, batch, mu, var): + factor = sfix.Array(len(mu)) + factor[:] = self.InvertSqrt(var[:] + self.epsilon) + @for_range_opt_multithread(self.n_threads, + [len(batch), self.X.sizes[1]]) + def _(i, j): + tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:] + self.Y[i][j][:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + if training: + d = self.X.sizes[1] + d_in = self.X.sizes[2] + s = sfix.Array(d_in) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.X[batch[i]][j].get_vector()) + s.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.mu.assign_vector( + s.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + item = self.X[batch[i]][j].get_vector() + return ((item - self.mu[:]) ** 2) + self.var.assign(_()) + @multithread(self.n_threads, d_in) + def _(base, size): + self.var.assign_vector( + self.var.get_vector(base, size) / (len(batch) * d - 1), + base) + for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var): + x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:] + self._output(batch, self.mu, self.var) + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + else: + self._output(batch, self.mu_hat, self.var_hat) + + def backward(self, batch, compute_nabla_X=True): + factor = Array.create_from( + self.InvertSqrt(self.var[:] + self.epsilon)) + mynYf = self.X.same_shape() + gamnY = self.X.same_shape() + gamnYd = self.X.same_shape() + nYdf = self.X.same_shape() + d = self.X.sizes[1] + d_in = self.X.sizes[2] + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + tmp = self.weights[:] * self.nabla_Y[i][j][:] + gamnY[i][j] = tmp + gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:]) + mynYf[i][j] = tmp * factor[:] + nYdf[i][j] = self.nabla_Y[i][j][:] * \ + (self.X[i][j][:] - self.mu[:]) * factor[:] + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (self.nabla_Y[i][j][:]) + self.nabla_bias.assign(_()) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (nYdf[i][j]) + self.nabla_weights.assign(_()) + factor3 = Array.create_from(factor[:] ** 3) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (mynYf[i][j]) + s1 = Array.create_from(_()) + @multithread(self.n_threads, len(s1)) + def _(base, size): + s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base) + @map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in) + def _(i, j): + return (gamnYd[i][j][:] * factor3[:]) + s2 = Array.create_from(_()) + @multithread(self.n_threads, len(s2)) + def _(base, size): + s2.assign_vector( + s2.get_vector(base, size) / (len(batch) * d - 1), base) + @for_range_opt_multithread(self.n_threads, [len(batch), d]) + def _(i, j): + self.nabla_X[i][j][:] = mynYf[i][j][:] \ + - s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:] + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.nabla_bias, self.nabla_weights: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k, + self.nabla_Y[i][j][k].reveal(), + self.nabla_X[i][j][k].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1298,6 +1485,8 @@ class ConvBase(BaseLayer): self.padding.append(pad_total // 2) elif padding == 'VALID': self.padding = [0, 0] + elif isinstance(padding, int): + self.padding = [padding, padding] else: self.padding = padding @@ -1323,6 +1512,12 @@ class ConvBase(BaseLayer): assert(len(output_shape) == 4) assert(len(weight_shape) == 4) + def __repr__(self): + return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \ + (type(self).__name__, self.X.sizes, self.weight_shape, + self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), + self.tf_weight_format) + def input_from(self, player, raw=False): self.input_params_from(player) self.weights.input_from(player, budget=100000, raw=raw) @@ -1545,20 +1740,20 @@ class FixConv2d(Conv2d, FixBase): self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) if compute_nabla_X: - assert tuple(self.padding) == (0, 0) assert tuple(self.stride) == (1, 1) reverse_weights = MultiArray( [n_channels_in, weights_h, weights_w, n_channels_out], sfix) - @for_range(n_channels_out) - def _(i): + @for_range_opt_multithread(self.n_threads, n_channels_in) + def _(l): @for_range(weights_h) def _(j): @for_range(weights_w) def _(k): - @for_range(n_channels_in) - def _(l): - reverse_weights[l][weights_h-j-1][k][i] = \ - self.weights[i][j][weights_w-k-1][l] + addresses = regint.inc(n_channels_out, + self.weights[0][j][weights_w-k-1].get_address(l), + reduce(operator.mul, self.weights.sizes[1:])) + reverse_weights[l][weights_h-j-1][k].assign_vector( + self.weights.value_type.load_mem(addresses)) padded_w = inputs_w + 2 * padding_w padded_h = inputs_h + 2 * padding_h if padding_h or padding_w: @@ -1579,14 +1774,16 @@ class FixConv2d(Conv2d, FixBase): unreduced_sfix._new(res).reduce_after_mul(), i, None, None, j) if padding_h or padding_w: - @for_range(N) + @for_range_opt_multithread(self.n_threads, N) def _(i): @for_range(inputs_h) def _(j): @for_range(inputs_w) def _(k): + jj = j + padding_w + kk = k + padding_w self.nabla_X[i][j][k].assign_vector( - output[i][j][k].get_vector()) + output[i][jj][kk].get_vector()) if self.debug_output: @for_range(len(batch)) @@ -1806,6 +2003,7 @@ class Optimizer: self.report_loss = report_loss self.X_by_label = None self.print_update_average = False + self.print_random_update = False self.print_losses = False self.print_loss_reduction = False self.i_epoch = MemValue(0) @@ -1846,6 +2044,7 @@ class Optimizer: def batch_for(self, layer, batch): if layer in (self.layers[0], self.layers[-1]): + assert not isinstance(layer, BatchNorm) return batch else: batch = regint.Array(len(batch)) @@ -1876,6 +2075,21 @@ class Optimizer: if i != len(self.layers) - 1 or run_last: layer.forward(batch=self.batch_for(layer, batch), training=training) + if self.print_random_update: + print_ln('forward layer %s', layer) + l = min(100, layer.Y[i].total_size()) + i = regint.get_random(64) % len(batch) + if l < 100: + j = 0 + else: + j = regint.get_random(64) % \ + (layer.Y[i].total_size() - l) + print_ln('forward layer %s at (%s, %s): %s', layer, i, j, + layer.Y[i].to_array().get_vector(j, l).reveal()) + i = regint.get_random(64) % layer.Y[0].total_size() + print_ln('forward layer %s vertical at %s: %s', layer, i, + [layer.Y[j].to_array()[i].reveal() + for j in range(len(batch))]) if self.time_layers: stop_timer(100 + i) break_point() @@ -1979,7 +2193,11 @@ class Optimizer: label * n) self.forward(batch=batch, training=True) self.backward(batch=batch) + if self.time_layers: + start_timer(1000) self.update(i, batch=batch) + if self.time_layers: + stop_timer(1000) loss_sum.iadd(self.layers[-1].l) if self.print_loss_reduction: before = self.layers[-1].average_loss(N) @@ -2070,6 +2288,8 @@ class Optimizer: if 'nomom' in program.args: self.momentum = 0 self.print_losses = 'print_losses' in program.args + self.print_random_update = 'print_random_update' in program.args + Layer.print_random_update = self.print_random_update self.time_layers = 'time_layers' in program.args self.revealing_correctness = not 'no_acc' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args @@ -2099,6 +2319,16 @@ class Optimizer: print_ln('loss %s', self.layers[-1].l.reveal()) self.output_weights() return + if 'bench10' in program.args or 'bench1' in program.args: + n = 1 if 'bench1' in program.args else 10 + print('benchmarking %s iterations' % n) + @for_range(n) + def _(i): + batch = Array.create_from(regint.inc(batch_size)) + self.forward(batch=batch, training=True) + self.backward(batch=batch) + self.update(0, batch=batch) + return @for_range(n_runs) def _(i): if not acc_first: @@ -2115,6 +2345,7 @@ class Optimizer: cfix(self.n_correct, k=63, f=31) / n_trained, self.n_correct, n_trained) if test_X and test_Y: + print('use test set') n_test = len(test_Y) n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size) @@ -2211,7 +2442,8 @@ class Adam(Optimizer): util.max, abs_g.get_vector()) scale = MemValue(sfix._new(library.AppRcr( max_g.v, max_g.k, max_g.f, simplex_flag=True))) - @multithread(self.n_threads, m.total_size()) + @multithread(self.n_threads, m.total_size(), + max_size=get_program().budget) def _(base, size): m_part = m.get_vector(base, size) v_part = v.get_vector(base, size) @@ -2333,20 +2565,33 @@ class SGD(Optimizer): print_ln_if((x > limit) + (x < -limit), 'theta epoch=%s %s index=%s %s', i_epoch.read(), str(theta), i, x) - index = regint.get_random(64) % len(a) - print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, - aa[1][index], aa[0][index], aa[2][index]) + if self.print_random_update: + print_ln('update') + l = min(100, nabla.total_size()) + if l < 100: + index = 0 + else: + index = regint.get_random(64) % (nabla.total_size() - l) + print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), + index, nabla.to_array().get_vector(index, l).reveal(), + delta_theta.to_array().get_vector(index, l).reveal(), + theta.to_array().get_vector(index, l).reveal()) self.gamma.imul(1 - 10 ** - 6) def apply_padding(input_shape, kernel_size, strides, padding): + if isinstance(padding, int): + input_shape = [x + 2 * padding for x in input_shape] + padding = 'valid' if padding == 'valid': - return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ + res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \ (input_shape[1] - kernel_size[1] + 1) // strides[1], + assert min(res) > 0, (input_shape, kernel_size, strides, padding) + return res elif padding == 'same': - return (input_shape[1]) // strides[0], \ - (input_shape[2]) // strides[1], + return (input_shape[0]) // strides[0], \ + (input_shape[1]) // strides[1], else: - raise Exception('invalid padding: ' + padding) + raise Exception('invalid padding: %s' % padding) class keras: class layers: @@ -2354,7 +2599,7 @@ class keras: Dense = lambda *args, **kwargs: ('dense', args, kwargs) def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', - activation=None): + activation=None, input_shape=None): return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, 'strides': strides, 'padding': padding, 'activation': activation} @@ -2369,6 +2614,13 @@ class keras: raise Exception('rate needs to be a power of two') return 'dropout', rate + def Activation(activation): + assert(activation == 'relu') + return activation, + + def BatchNormalization(): + return 'batchnorm', + class optimizers: SGD = lambda *args, **kwargs: ('sgd', args, kwargs) Adam = lambda *args, **kwargs: ('adam', args, kwargs) @@ -2383,12 +2635,25 @@ class keras: def compile(self, optimizer): self.optimizer = optimizer + def compile_by_args(self, program): + if 'adam' in program.args: + self.optimizer = 'adam', [], {} + elif 'amsgrad' in program.args: + self.optimizer = 'adam', [], {'amsgrad': True} + else: + self.optimizer = 'sgd', [], {} + @property def trainable_variables(self): if self.opt == None: raise Exception('need to run build() or fit() first') return list(self.opt.thetas) + def summary(self): + sizes = [var.total_size() for var in self.trainable_variables] + print(sizes) + print('Trainable params:', sum(sizes)) + def build(self, input_shape, batch_size=128): data_input_shape = input_shape if self.opt != None and \ @@ -2415,12 +2680,11 @@ class keras: if i == len(self.layers) - 1: if layer[2].get('activation', 'softmax') in \ ('softmax', 'sigmoid'): - del layer[2]['activation'] + layer[2].pop('activation', None) layers.append(Dense(N, n_units, layer[1][0], **layer[2])) + input_shape = layers[-1].Y.sizes elif name == 'conv2d': - if len(layers) != 0: - input_shape = layers[-1].Y.sizes input_shape = list(input_shape) + \ [1] * (4 - len(input_shape)) print (layer[1]) @@ -2437,9 +2701,13 @@ class keras: output_shape = [batch_size] + list( apply_padding(input_shape[1:3], kernel_size, strides, padding)) + [filters] + padding = padding.upper() if isinstance(padding, str) \ + else padding layers.append(FixConv2d(input_shape, weight_shape, (filters,), output_shape, - strides, padding.upper())) + strides, padding)) + input_shape = output_shape + print('conv output shape', output_shape) elif name == 'maxpool': pool_size = layer[1]['pool_size'] strides = layer[1]['strides'] @@ -2450,16 +2718,23 @@ class keras: strides = (strides, strides) if strides == None: strides = pool_size - layers.append(MaxPool(layers[-1].Y.sizes, + layers.append(MaxPool(input_shape, [1] + list(strides) + [1], [1] + list(pool_size) + [1], - padding.upper())) + padding)) + input_shape = layers[-1].Y.sizes elif name == 'dropout': layers.append(Dropout(batch_size, reduce( operator.mul, layers[-1].Y.sizes[1:]), alpha=layer[1])) + input_shape = layers[-1].Y.sizes elif name == 'flatten': pass + elif name == 'relu': + layers.append(Relu(layers[-1].Y.sizes)) + elif name == 'batchnorm': + input_shape = layers[-1].Y.sizes + layers.append(BatchNorm(layers[-1].Y.sizes)) else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: diff --git a/Compiler/oram.py b/Compiler/oram.py index 443d826c..543fc4aa 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1493,6 +1493,7 @@ class PackedIndexStructure(object): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) + print_ln('packed ORAM init done') print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple diff --git a/Compiler/program.py b/Compiler/program.py index 5dad8e51..36672330 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -580,10 +580,19 @@ class Program(object): @staticmethod def read_tapes(schedule): + m = re.search(r'([^/]*)\.mpc', schedule) + if m: + schedule = m.group(1) if not os.path.exists(schedule): schedule = 'Programs/Schedules/%s.sch' % schedule - lines = open(schedule).readlines() + try: + lines = open(schedule).readlines() + except FileNotFoundError: + print('%s not found, have you compiled the program?' % schedule, + file=sys.stderr) + sys.exit(1) + for tapename in lines[2].split(' '): yield tapename.strip() diff --git a/Compiler/types.py b/Compiler/types.py index 0063fdc1..1dbe1f90 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1675,6 +1675,13 @@ class localint(Tape._no_truth): __ne__ = lambda self, other: localint(self._v != other) class personal(Tape._no_truth): + """ Value known to one player. Supports operations with public + values and personal values known to the same player. Can be used + with :py:func:`~Compiler.library.print_ln_to`. + + :param player: player (int) + :param value: cleartext value (cint, cfix, cfloat) or array thereof + """ def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -1685,8 +1692,24 @@ class personal(Tape._no_truth): self._v = value def binary_output(self): + """ Write binary output to + ``Player-Data/Binary-Output-P-`` if + supported by underlying type. Player must be known at compile time.""" self._v.binary_output(self.player) + def reveal_to(self, player): + """ Pass personal value to another player. """ + if isinstance(self._v, Array): + source = self._v[:] + else: + source = self._v + source = cint.conv(source) + res = cint(size=source.size) + sendpersonal(source.size, player, res, self.player, source) + if isinstance(self._v, Array): + res = Array.create_from(res) + return personal(player, res) + def bit_decompose(self, length): return [personal(self.player, x) for x in self._v.bit_decompose(length)] @@ -1858,8 +1881,13 @@ class _secret(_register, _secret_structure): @vectorized_classmethod @set_instruction_type def get_random_input_mask_for(cls, player): - res = cls() - inputmask(res, player) + """ Secret random input mask according to security model. + + :return: mask (sint), mask (personal cint) + :param size: vector size (int, default 1) + """ + res = cls(), personal(player, cls.clear_type()) + inputmask(res[0], res[1]._v, player) return res @classmethod @@ -2071,15 +2099,13 @@ class _secret(_register, _secret_structure): @set_instruction_type def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result written to ``Player-Data/Private-Output-P`` :param player: int - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ - masked = self.__class__() - res = personal(player, self.clear_type()) - startprivateoutput(masked, self, player) - stopprivateoutput(res._v, masked.reveal(), player) + mask = self.get_random_input_mask_for(player) + masked = self + mask[0] + res = personal(player, masked.reveal() - mask[1]) return res @@ -2633,21 +2659,20 @@ class sint(_secret, _int): @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result potentially written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. - :param player: public integer (int/regint/cint): - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :param player: public integer (int/regint/cint) + :returns: :py:class:`personal` """ - if not util.is_constant(player) or self.size > 1: + if not util.is_constant(player): secret_mask = sint() player_mask = cint() inputmaskreg(secret_mask, player_mask, regint.conv(player)) return personal(player, (self + secret_mask).reveal() - player_mask) else: - return super(sint, self).reveal_to(player) + res = personal(player, self.clear_type()) + privateoutput(self.size, player, res._v, self) + return res def private_division(self, divisor, active=True, dividend_length=None, divisor_length=None): @@ -4366,12 +4391,9 @@ class sfix(_fix): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Raw representation possibly written to - ``Player-Data/Private-Output-P``, but not if - :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint) - :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` + :returns: :py:class:`personal` """ return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) @@ -5221,6 +5243,9 @@ class Array(_vectorizable): return self.assign(value, addresses) self._store(value, self.get_address(index)) + def to_array(self): + return self + def get_sub(self, start, stop=None): if stop is None: stop = start @@ -5471,6 +5496,10 @@ class Array(_vectorizable): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def randomize(self, *args): + """ Randomize according to data type. """ + self.assign_vector(self.value_type.get_random(*args, size=len(self))) + def reveal(self): """ Reveal the whole array. @@ -5596,6 +5625,9 @@ class SubMultiArray(_vectorizable): def __iter__(self): return (self[i] for i in range(len(self))) + def to_array(self): + return Array(self.total_size(), self.value_type, address=self.address) + def assign_all(self, value): """ Assign the same value to all entries. @@ -5958,6 +5990,7 @@ class SubMultiArray(_vectorizable): """ assert len(self.sizes) == 2 assert len(other.sizes) == 2 + assert other.address != None if indices is None: assert self.sizes[1] == other.sizes[1] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] @@ -6145,6 +6178,16 @@ class SubMultiArray(_vectorizable): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def randomize(self, *args): + """ Randomize according to data type. """ + if self.total_size() < program.options.budget: + self.assign_vector( + self.value_type.get_random(*args, size=self.total_size())) + else: + @library.for_range(self.sizes[0]) + def _(i): + self[i].randomize(*args) + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6251,6 +6294,22 @@ class Matrix(MultiArray): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + def get_column(self, index): + """ Get column as vector. + + :param index: regint/cint/int + """ + assert self.value_type.n_elements() == 1 + addresses = regint.inc(self.sizes[0], self.address + index, + self.sizes[1]) + return self.value_type.load_mem(addresses) + + def get_column_by_row_indices(self, rows, column): + assert self.value_type.n_elements() == 1 + addresses = rows * self.sizes[1] + \ + regint.inc(len(rows), self.address + column, 0) + return self.value_type.load_mem(addresses) + def set_column(self, index, vector): """ Change column. diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 2a4d6b12..20dfb1bb 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G) } void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) +{ + Rq_Element a(*this); + a.randomize(G); + partial_key_gen(sk, a, G, noise_boost); +} + +void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost) { FHE_PK& PK = *this; - // Generate the main public key - PK.a0.randomize(G); + a0 = a; // b0=a0*s+p*e0 Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); @@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) mul(es,es,PK.pr); add(PK.Sw_b,PK.Sw_b,es); - // Lowering level as we only decrypt at level 0 - sk.lower_level(); - // bs=bs-p1*s^2 Rq_Element s2; mul(s2,sk,sk); // Mult at level 0 @@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, template void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) { - check(*params, pk, pr); + check(*params, pk, FieldD.get_prime()); pk.check_noise(*this); if (decrypt(pk.encrypt(Plaintext_(FieldD)), FieldD) != Plaintext_(FieldD)) diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 72a7ddfa..30ecc292 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -150,6 +150,8 @@ class FHE_PK Rq_Element sample_secret_key(PRNG& G); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); + void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, + int noise_boost = 1); void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 8ae6c288..0de8bb1e 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -3,6 +3,11 @@ #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +FHE_Params::FHE_Params(int n_mults) : + FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1) +{ +} + void FHE_Params::set(const Ring& R, const vector& primes) { @@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec) throw runtime_error("distributed decryption bound is zero"); } +void FHE_Params::set_matrix_dim(int matrix_dim) +{ + assert(matrix_dim > 0); + if (FFTData[0].get_prime() != 0) + throw runtime_error("cannot change matrix dimension after parameter generation"); + this->matrix_dim = matrix_dim; +} + bigint FHE_Params::Q() const { bigint res = FFTData[0].get_prime(); @@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const Chi.pack(o); Bval.pack(o); o.store(sec_p); + o.store(matrix_dim); } void FHE_Params::unpack(octetStream& o) @@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o) Chi.unpack(o); Bval.unpack(o); o.get(sec_p); + o.get(matrix_dim); } bool FHE_Params::operator!=(const FHE_Params& other) const diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8ac40083..9407b0ba 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -26,10 +26,11 @@ class FHE_Params // Data for distributed decryption int sec_p; bigint Bval; + int matrix_dim; public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} + FHE_Params(int n_mults = 1); int n_mults() const { return FFTData.size() - 1; } @@ -37,6 +38,9 @@ class FHE_Params void set(const vector& primes); void set_sec(int sec); + void set_matrix_dim(int matrix_dim); + int get_matrix_dim() const { return matrix_dim; } + const vector& FFTD() const { return FFTData; } const bigint& p0() const { return FFTData[0].get_prime(); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index c6e294a6..7c46a74f 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2) template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FFT_Data& FTD, bool round_up) + FHE_Params& params, FFT_Data& FTD, bool round_up, int n) { int m = 1024; int lgp = plaintext_length; @@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec, while (true) { tmp_params = params; - SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params); bigint p1 = 2 * p * m, p0 = p; while (nb.min_p0(params.n_mults() > 0, p1) > p0) @@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec, template <> int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, P2Data& P2D, bool round_up) + FHE_Params& params, P2Data& P2D, bool round_up, int n) { if (params.n_mults() > 0) throw runtime_error("only implemented for 0-level BGV"); gf2n_short::init_field(plaintext_length); int m; char_2_dimension(m, plaintext_length); - SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec, + SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec, numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); int lgp0 = numBits(nb.min_p0(false, 0)); int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up); @@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 16: + m = 13107; + break; default: throw runtime_error("field size not supported"); break; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index c0a2ecfe..acaba70b 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2, // semi-homomorphic, includes slack template int generate_semi_setup(int plaintext_length, int sec, - FHE_Params& params, FD& FieldD, bool round_up); + FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index 7ab8e517..f2e151c4 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); + int matrix_dim = params.get_matrix_dim(); #ifdef NOISY cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; @@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "log(slack): " << slack << endl; cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; + cout << "matrix dimension: " << matrix_dim << endl; #endif - drown = 1 + n * (bigint(1) << sec); + assert(matrix_dim > 0); + drown = 1 + matrix_dim * n * (bigint(1) << sec); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 812560a3..554d4dc1 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -50,6 +50,7 @@ void Ring_Element::prepare_push() void Ring_Element::allocate() { + assert(FFTD); element.resize(FFTD->phi_m()); } diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index af7a664b..531df90f 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b) } } +void Rq_Element::add(octetStream& os) +{ + Rq_Element tmp(*this); + tmp.unpack(os); + *this += tmp; +} + void Rq_Element::randomize(PRNG& G,int l) { set_level(l); @@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p) // Now add delta back onto a0 Rq_Element bb(b0,b1); - add(*this,*this,bb); + ::add(*this,*this,bb); // Now divide by p1 mod p0 modp p1_inv,pp; diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index d5e71841..a58cb7de 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -93,12 +93,14 @@ protected: friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); + void add(octetStream& os); + template Rq_Element& operator+=(const vector& other); - Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; } + Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; } - Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; } + Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; } Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; } template Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; } @@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector& other) { Rq_Element tmp = *this; tmp.from(Iterator(other), lev); - add(*this, *this, tmp); + ::add(*this, *this, tmp); return *this; } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 0f5d1fe8..48a8a6ef 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -203,7 +203,7 @@ template void PartSetup::secure_init(Player& P, MachineBase& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); } template diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 732904b3..92632002 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -130,6 +130,13 @@ void Multiplier::report_size(ReportType type, MemoryUsage& res) res += memory_usage; } +template +const vector& Multiplier::get_multiplicands( + const vector >& others_ct, const FHE_PK&) +{ + return others_ct[P.get_full_player().get_player(-P.get_offset())]; +} + template class Multiplier; template class Multiplier; diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index e2e1ce66..9ab517a6 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -55,6 +55,9 @@ public: size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); size_t report_volatile() { return volatile_capacity; } + + const vector& get_multiplicands( + const vector>& others_ct, const FHE_PK&); }; #endif /* FHEOFFLINE_MULTIPLIER_H_ */ diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index bba83b5f..047c84f2 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -9,6 +9,7 @@ #include "Math/Setup.h" #include "FHEOffline/Proof.h" #include "FHEOffline/PairwiseMachine.h" +#include "FHEOffline/TemiSetup.h" #include "Tools/Commit.h" #include "Tools/Bundle.h" #include "Processor/OnlineOptions.h" @@ -53,7 +54,7 @@ void PairwiseSetup::init(const Player& P, int sec, int plaintext_length, template void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) { - ::secure_init(*this, P, machine, plaintext_length, sec); + ::secure_init(*this, P, machine, plaintext_length, sec, params); alpha = FieldD; machine.sk = FHE_SK(params, FieldD.get_prime()); for (auto& pk : machine.other_pks) @@ -62,13 +63,14 @@ void PairwiseSetup::secure_init(Player& P, PairwiseMachine& machine, int pla template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec) + int plaintext_length, int sec, FHE_Params& params) { machine.sec = sec; sec = max(sec, 40); machine.drown_sec = sec; string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); @@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine, { cout << "Finding parameters for security " << sec << " and field size ~2^" << plaintext_length << endl; - setup.params = setup.params.n_mults(); setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; @@ -208,5 +209,8 @@ void PairwiseSetup::set_alphai(T alphai) template class PairwiseSetup; template class PairwiseSetup; -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); -template void secure_init(PartSetup&, Player&, MachineBase&, int, int); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(PartSetup&, Player&, MachineBase&, int, int, FHE_Params& params); + +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); +template void secure_init(TemiSetup&, Player&, MachineBase&, int, int, FHE_Params& params); diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index 8e16eaf3..f6482ede 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -15,7 +15,7 @@ class MachineBase; template void secure_init(T& setup, Player& P, U& machine, - int plaintext_length, int sec); + int plaintext_length, int sec, FHE_Params& params); template class PairwiseSetup diff --git a/FHEOffline/SimpleDistDecrypt.cpp b/FHEOffline/SimpleDistDecrypt.cpp index 3774cd3c..c8b92312 100644 --- a/FHEOffline/SimpleDistDecrypt.cpp +++ b/FHEOffline/SimpleDistDecrypt.cpp @@ -18,7 +18,12 @@ void SimpleDistDecrypt::reshare(Plaintext& EC) { (void)EC; + m = reshare(cm); +} +template +Plaintext_ SimpleDistDecrypt::reshare(const Ciphertext& cm) +{ PRNG G; G.ReSeed(); this->f.randomize(G, Full); @@ -27,10 +32,13 @@ void SimpleDistDecrypt::reshare(Plaintextrun(cm); // Step 4 + Plaintext_ m(this->f.get_field()); if (this->P.my_num()==0) { sub(m,this->mf,this->f); } else { m=this->f; m.negate(); } + + return m; } diff --git a/FHEOffline/SimpleDistDecrypt.h b/FHEOffline/SimpleDistDecrypt.h index 9589f15a..c929a799 100644 --- a/FHEOffline/SimpleDistDecrypt.h +++ b/FHEOffline/SimpleDistDecrypt.h @@ -20,6 +20,7 @@ public: void reshare(Plaintext& m, const Ciphertext& cm, EncCommitBase& EC); + Plaintext_ reshare(const Ciphertext& cm); }; #endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */ diff --git a/FHEOffline/TemiSetup.cpp b/FHEOffline/TemiSetup.cpp new file mode 100644 index 00000000..fc222ed5 --- /dev/null +++ b/FHEOffline/TemiSetup.cpp @@ -0,0 +1,59 @@ +/* + * TemiSetup.cpp + * + */ + +#include "TemiSetup.h" +#include "PairwiseSetup.h" +#include "FHE/NTL-Subs.h" +#include "Protocols/HemiOptions.h" + +template +TemiSetup::TemiSetup() +{ + this->params = FHE_Params(0); + this->pk = {this->params, 0}; + this->sk = {this->params, 0}; + this->calpha = this->params; + this->params.set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); +} + +template +void TemiSetup::secure_init(Player& P, int plaintext_length) +{ + MachineBase machine; + ::secure_init(*this, P, machine, plaintext_length, 0, this->params); +} + +template +void TemiSetup::generate(Player& P, MachineBase&, + int plaintext_length, int sec) +{ + generate_semi_setup(plaintext_length, sec, this->params, this->FieldD, + false, P.num_players()); + this->sk = {this->params, this->FieldD.get_prime()}; + this->pk = {this->params, this->FieldD.get_prime()}; +} + +template +void TemiSetup::key_and_mac_generation(Player& P, MachineBase&, int, + true_type) +{ + Rq_Element a(this->params); + GlobalPRNG GG(P); + a.randomize(GG); + SeededPRNG G; + auto sk = this->pk.sample_secret_key(G); + this->sk.assign(sk); + this->pk.partial_key_gen(sk, a, G); + TreeSum ts; + vector pks; + pks.push_back(this->pk.b()); + ts.run(pks, P); + this->pk.assign(this->pk.a(), pks[0]); +} + +template class TemiSetup; +template class TemiSetup; diff --git a/FHEOffline/TemiSetup.h b/FHEOffline/TemiSetup.h new file mode 100644 index 00000000..483cb0ee --- /dev/null +++ b/FHEOffline/TemiSetup.h @@ -0,0 +1,34 @@ +/* + * TemiSetup.h + * + */ + +#ifndef FHEOFFLINE_TEMISETUP_H_ +#define FHEOFFLINE_TEMISETUP_H_ + +#include "FHE/FHE_Keys.h" +#include "FHEOffline/SimpleMachine.h" + +template +class TemiSetup : public PartSetup +{ +public: + static string name() + { + return "TemiParams"; + } + + static string protocol_name(int) + { + return "Temi"; + } + + TemiSetup(); + + void secure_init(Player& P, int plaintext_length); + void generate(Player& P, MachineBase&, int plaintext_length, int sec); + + void key_and_mac_generation(Player& P, MachineBase&, int, true_type); +}; + +#endif /* FHEOFFLINE_TEMISETUP_H_ */ diff --git a/GC/Memory.h b/GC/Memory.h index 359677a2..006a91d9 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -47,11 +47,11 @@ inline void Memory::check_index(Integer index) const ss << T::type_string() << " memory overflow: " << i << "/" << vector::size(); throw Processor_Error(ss.str()); } -#endif #ifdef DEBUG_MEMORY cout << typeid(T).name() << " at " << this << " index " << i << ": " << vector::operator[](i) << endl; #endif +#endif } template diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 48f75b8f..6d9f2652 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -122,6 +122,7 @@ public: static const bool dishonest_majority = false; static const bool variable_players = false; static const bool needs_ot = false; + static const bool has_mac = false; static string type_string() { return "replicated secret"; } static string phase_name() { return "Replicated computation"; } diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9b6c8478..9cdde3dc 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -49,6 +49,7 @@ public: static const bool dishonest_majority = T::dishonest_majority; static const bool variable_players = T::variable_players; static const bool needs_ot = T::needs_ot; + static const bool has_mac = T::has_mac; static const bool expensive_triples = false; static const int default_length = 64; diff --git a/GC/instructions.h b/GC/instructions.h index 66ae46d2..49443cc2 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -55,7 +55,7 @@ X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ X(SHRCBI, C0 = PC1 >> IMM) \ X(SHLCBI, C0 = PC1 << IMM) \ - X(LDBITS, S0.load_clear(REG1, IMM)) \ + X(LDBITS, S0.load_clear(REG1, int(IMM))) \ X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \ X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \ X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 7697c512..9f18d3a6 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -23,6 +23,7 @@ #include "Protocols/Shamir.hpp" #include "Protocols/ShamirMC.hpp" #include "Protocols/MaliciousShamirMC.hpp" +#include "Protocols/MaliciousShamirPO.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" diff --git a/Machines/temi-party.cpp b/Machines/temi-party.cpp new file mode 100644 index 00000000..12e99dc2 --- /dev/null +++ b/Machines/temi-party.cpp @@ -0,0 +1,37 @@ +/* + * temi-party.cpp + * + */ + +#include "Protocols/TemiShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "FHE/P2Data.h" +#include "Tools/ezOptionParser.h" +#include "GC/SemiSecret.h" +#include "GC/SemiPrep.h" + +#include "Processor/FieldMachine.hpp" +#include "Protocols/TemiPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/SemiPrep.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/Hemi.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/SemiHonestRepPrep.h" +#include "Math/gfp.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + HemiOptions::singleton = {opt, argc, argv}; + DishonestMajorityFieldMachine(argc, argv, + opt); +} diff --git a/Makefile b/Makefile index e40528b8..4f558e1d 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr all: overdrive she-offline -arithmetic: hemi-party.x soho-party.x gear +arithmetic: semi-he gear -include $(DEPS) include $(wildcard *.d static/*.d) @@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x +semi-he: hemi-party.x soho-party.x temi-party.x rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x @@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) +temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) @@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) +static/temi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 55983e0b..c0b2373e 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -14,11 +14,6 @@ using namespace std; #include "Tools/random.h" #include "field_types.h" -template class ReplicatedMC; -template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Replicated; - template class FixedVec { diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index f30e7103..13d700fc 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* if (mpn_cmp(ans+T,prA,T+1)>=0) { mpn_sub_fixed_n(z,ans+T,prA); } else - { inline_mpn_copyi(z,ans+T,T); } + { inline_mpn_copyi(z,ans+T); } #else Mont_Mult(z, x, y, t); #endif diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 1a6fe41d..f9491fb7 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -18,15 +18,21 @@ bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 6 +#define num_2_fields 7 /* Require * 2*(n-1)-64+t1<64 */ -int fields_2[num_2_fields][4] = { - {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1}, - }; - +int fields_2[num_2_fields][4] = +{ + { 4, 1, 0, 0 }, + { 8, 4, 3, 1 }, + { 16, 5, 3, 1 }, + { 28, 1, 0, 0 }, + { 40, 20, 15, 10 }, + { 63, 1, 0, 0 }, + { 128, 7, 2, 1 }, +}; template void gf2n_::init_tables() diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index b55a6b7e..b1c5642b 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si avx_memcpy(dest, src, size * sizeof(mp_limb_t)); } +template +inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src) +{ + avx_memcpy(dest, src); +} + inline void debug_print(const char* name, const mp_limb_t* x, int n) { (void)name, (void)x, (void)n; diff --git a/Networking/Player.h b/Networking/Player.h index 9c90dbd1..ff4bdcd1 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -542,6 +542,7 @@ public: int other_player_num() const { return P.get_player(offset); } int num_players() const { return 2; } int get_offset() const { return offset; } + Player& get_full_player() const { return P; } void send(octetStream& o) const { P.send_to(P.get_player(offset), o); } void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 98856585..730ffa6f 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs) receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); } } + +#ifdef BASE_OT_DEBUG + for (j = 0; j < 4; j++) + for (k = 0; k < AES_BLK_SIZE; k++) + { + printf("%4d-th receiver key:", i+j); + for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); + printf("\n"); + } + + printf("\n"); +#endif } } @@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs) for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]); printf("\n"); } - if (ot_role & RECEIVER) - { - printf("%4d-th receiver key:", i+j); - for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]); - printf("\n"); - } } printf("\n"); diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index ef735279..ea8239a5 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename, if (start_pos != -1) { - long write_pos = start_pos * T::size(); + long write_pos = file_signature().get_total_length() + start_pos * T::size(); // fill with zeros if needed for (long i = outf.tellp(); i < write_pos; i++) outf.put(0); @@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, inf.open(filename, ios::in | ios::binary); if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } + check_file_signature(inf, filename).get_length(); + auto data_start = inf.tellg(); + int size_in_bytes = T::size() * buffer.size(); int n_read = 0; char read_buffer[size_in_bytes]; - inf.seekg(start_posn * T::size()); + inf.seekg(start_posn * T::size(), iostream::cur); do { inf.read(read_buffer + n_read, size_in_bytes - n_read); @@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, if (inf.eof()) { stringstream ss; - ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; + ss << "Got to EOF when reading from disk (expecting " << size_in_bytes + << " bytes from " << (long(data_start) + start_posn * T::size()) + << ")."; throw file_error(ss.str()); } if (inf.fail()) @@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, } while (n_read < size_in_bytes); - end_posn = inf.tellg() / T::size(); + end_posn = (inf.tellg() - data_start) / T::size(); assert (end_posn == start_posn + int(buffer.size())); //Check if at end of file by getting 1 more char. diff --git a/Processor/Input.h b/Processor/Input.h index 98c6c83b..728c81f6 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -32,6 +32,15 @@ protected: Buffer buffer; Timer timer; + // Send my inputs (not generally available) + virtual void send_mine() { throw not_implemented(); } + // Get share for next input of mine (not generally available) + virtual T finalize_mine() { throw not_implemented(); } + // Store share for next input from ``player`` from buffer ``o`` + // in ``target`` (not generally available) + virtual void finalize_other(int, T&, octetStream&, int = -1) + { throw not_implemented(); } + public: vector os; int values_input; @@ -61,18 +70,12 @@ public: /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input, int n_bits = -1); + void add_from_all(const typename T::open_type& input, int n_bits = -1); - /// Send my inputs - virtual void send_mine() = 0; /// Run input protocol for all players virtual void exchange(); - /// Get share for next input of mine - virtual T finalize_mine() = 0; - /// Store share for next input from ``player`` from buffer ``o`` in ``target`` - virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; - /// Get share for next input from ``player` + /// Get share for next input from ``player`` virtual T finalize(int player, int n_bits = -1); void raw_input(SubProcessor& proc, const vector& args, int size); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index b9f7a77a..246c9eb1 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -113,7 +113,7 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input, int n_bits) +void InputBase::add_from_all(const typename T::open_type& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ca062cbc..a7e1e318 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -106,6 +106,7 @@ enum MATMULSM = 0xAB, CONV2DS = 0xAC, CHECK = 0xAF, + PRIVATEOUTPUT = 0xAD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -127,6 +128,7 @@ enum INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, INPUTPERSONAL = 0xF5, + SENDPERSONAL = 0xF6, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 25fa666f..1bc46f94 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case USE: case USE_INP: case USE_EDABIT: + case DIGESTC: + case INPUTMASK: + case GINPUTMASK: + get_ints(r, s, 2); + n = get_int(s); + break; case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: case STOPPRIVATEOUTPUT: case GSTOPPRIVATEOUTPUT: - case DIGESTC: - get_ints(r, s, 2); - n = get_int(s); - break; + throw runtime_error("two-stage private output not supported any more"); case USE_MATMUL: get_ints(r, s, 3); n = get_int(s); @@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRINTREGB: case GPRINTREG: case LDINT: - case INPUTMASK: - case GINPUTMASK: case INV2M: case CONDPRINTSTR: case CONDPRINTSTRB: @@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RAWINPUT: case GRAWINPUT: case INPUTPERSONAL: + case SENDPERSONAL: + case PRIVATEOUTPUT: case TRUNC_PR: case RUN_TAPE: num_var_args = get_int(s); @@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const case PUBINPUT: case FLOATOUTPUT: case READSOCKETC: + case PRIVATEOUTPUT: return CINT; default: if (is_gf2n_instruction()) @@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const skip = 1; break; case INPUTPERSONAL: + case PRIVATEOUTPUT: size_offset = -2; offset = 2; skip = 4; break; + case SENDPERSONAL: + size_offset = -2; + offset = 2; + skip = 5; + break; case READSOCKETS: case READSOCKETC: case READSOCKETINT: @@ -939,13 +949,11 @@ inline void Instruction::execute(Processor& Proc) const break; case INPUTMASK: Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); - if (n == Proc.P.my_num()) - Proc.temp.rrp.output(Proc.private_output, false); + Proc.write_Cp(r[1], Proc.temp.rrp); break; case GINPUTMASK: Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); - if (n == Proc.P.my_num()) - Proc.temp.ans2.output(Proc.private_output, false); + Proc.write_C2(r[1], Proc.temp.ans2); break; case INPUT: sint::Input::template input>(Proc.Procp, start, size); @@ -974,6 +982,12 @@ inline void Instruction::execute(Processor& Proc) const case INPUTPERSONAL: Proc.Procp.input_personal(start); return; + case SENDPERSONAL: + Proc.Procp.send_personal(start); + return; + case PRIVATEOUTPUT: + Proc.Procp.private_output(start); + return; // Note: Fp version has different semantics for NOTC than GNOTC case NOTC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); @@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor& Proc) const Proc.binary_output.write((char*) &tmp, sizeof(double)); } break; - case STARTPRIVATEOUTPUT: - Proc.privateOutputp.start(n,r[0],r[1]); - break; - case GSTARTPRIVATEOUTPUT: - Proc.privateOutput2.start(n,r[0],r[1]); - break; - case STOPPRIVATEOUTPUT: - Proc.privateOutputp.stop(n,r[0],r[1]); - break; - case GSTOPPRIVATEOUTPUT: - Proc.privateOutput2.stop(n,r[0],r[1]); - break; case PREP: Procp.DataF.get(Proc.Procp.get_S(), r, start, size); return; diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index a43c9d47..cd318f1a 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -97,12 +97,19 @@ Machine::Machine(int my_number, Names& playerNames, // initialize persistence if necessary for (auto& prog : progs) { - if (prog.writes_persistance) + if (prog.writes_persistence) { string filename = Binary_File_IO::filename(my_number); ifstream pers(filename); - if (pers.fail()) - ofstream pers(filename, ios::binary); + try + { + check_file_signature(pers, filename); + } + catch (signature_mismatch&) + { + ofstream pers(filename, ios::binary); + file_signature().output(pers); + } break; } } @@ -418,12 +425,14 @@ void Machine::run() cerr << "Full broadcast" << endl; #endif +#ifdef CHOP_MEMORY // Reduce memory size to speed up unsigned max_size = 1 << 20; if (M2.size_s() > max_size) M2.resize_s(max_size); if (Mp.size_s() > max_size) Mp.resize_s(max_size); +#endif // Write out the memory to use next time ofstream outf(memory_filename(), ios::out | ios::binary); diff --git a/Processor/Memory.h b/Processor/Memory.h index 9ec02d2b..1fbeda7e 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -44,9 +44,9 @@ class Memory static void check_index(const vector& M, size_t i) { (void) M, (void) i; -#ifdef NO_CHECK_INDEX +#ifndef NO_CHECK_INDEX if (i >= M.size()) - throw overflow("memory", i, M.size()); + throw overflow(U::type_string() + " memory", i, M.size()); #endif } diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index c3c3e01b..ef767441 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -19,6 +19,9 @@ void MemoryPart::minimum_size(size_t size) { if (size > this->size()) this->resize(size); +#ifdef DEBUG_MEMORY_SIZE + cerr << T::type_string() << " memory has now size " << this->size() << endl; +#endif } catch (bad_alloc&) { @@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory& M) int len; s >> len; - M.resize_s(len); + M.MS.minimum_size(len); s >> len; - M.resize_c(len); + M.MC.minimum_size(len); s.seekg(1, istream::cur); for (unsigned int i=0; i& proc; + typename T::MAC_Check MC; deque masks; public: - PrivateOutput(SubProcessor& proc) : proc(proc) { }; + PrivateOutput(SubProcessor& proc); + ~PrivateOutput(); - void start(int player, int target, int source); - void stop(int player, int dest, int source); - - T start(int player, const T& source); - typename T::clear stop(int player, const typename T::clear& masked); + void prepare_sending(const T& source, int player); + void exchange(); + typename T::clear finalize(int player); }; #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index 977e7e15..d2cee8a1 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -7,13 +7,21 @@ #include "Processor.h" template -void PrivateOutput::start(int player, int target, int source) +PrivateOutput::PrivateOutput(SubProcessor& proc) : + proc(proc), MC(proc.MC.get_alphai()) { - proc.get_S_ref(target) = start(player, proc.get_S_ref(source)); + MC.init_open(proc.P); + MC.set_prep(proc.DataF); } template -T PrivateOutput::start(int player, const T& source) +PrivateOutput::~PrivateOutput() +{ + MC.Check(proc.P); +} + +template +void PrivateOutput::prepare_sending(const T& source, int player) { assert (player < proc.P.num_players()); open_type mask; @@ -24,26 +32,25 @@ T PrivateOutput::start(int player, const T& source) if (player == proc.P.my_num()) masks.push_back(mask); - return res; + MC.prepare_open(res); } template -void PrivateOutput::stop(int player, int dest, int source) +void PrivateOutput::exchange() { - auto& value = proc.get_C_ref(dest); - value = stop(player, proc.get_C_ref(source)); - if (proc.Proc) - value.output(proc.Proc->private_output, false); + MC.exchange(proc.P); } template -typename T::clear PrivateOutput::stop(int player, const typename T::clear& source) +typename T::clear PrivateOutput::finalize(int player) { - typename T::clear value; + auto res = MC.finalize_open(); + if (player == proc.P.my_num()) { - value = source - masks.front(); + res -= masks.front(); masks.pop_front(); } - return value; + + return res; } diff --git a/Processor/Processor.h b/Processor/Processor.h index c91b677b..38ea7f25 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -71,6 +71,8 @@ public: void conv2ds(const Instruction& instruction); void input_personal(const vector& args); + void send_personal(const vector& args); + void private_output(const vector& args); CheckVector& get_S() { @@ -110,7 +112,6 @@ public: ifstream private_input; ifstream public_input; ofstream public_output; - ofstream private_output; ofstream binary_output; int sent, rounds; @@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor SubProcessor Proc2; SubProcessor Procp; - typename sgf2n::PrivateOutput privateOutput2; - typename sint::PrivateOutput privateOutputp; - unsigned int PC; TempVars temp; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index c55a6dfc..d74594b3 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -4,9 +4,8 @@ #include "Processor/Processor.h" #include "Processor/Program.h" #include "GC/square64.h" +#include "SpecificPrivateOutput.h" -#include "Protocols/ReplicatedInput.hpp" -#include "Protocols/ReplicatedPrivateOutput.hpp" #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" @@ -63,7 +62,6 @@ Processor::Processor(int thread_num,Player& P, share_thread(DataF.DataFb, P, machine.get_bit_mac_key()), Procb(machine.bit_memories), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), - privateOutput2(Proc2),privateOutputp(Procp), external_clients(P.my_num()), binary_file_io(Binary_File_IO()) { @@ -74,7 +72,6 @@ Processor::Processor(int thread_num,Player& P, private_input_filename = (get_filename(PREP_DIR "Private-Input-",true)); private_input.open(private_input_filename.c_str()); public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); - private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out); binary_output.open( get_parameterized_filename(P.my_num(), thread_num, PREP_DIR "Binary-Output"), ios_base::out); @@ -654,6 +651,37 @@ void SubProcessor::input_personal(const vector& args) S[args[i + 2] + j] = input.finalize(args[i + 1]); } +template +void SubProcessor::private_output(const vector& args) +{ + typename T::PrivateOutput output(*this); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + { + int player = args[i + 1]; + output.prepare_sending(S.at(args[i + 3] + j), player); + } + output.exchange(); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + C.at(args[i + 2] + j) = output.finalize(args[i + 1]); +} + +template +void SubProcessor::send_personal(const vector& args) +{ + octetStreams to_send(P), to_receive(P); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 3] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 4] + j].pack(to_send[args[i + 1]]); + P.send_receive_all(to_send, to_receive); + for (size_t i = 0; i < args.size(); i += 5) + if (args[i + 1] == P.my_num()) + for (int j = 0; j < args[i]; j++) + C[args[i + 2] + j].unpack(to_receive[args[i + 3]]); +} + template typename sint::clear Processor::get_inverse2(unsigned m) { diff --git a/Processor/Program.cpp b/Processor/Program.cpp index c3303942..dac73400 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -23,7 +23,7 @@ void Program::compute_constants() max_mem[reg_type] = max(max_mem[reg_type], p[i].get_mem(RegType(reg_type))); } - writes_persistance |= p[i].opcode == WRITEFILESHARE; + writes_persistence |= p[i].opcode == WRITEFILESHARE; } } diff --git a/Processor/Program.h b/Processor/Program.h index a41c9e2a..87a263f0 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -30,10 +30,10 @@ class Program public: - bool writes_persistance; + bool writes_persistence; Program(int nplayers) : offline_data_used(nplayers), - unknown_usage(false), writes_persistance(false) + unknown_usage(false), writes_persistence(false) { compute_constants(); } // Read in a program diff --git a/Processor/SpecificPrivateOutput.h b/Processor/SpecificPrivateOutput.h new file mode 100644 index 00000000..7878db1c --- /dev/null +++ b/Processor/SpecificPrivateOutput.h @@ -0,0 +1,65 @@ +/* + * SpecificPrivateOutput.h + * + */ + +#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ +#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ + +template +class SpecificPrivateOutput +{ + deque secrets; + vector pos; + Player& P; + vector active; + +public: + SpecificPrivateOutput(SubProcessor& proc) : + P(proc.P) + { + for (int i = 0; i < P.num_players(); i++) + pos.push_back(new typename T::PO(proc.P)); + active.resize(P.num_players()); + } + + ~SpecificPrivateOutput() + { + for (auto& x : pos) + delete x; + } + + void prepare_sending(const T& secret, int player) + { + pos[player]->prepare_sending(secret, player); + if (P.my_num() == player) + secrets.push_back(secret); + active[player] = true; + } + + void exchange() + { + for (int i = 0; i < this->P.num_players(); i++) + if (active[i]) + { + if (i == this->P.my_num()) + pos[i]->receive(); + else + pos[i]->send(i); + } + } + + typename T::clear finalize(int player) + { + if (player == this->P.my_num()) + { + T secret = secrets.front(); + secrets.pop_front(); + return pos[player]->finalize(secret); + } + else + return {}; + } +}; + +#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */ diff --git a/Programs/Source/falcon_alex.mpc b/Programs/Source/falcon_alex.mpc new file mode 100644 index 00000000..3c535248 --- /dev/null +++ b/Programs/Source/falcon_alex.mpc @@ -0,0 +1,100 @@ +from Compiler.ml import keras +import Compiler.ml as tf + +try: + n_epochs = int(program.args[1]) +except (ValueError, IndexError): + n_epochs = 10 + +try: + batch_size = int(program.args[2]) +except (ValueError, IndexError): + batch_size = 128 + +try: + n_threads = int(program.args[3]) +except (ValueError, IndexError): + n_threads = 36 + +#Instantiation +AlexNet = [] + +padding = 'same' +batchnorm = 'batchnorm' in program.args + +#1st Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9)) +AlexNet.append(keras.layers.Activation('relu')) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2))) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) + +#2nd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1)) + +#3rd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#4th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#5th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1)) +AlexNet.append(keras.layers.Activation('relu')) + +#Passing it to a Fully Connected layer +# 1st Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#2nd Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#Output Layer +AlexNet.append(keras.layers.Dense(10)) + + +tf.set_n_threads(n_threads) +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +if 'no_acc' not in program.args: + training_labels.input_from(0) + training_samples.input_from(0) + + test_labels.input_from(0) + test_samples.input_from(0) + +model = tf.keras.models.Sequential(AlexNet) + +model.compile_by_args(program) + +model.build(training_samples.sizes) +model.summary() + +opt = model.fit( + training_samples, + training_labels, + epochs=n_epochs, + batch_size=batch_size, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_cifar_lenet.mpc b/Programs/Source/keras_cifar_lenet.mpc new file mode 100644 index 00000000..882d2e18 --- /dev/null +++ b/Programs/Source/keras_cifar_lenet.mpc @@ -0,0 +1,45 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml +ml.set_n_threads(36) + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.Adam(amsgrad=True) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=10, + batch_size=128, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/keras_mnist_dense.mpc b/Programs/Source/keras_mnist_dense.mpc index a525c065..76b1e23f 100644 --- a/Programs/Source/keras_mnist_dense.mpc +++ b/Programs/Source/keras_mnist_dense.mpc @@ -21,7 +21,8 @@ tf = ml layers = [ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128), + tf.keras.layers.Activation('relu'), tf.keras.layers.Dense(10, activation='softmax') ] diff --git a/Programs/Source/keras_mnist_lenet.mpc b/Programs/Source/keras_mnist_lenet.mpc index 9fdac27f..674cf403 100644 --- a/Programs/Source/keras_mnist_lenet.mpc +++ b/Programs/Source/keras_mnist_lenet.mpc @@ -20,8 +20,21 @@ tf = ml layers = [ tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), +] + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), +] + + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ tf.keras.layers.MaxPooling2D(2), tf.keras.layers.Flatten(), tf.keras.layers.Dropout(0.5), diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 9dc8a685..37cd73d2 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -21,6 +21,8 @@ elif 'debug' in program.args: n_test = 100 elif 'debug5000' in program.args: N = n_test = 5000 +elif 'mini' in program.args: + N = n_test = 10 else: N = 60000 n_test = 10000 @@ -39,6 +41,7 @@ except: batch_size = N N = min(N, 10000) +batch_size = min(batch_size, N) ml.Layer.back_batch_size = batch_size try: @@ -71,6 +74,9 @@ else: ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, 10, debug=debug_ml)] +if 'batchnorm' in program.args: + layers.insert(1, ml.BatchNorm([N, n_inner])) + if 'dropout' in program.args: for i in range(len(layers) - 1, 0, -1): layers.insert(i, ml.Dropout(N, n_inner)) diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 6ea76b26..04ca11ad 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -53,7 +53,7 @@ except: ml.Layer.back_batch_size = batch_size layers = [ - ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'), + ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'), ml.MaxPool([N, 24, 24, 20]), ml.Relu([N, 12, 12, 20]), ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'), @@ -66,6 +66,12 @@ layers = [ layers += [ml.MultiOutput.from_args(program, n_examples, 10)] +if 'batchnorm' in program.args: + for arg in program.args: + assert not arg.startswith('dropout') + layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args)) + layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args)) + if 'dropout' in program.args or 'dropout2' in program.args: layers.insert(8, ml.Dropout(N, 500)) elif 'dropout.25' in program.args: diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index c3a919b3..9c6f0b9c 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -85,6 +85,12 @@ void Atlas::exchange() resharing.add_mine(e); } + for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++) + { + int j = (base_king + i) % P.num_players(); + resharing.add_sender(j); + } + resharing.exchange(); } diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index e67b28a9..1eebd3b7 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -27,7 +27,7 @@ HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, new HemiMatrixPrep(dims[0], dims[1], dims[2], - dynamic_cast&>(processor.DataF))}); + dynamic_cast(processor.DataF))}); return *matrix_preps.at(dims); } diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index e48d9257..ea5a7211 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -18,17 +18,18 @@ template class HemiMatrixPrep : public BufferPrep> { typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; int n_rows, n_inner, n_cols; bool swapped; DataPositions* usage; - HemiPrep* prep; + LivePrep* prep; HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: - HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep& prep) : + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) : super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 82b28431..f4221299 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -87,11 +87,10 @@ void HemiMatrixPrep::buffer_triples() assert(prep); auto& multipliers = prep->get_multipliers(); - assert(prep->pairwise_machine); - auto& FTD = prep->pairwise_machine->setup_p.FieldD; - auto& pk = prep->pairwise_machine->pk; + auto& FTD = prep->get_FTD(); + auto& pk = prep->get_pk(); int n_matrices = FTD.num_slots() / n_rows; -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, n_inner, n_cols); fflush(stderr); @@ -103,20 +102,23 @@ void HemiMatrixPrep::buffer_triples() AddableVector> C(n_matrices); MatrixRandMultJob job(C, A, B); - if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + if (T::local_mul) { - auto& queues = BaseMachine::s().queues; - int start = queues.distribute(job, n_matrices); - job.begin = start; - job.end = n_matrices; - matrix_rand_mult(job); - queues.wrap_up(job); - } - else - { - job.begin = 0; - job.end = n_matrices; - matrix_rand_mult(job); + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + { + auto& queues = BaseMachine::s().queues; + int start = queues.distribute(job, n_matrices); + job.begin = start; + job.end = n_matrices; + matrix_rand_mult(job); + queues.wrap_up(job); + } + else + { + job.begin = 0; + job.end = n_matrices; + matrix_rand_mult(job); + } } #ifdef VERBOSE_HE @@ -130,26 +132,35 @@ void HemiMatrixPrep::buffer_triples() assert(prep->proc); auto& P = prep->proc->P; - Bundle bundle(P); - bundle.mine.store(diag.ciphertexts); - P.unchecked_broadcast(bundle); vector> others_ct; - for (auto& os : bundle) + + if (T::local_mul or OnlineOptions::singleton.direct) { - others_ct.push_back({}); - os.get(others_ct.back(), Ciphertext(pk)); + Bundle bundle(P); + bundle.mine.store(diag.ciphertexts); + P.unchecked_broadcast(bundle); + for (auto& os : bundle) + { + others_ct.push_back({}); + os.get(others_ct.back(), Ciphertext(pk)); + } + } + else + { + others_ct.push_back(diag.ciphertexts); + TreeSum().run(others_ct[0], P); } for (int j = 0; j < n_cols; j++) for (auto m : multipliers) { -#ifdef VERBOSE +#ifdef VERBOSE_HE fprintf(stderr, "column %d with party offset %d at %f\n", j, m->get_offset(), timer.elapsed()); fflush(stderr); #endif Ciphertext C(pk); - auto& multiplicands = others_ct[P.get_player(-m->get_offset())]; + auto& multiplicands = m->get_multiplicands(others_ct, pk); if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) { auto& queues = BaseMachine::s().queues; @@ -160,7 +171,7 @@ void HemiMatrixPrep::buffer_triples() CipherPlainMultJob job(products, multiplicands, multiplicands2, true); int start = queues.distribute(job, n_inner); #ifdef VERBOSE_HE - fprintf(stderr, "from %d in central thread\n", start); + fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed()); fflush(stderr); #endif for (int i = start; i < n_inner; i++) @@ -185,7 +196,10 @@ void HemiMatrixPrep::buffer_triples() m->add(products[j], C, BOTH, n_inner); } - C += diag.dediag(products, n_matrices); + if (T::local_mul) + C += diag.dediag(products, n_matrices); + else + C = diag.dediag(products, n_matrices); for (int i = 0; i < n_matrices; i++) if (swapped) diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index c43b43e9..b2b510aa 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -34,6 +34,9 @@ public: static void basic_setup(Player& P); static void teardown(); + static const FHE_PK& get_pk(); + static const FD& get_FTD(); + HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index 6cdd7547..c456424e 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -34,6 +34,20 @@ void HemiPrep::basic_setup(Player& P) T::clear::template init(); } +template +const FHE_PK& HemiPrep::get_pk() +{ + assert(pairwise_machine); + return pairwise_machine->pk; +} + +template +const typename T::clear::FD& HemiPrep::get_FTD() +{ + assert(pairwise_machine); + return pairwise_machine->setup().FieldD; +} + template HemiPrep::~HemiPrep() diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index d299fb18..4a85cbe3 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -27,6 +27,7 @@ public: typedef HemiPrep LivePrep; static const bool needs_ot = false; + static const bool local_mul = true; static true_type triple_matmul; HemiShare() diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 9ff92fb0..be0fac61 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -140,12 +140,12 @@ void KeyGenProtocol::output_to(int player, vector& opened, vector& shares) { PrivateOutput po(*proc); - vector masked; for (auto& share : shares) - masked.push_back(po.start(player, share)); - MC->POpen(opened, masked, P); + po.prepare_sending(share, player); + po.exchange(); + opened.resize(shares.size()); for (auto& x : opened) - x = po.stop(player, x); + x = po.finalize(player); } template diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 571f391e..2250417d 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -52,6 +52,7 @@ public: virtual ~TreeSum(); void run(vector& values, const Player& P); + T run(const T& value, const Player& P); octetStream& get_buffer() { return os; } @@ -210,6 +211,14 @@ void TreeSum::run(vector& values, const Player& P) finish(values, P); } +template +T TreeSum::run(const T& value, const Player& P) +{ + vector values = {value}; + run(values, P); + return values[0]; +} + template size_t TreeSum::report_size(ReportType type) { @@ -244,14 +253,6 @@ void add_openings(vector& values, const Player& P, int sum_players, int last_ MC.player_timers[sender].start(); P.wait_receive(sender, oss[j]); MC.player_timers[sender].stop(); - if ((unsigned)oss[j].get_length() < values.size() * T::size()) - { - stringstream ss; - ss << "Not enough information received, expected " - << values.size() * T::size() << " bytes, got " - << oss[j].get_length(); - throw Processor_Error(ss.str()); - } MC.timers[SUM].start(); for (unsigned int i=0; i::Check(const Player& P) auto& vals = this->vals; auto& macs = this->macs; auto& popen_cnt = this->popen_cnt; + assert(int(macs.size()) <= popen_cnt); if (popen_cnt < 10) { diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index c7d477ad..5a60281c 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -12,6 +12,8 @@ using namespace std; #include "Networking/Player.h" #include "Tools/PointerVector.h" +template class Preprocessing; + /** * Abstract base class for opening protocols */ @@ -61,6 +63,8 @@ public: virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); virtual const Player& get_check_player(const Player& P) const { return P; } + + virtual void set_prep(Preprocessing&) {} }; #endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */ diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index 63bfe63a..ff33a6ee 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share> { typedef SignedZ2 T; typedef MaliciousRep3Share super; + typedef MalRepRingShare This; public: const static int BIT_LENGTH = K; @@ -26,7 +27,8 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index f98e9797..e6f3a8a6 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -13,6 +13,7 @@ template class Beaver; template class MaliciousRepPrepWithBits; template class MaliciousRepPO; template class MaliciousRepPrep; +template class SpecificPrivateOutput; namespace GC { @@ -30,8 +31,8 @@ public: typedef HashMaliciousRepMC> MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; - typedef ::PrivateOutput> PrivateOutput; typedef MaliciousRepPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef Rep3Share Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MaliciousShamirPO.h b/Protocols/MaliciousShamirPO.h index 65003d10..5bffe4f8 100644 --- a/Protocols/MaliciousShamirPO.h +++ b/Protocols/MaliciousShamirPO.h @@ -9,13 +9,14 @@ template class MaliciousShamirPO { +protected: Player& P; octetStream to_send; vector to_receive; vector shares; - MaliciousShamirMC MC; + typename T::Direct_MC MC; public: MaliciousShamirPO(Player& P); diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 47592981..fee8e829 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -13,6 +13,7 @@ template class MaliciousRepPrepWithBits; template class MaliciousRepPrep; template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -23,14 +24,15 @@ template class MaliciousShamirShare : public ShamirShare { typedef ShamirShare super; + typedef MaliciousShamirShare This; public: typedef Beaver> Protocol; typedef MaliciousShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ShamirShare Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index fa3bc9f0..c90a5e27 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -76,12 +76,6 @@ public: return string(1, T::type_char()); } - static void read_or_generate_mac_key(string, Player&, mac_key_type& key) - { - SeededPRNG G; - key.randomize(G); - } - MamaShare() { } diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index a7fed8af..06196762 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -15,6 +15,7 @@ template class PostSacriRepFieldShare : public MaliciousRep3Share { typedef MaliciousRep3Share super; + typedef PostSacriRepFieldShare This; public: typedef typename super::clear clear; @@ -23,7 +24,8 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MaliciousRepPrepWithBits LivePrep; PostSacriRepFieldShare() diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index d4f2ab0f..7cbd483c 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -17,6 +17,7 @@ template class PostSacriRepRingShare : public Rep3Share2 { typedef Rep3Share2 super; + typedef PostSacriRepRingShare This; public: static const int BIT_LENGTH = K; @@ -33,7 +34,8 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h index e6a8eb52..09be88cb 100644 --- a/Protocols/ProtocolSet.h +++ b/Protocols/ProtocolSet.h @@ -42,8 +42,13 @@ public: { } - ~ProtocolSet() + /** + * Run all protocol checks + */ + void check() { + protocol.check(); + output.Check(processor.P); } }; @@ -73,6 +78,15 @@ public: *thread.protocol), input(output, prep, P) { } + + /** + * Run all protocol checks + */ + void check() + { + protocol.check(); + output.Check(protocol.P); + } }; /** @@ -102,6 +116,15 @@ public: arithmetic.protocol), input(arithmetic.input) { } + + /** + * Run all protocol checks + */ + void check() + { + arithmetic.check(); + binary.check(); + } }; #endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index e85065ac..44853b79 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -15,7 +15,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; -template class PrivateOutput; +template class ReplicatedPO; +template class SpecificPrivateOutput; template class RepShare : public FixedVec, public ShareInterface @@ -99,6 +100,7 @@ template class Rep3Share : public RepShare { typedef RepShare super; + typedef Rep3Share This; public: typedef T clear; @@ -107,7 +109,8 @@ public: typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 23f28cf9..e52d160b 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -24,7 +24,8 @@ public: typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef ReplicatedPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep2k LivePrep; typedef Rep3Share2 Honest; typedef SignedZ2 clear; diff --git a/Protocols/Rep4Input.h b/Protocols/Rep4Input.h index f1bc29af..04acd004 100644 --- a/Protocols/Rep4Input.h +++ b/Protocols/Rep4Input.h @@ -31,7 +31,6 @@ public: void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize_mine(); diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 5600b45c..48844396 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -64,12 +64,6 @@ void Rep4Input::add_other(int player, int) results[player].push_back(res); } -template -void Rep4Input::send_mine() -{ - throw not_implemented(); -} - template void Rep4Input::exchange() { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 67527a20..2357d0f5 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -19,10 +19,6 @@ using namespace std; template class SubProcessor; template class ReplicatedMC; template class ReplicatedInput; -template class ReplicatedPrivateOutput; -template class Share; -template class Rep3Share; -template class MAC_Check_Base; template class Preprocessing; class Instruction; @@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase void trunc_pr(const vector& regs, int size, U& proc, false_type); public: - typedef ReplicatedMC MAC_Check; - typedef ReplicatedInput Input; - static const bool uses_triples = false; Replicated(Player& P); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 374ed89b..1a8a66b9 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -10,6 +10,7 @@ #include "Processor/Processor.h" #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" @@ -162,14 +163,13 @@ void Replicated::prepare_mul(const T& x, } template -inline void Replicated::prepare_reshare(const typename T::clear& share, +void Replicated::prepare_reshare(const typename T::clear& share, int n) { - auto add_share = share; typename T::value_type tmp[2]; for (int i = 0; i < 2; i++) tmp[i].randomize(shared_prngs[i], n); - add_share += tmp[0] - tmp[1]; + auto add_share = share + tmp[0] - tmp[1]; add_share.pack(os[0], n); add_shares.push_back(add_share); } diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 916ee6b8..b12f7f91 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,16 +56,24 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif + auto field_type = T::clear::field_type(); + auto& my_usage = this->usage.files.at(field_type); + this->print_left("triples", triples.size() * T::default_length, type_string, this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) * T::default_length); + size_t used_bits = my_usage.at(DATA_BIT); + if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + // add dabits with computation modulo power of two but without MAC + used_bits += my_usage.at(DATA_DABIT); + this->print_left("bits", bits.size(), type_string, used_bits); + #define X(KIND, TYPE) \ this->print_left(#KIND, KIND.size(), type_string, \ this->usage.files.at(T::clear::field_type()).at(TYPE)); X(squares, DATA_SQUARE) X(inverses, DATA_INVERSE) - X(bits, DATA_BIT) X(dabits, DATA_DABIT) #undef X @@ -601,17 +609,6 @@ void buffer_bits_from_players(vector>& player_bits, for (int i = 0; i < n_relevant_players; i++) for (auto& x : player_bits[i]) x = input.finalize((base_player + i) % P.num_players(), n_bits); -#if !defined(__clang__) && (__GNUC__ == 6) - // mitigate compiler bug - Bundle bundle(P); - P.unchecked_broadcast(bundle); -#endif -#ifdef DEBUG_BIT_SACRIFICE - typename T::MAC_Check MC; - for (int i = 0; i < n_relevant_players; i++) - for (auto& x : player_bits[i]) - assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1)); -#endif } template @@ -1164,18 +1161,18 @@ void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) typename T::clear r; r.randomize(G); input.add_mine(r); - this->inputs[player].push_back({input.finalize_mine(), r}); + this->inputs[player].push_back({input.finalize(player), r}); } - input.send_mine(); + input.exchange(); } else { - octetStream os; - P.receive_player(player, os); - T share; + for (int i = 0; i < buffer_size; i++) + input.add_other(player); + input.exchange(); for (int i = 0; i < buffer_size; i++) { - input.finalize_other(player, share, os); + auto share = input.finalize(player); this->inputs[player].push_back({share, 0}); } } diff --git a/Protocols/ReplicatedPrivateOutput.h b/Protocols/ReplicatedPrivateOutput.h deleted file mode 100644 index b9e546ca..00000000 --- a/Protocols/ReplicatedPrivateOutput.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * ReplicatedPrivateOutput.h - * - */ - -#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ -#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ - -template -class SubProcessor; -template -class Share; - -template -class ReplicatedPrivateOutput -{ - SubProcessor& proc; - -public: - ReplicatedPrivateOutput(SubProcessor& proc); - - void start(int player, int target, int source); - void stop(int player, int source); -}; - -#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */ diff --git a/Protocols/ReplicatedPrivateOutput.hpp b/Protocols/ReplicatedPrivateOutput.hpp deleted file mode 100644 index d3487223..00000000 --- a/Protocols/ReplicatedPrivateOutput.hpp +++ /dev/null @@ -1,30 +0,0 @@ -/* - * ReplicatedPrivateOutput.cpp - * - */ - -#include "ReplicatedPrivateOutput.h" -#include "Processor/Processor.h" -#include "Math/FixedVec.h" -#include "Math/Integer.h" - -template -inline ReplicatedPrivateOutput::ReplicatedPrivateOutput( - SubProcessor& proc) : - proc(proc) -{ -} - -template -void ReplicatedPrivateOutput::start(int player, int target, - int source) -{ - (void)player, (void)target, (void)source; - throw runtime_error("not implemented, use PrivateOutput"); -} - -template -void ReplicatedPrivateOutput::stop(int player, int source) -{ - (void)player, (void)source; -} diff --git a/Protocols/Semi.h b/Protocols/Semi.h index e290ca0e..5f63a9d6 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -71,6 +71,12 @@ public: proc.get_S()[info.source_base + i] >> info.m; } } + + void buffer_random() + { + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->random.push_back(G.get()); + } }; #endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 87a1e08e..4fc265b7 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -14,34 +14,33 @@ template class SemiMC; * Additive secret sharing input protocol */ template -class SemiInput : public IndividualInput +class SemiInput : public InputBase { - SeededPRNG secure_prng; + vector send_prngs; + vector recv_prngs; + Player& P; + vector> shares; public: - SemiInput(SubProcessor& proc, SemiMC& MC) : - IndividualInput(proc) + SemiInput(SubProcessor& proc, SemiMC&) : + SemiInput(&proc, proc.P) { - (void) MC; } - SemiInput(SubProcessor* proc, Player& P) : - IndividualInput(proc, P) - { - } + SemiInput(SubProcessor* proc, Player& P); SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : - SemiInput(P) + SemiInput(0, P) { (void) MC, (void) prep; } - SemiInput(Player& P) : - IndividualInput(0, P) - { - } - + void reset(int player); void add_mine(const typename T::clear& input, int n_bits = -1); + void add_other(int player, int n_bits = -1); + void exchange(); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + T finalize_mine(); }; #endif /* PROTOCOLS_SEMIINPUT_H_ */ diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 28673250..3ed1feef 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -11,22 +11,64 @@ #include "ShamirInput.hpp" template -void SemiInput::add_mine(const typename T::clear& input, int n_bits) +SemiInput::SemiInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P) +{ + shares.resize(P.num_players()); + vector to_send(P.num_players()), to_receive; + for (int i = 0; i < P.num_players(); i++) + { + send_prngs.push_back({}); + to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE); + } + P.send_receive_all(to_send, to_receive); + recv_prngs.resize(P.num_players()); + for (int i = 0; i < P.num_players(); i++) + if (i != P.my_num()) + recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE)); + this->reset_all(P); +} + +template +void SemiInput::reset(int player) +{ + shares[player].clear(); +} + +template +void SemiInput::add_mine(const typename T::clear& input, int) { auto& P = this->P; typename T::open_type sum, share; for (int i = 0; i < P.num_players(); i++) { - if (i < P.num_players() - 1) - share.randomize(secure_prng, n_bits); - else - share = input - sum; - sum += share; - if (i == P.my_num()) - this->shares.push_back(share); - else - share.pack(this->os[i], n_bits); + if (i != P.my_num()) + sum += send_prngs[i].template get(); } + shares[P.my_num()].push_back(input - sum); +} + +template +void SemiInput::add_other(int, int) +{ +} + +template +void SemiInput::exchange() +{ +} + +template +void SemiInput::finalize_other(int player, T& target, octetStream&, + int) +{ + target = recv_prngs[player].template get(); +} + +template +T SemiInput::finalize_mine() +{ + return shares[P.my_num()].next(); } #endif diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index f722886e..402173e9 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -27,7 +27,6 @@ class Shamir : public ProtocolBase { typedef typename T::open_type::Scalar U; - octetStreams os; vector reconstruction; U rec_factor; ShamirInput* resharing; diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 9fe10bde..8bfdf70e 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -69,8 +69,6 @@ int Shamir::get_n_relevant_players() template void Shamir::reset() { - os.reset(P); - if (resharing == 0) { resharing = new ShamirInput(0, P); @@ -78,6 +76,9 @@ void Shamir::reset() for (int i = 0; i < P.num_players(); i++) resharing->reset(i); + + for (int i = 0; i < n_mul_players; i++) + resharing->add_sender(i); } template @@ -92,37 +93,27 @@ template void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; - auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) - resharing->add_mine(add_share); + resharing->add_mine(x * y * rec_factor); } template void Shamir::exchange() { - vector senders(P.num_players(), false); - for (int i = 0; i < n_mul_players; i++) - senders[i] = true; - P.send_receive_all(senders, resharing->os, os); + assert(resharing); + resharing->exchange(); } template void Shamir::start_exchange() { - if (P.my_num() < n_mul_players) - for (int offset = 1; offset < P.num_players(); offset++) - P.send_relative(offset, resharing->os[P.get_player(offset)]); + resharing->start_exchange(); } template void Shamir::stop_exchange() { - for (int offset = 1; offset < P.num_players(); offset++) - { - int receive_from = P.get_player(-offset); - if (receive_from < n_mul_players) - P.receive_player(receive_from, os[receive_from]); - } + resharing->stop_exchange(); } template @@ -136,15 +127,8 @@ template T Shamir::finalize(int n_relevant_players) { ShamirShare res = U(0); - if (P.my_num() < n_relevant_players) - res = resharing->finalize_mine(); for (int i = 0; i < n_relevant_players; i++) - if (i != P.my_num()) - { - T tmp; - resharing->finalize_other(i, tmp, os[i]); - res += tmp; - } + res += resharing->finalize(i); return res; } @@ -259,7 +243,7 @@ vector Shamir::get_randoms(PRNG& G, int t) input.reset_all(P); int buffer_size = OnlineOptions::singleton.batch_size; for (int i = 0; i < buffer_size; i += hyper.size()) - input.add_mine(G.get()); + input.add_from_all(G.get()); input.exchange(); vector inputs; vector random; diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 02346707..91e09309 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput protected: Player& P; octetStreams os; + vector senders; public: IndividualInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P) + PrepLessInput(proc), P(P), senders(P.num_players()) { this->reset_all(P); } @@ -34,10 +35,14 @@ public: } void reset(int player); + void add_sender(int player); void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + + void start_exchange(); + void stop_exchange(); }; /** diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index d84b09a6..6d9992ad 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -20,6 +20,8 @@ void IndividualInput::reset(int player) this->i_share = 0; os.reset(P); } + + senders[player] = false; } template @@ -68,12 +70,20 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) else x.pack(this->os[i]); } + + this->senders[P.my_num()] = true; +} + +template +void IndividualInput::add_sender(int player) +{ + senders[player] = true; } template void IndividualInput::add_other(int player, int) { - (void) player; + add_sender(player); } template @@ -87,7 +97,26 @@ void IndividualInput::send_mine() template void IndividualInput::exchange() { - P.send_receive_all(os, InputBase::os); + P.send_receive_all(senders, os, InputBase::os); +} + +template +void IndividualInput::start_exchange() +{ + if (senders[P.my_num()]) + for (int offset = 1; offset < P.num_players(); offset++) + P.send_relative(offset, os[P.get_player(offset)]); +} + +template +void IndividualInput::stop_exchange() +{ + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + if (senders[receive_from]) + P.receive_player(receive_from, InputBase::os[receive_from]); + } } template diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 8f76d6a7..6bda92df 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -33,9 +33,12 @@ public: template class ShamirMC : public IndirectShamirMC { + typedef typename T::open_type open_type; typedef typename T::open_type::Scalar rec_type; vector reconstruction; + ShamirMC(const ShamirMC&); + void finalize(vector& values, const vector& S); protected: @@ -71,6 +74,7 @@ public: void Check(const Player& P) { (void)P; } vector get_reconstruction(const Player& P); + open_type reconstruct(const vector& shares); }; #endif /* PROTOCOLS_SHAMIRMC_H_ */ diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 6d6af913..e3e7cd3a 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -130,6 +130,19 @@ typename T::open_type ShamirMC::finalize_open() return res; } +template +typename T::open_type ShamirMC::reconstruct(const vector& shares) +{ + assert(reconstruction.size()); + typename T::open_type res; + for (size_t j = 0; j < reconstruction.size(); j++) + { + res += shares[j] * reconstruction[j]; + } + + return res; +} + template void IndirectShamirMC::exchange(const Player& P) { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 6e818c39..e7daabfc 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -13,6 +13,8 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; +template class MaliciousShamirPO; +template class SpecificPrivateOutput; namespace GC { @@ -22,6 +24,8 @@ template class CcdSecret; template class ShamirShare : public T, public ShareInterface { + typedef ShamirShare This; + public: typedef T clear; typedef T open_type; @@ -34,7 +38,8 @@ public: typedef IndirectShamirMC MAC_Check; typedef ShamirMC Direct_MC; typedef ShamirInput Input; - typedef ::PrivateOutput PrivateOutput; + typedef MaliciousShamirPO PO; + typedef SpecificPrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef ShamirShare Honest; diff --git a/Protocols/Share.h b/Protocols/Share.h index 743a2c61..92be4f14 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -55,6 +55,7 @@ class Share_ : public ShareInterface const static bool needs_ot = T::needs_ot; const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; + const static bool has_mac = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index ae6e7b7d..444214e4 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -34,6 +34,7 @@ public: static const bool has_trunc_pr = false; static const bool has_split = false; + static const bool has_mac = false; static const false_type triple_matmul; diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h index e9597527..4c5675e9 100644 --- a/Protocols/SpdzWiseInput.h +++ b/Protocols/SpdzWiseInput.h @@ -36,11 +36,8 @@ public: void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); - void send_mine(); void exchange(); T finalize(int player, int n_bits = -1); - T finalize_mine(); - void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_SPDZWISEINPUT_H_ */ diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index e0d508e5..7aaa14c9 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -85,21 +85,3 @@ T SpdzWiseInput::finalize(int player, int) { return shares[player].next(); } - -template -void SpdzWiseInput::send_mine() -{ - throw runtime_error("use exchange()"); -} - -template -T SpdzWiseInput::finalize_mine() -{ - throw runtime_error("use finalize()"); -} - -template -void SpdzWiseInput::finalize_other(int, T&, octetStream&, int) -{ - throw runtime_error("use finalize()"); -} diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9e953e73..9991dafb 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -32,7 +32,7 @@ public: { } - void init_open(const Player& P, int n) + void init_open(const Player& P, int n = 0) { inner_MC.init_open(P, n); } diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index 9cb86017..1090fc08 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -15,7 +15,6 @@ #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" -#include "MaliciousShamirPO.hpp" #include "GC/RepPrep.hpp" template diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h new file mode 100644 index 00000000..de7406bb --- /dev/null +++ b/Protocols/TemiPrep.h @@ -0,0 +1,72 @@ +/* + * TemiPrep.h + * + */ + +#ifndef PROTOCOLS_TEMIPREP_H_ +#define PROTOCOLS_TEMIPREP_H_ + +#include "ReplicatedPrep.h" +#include "FHEOffline/TemiSetup.h" + +template class HemiMatrixPrep; + +template +class TemiMultiplier +{ + typedef typename T::clear::FD FD; + + vector multiplicands; + + Player& P; + +public: + TemiMultiplier(Player& P); + + vector& get_multiplicands( + vector>& ciphertexts, const FHE_PK& pk); + void add(Plaintext_& res, const Ciphertext& C, OT_ROLE role = BOTH, + int n_summands = 1); + + int get_offset() + { + return 0; + } +}; + +/** + * Semi-honest triple generation with semi-homomorphic encryption + */ +template +class TemiPrep : public SemiHonestRingPrep +{ + friend class HemiMatrixPrep; + + typedef typename T::clear::FD FD; + + static Lock lock; + static TemiSetup* setup; + + vector*> multipliers; + +public: + static void basic_setup(Player& P); + static void teardown(); + + static const FD& get_FTD(); + static const FHE_PK& get_pk(); + static const TemiSetup& get_setup(); + + TemiPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), + BitPrep(proc, usage), RingPrep(proc, usage), + SemiHonestRingPrep(proc, usage) + { + } + + void buffer_triples(); + + vector*>& get_multipliers(); +}; + +#endif /* PROTOCOLS_TEMIPREP_H_ */ diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp new file mode 100644 index 00000000..1088a99c --- /dev/null +++ b/Protocols/TemiPrep.hpp @@ -0,0 +1,129 @@ +/* + * TemiPrep.hppg + * + * + */ + +#ifndef PROTOCOLS_TEMIPREP_HPP_ +#define PROTOCOLS_TEMIPREP_HPP_ + +#include "TemiPrep.h" +#include "FHEOffline/SimpleMachine.h" + +#include "FHEOffline/DataSetup.hpp" + +template +TemiSetup* TemiPrep::setup; + +template +Lock TemiPrep::lock; + +template +void TemiPrep::basic_setup(Player& P) +{ + assert(not setup); + setup = new TemiSetup; + MachineBase machine; + setup->secure_init(P, T::clear::length()); + read_or_generate_secrets(*setup, P, machine, 1, true_type()); + T::clear::template init(); +} + +template +void TemiPrep::teardown() +{ + if (setup) + delete setup; +} + +template +const typename T::clear::FD& TemiPrep::get_FTD() +{ + assert(setup); + return setup->FieldD; +} + +template +inline const FHE_PK& TemiPrep::get_pk() +{ + assert(setup); + return setup->pk; +} + +template +const TemiSetup& TemiPrep::get_setup() +{ + assert(setup); + return *setup; +} + +template +void TemiPrep::buffer_triples() +{ + lock.lock(); + if (setup == 0) + { + PlainPlayer P(this->proc->P.N, "Temi" + T::type_string()); + basic_setup(P); + } + lock.unlock(); + + auto& P = this->proc->P; + auto& FieldD = setup->FieldD; + + Plaintext_ a(FieldD), b(FieldD), c(FieldD); + + SeededPRNG G; + a.randomize(G); + b.randomize(G); + + TreeSum ts; + auto C = ts.run(setup->pk.encrypt(a), P); + C = ts.run(C * b + setup->pk.template encrypt(FieldD), P); + c = SimpleDistDecrypt(P, *setup).reshare(C); + + for (unsigned i = 0; i < a.num_slots(); i++) + this->triples.push_back({{a.element(i), b.element(i), c.element(i)}}); +} + +template +vector*>& TemiPrep::get_multipliers() +{ + assert(setup); + assert( + OnlineOptions::singleton.batch_size + <= setup->params.get_matrix_dim()); + assert(this->proc); + if (multipliers.empty()) + multipliers.push_back(new TemiMultiplier(this->proc->P)); + return multipliers; +} + +template +TemiMultiplier::TemiMultiplier(Player& P) : P(P) +{ +} + +template +vector& TemiMultiplier::get_multiplicands( + vector >& ciphertexts, const FHE_PK& pk) +{ + multiplicands.clear(); + multiplicands.resize(ciphertexts[0].size(), pk); + for (size_t j = 0; j < multiplicands.size(); j++) + for (size_t i = 0; i < ciphertexts.size(); i++) + multiplicands[j] += ciphertexts[i].at(j); + return multiplicands; +} + +template +void TemiMultiplier::add(Plaintext_& res, const Ciphertext& C, + OT_ROLE, int) +{ + TreeSum ts; + SimpleDistDecrypt dd(P, TemiPrep::get_setup()); + auto zero = TemiPrep::get_pk().template encrypt(TemiPrep::get_FTD()); + res += dd.reshare(ts.run(C + zero, P)); +} + +#endif /* PROTOCOLS_TEMIPREP_HPP_ */ diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h new file mode 100644 index 00000000..f4f37dcd --- /dev/null +++ b/Protocols/TemiShare.h @@ -0,0 +1,42 @@ +/* + * TemiShare.h + * + */ + +#ifndef PROTOCOLS_TEMISHARE_H_ +#define PROTOCOLS_TEMISHARE_H_ + +#include "HemiShare.h" + +template class TemiPrep; +template class Hemi; + +template +class TemiShare : public HemiShare +{ + typedef TemiShare This; + typedef HemiShare super; + +public: + typedef SemiMC MAC_Check; + typedef DirectSemiMC Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef typename conditional, Beaver>::type Protocol; + typedef TemiPrep LivePrep; + + static const bool needs_ot = false; + static const bool local_mul = false; + + TemiShare() + { + } + template + TemiShare(const U& other) : + super(other) + { + } + +}; + +#endif /* PROTOCOLS_TEMISHARE_H_ */ diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 951cbfe7..45d92613 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -317,7 +317,14 @@ void read_mac_key(const string& directory, int player_num, int nplayers, U& key) throw mac_key_error(filename); } - key.input(inpf,true); + try + { + key.input(inpf,true); + } + catch(exception&) + { + throw mac_key_error(filename); + } if (inpf.fail()) throw mac_key_error(filename); diff --git a/README.md b/README.md index bd107512..99d0f076 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,31 @@ The following table lists all protocols that are fully supported. | --- | --- | --- | --- | --- | | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | -| Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +Modulo prime and modulo 2^k are the two settings that allow +integer-like computation. For k = 64, the latter corresponds to the +computation available on the widely used 64-bit processors. GF(2^n) +denotes Galois extension fields of order 2^n, which are different to +computation modulo 2^n. In particular, every element has an inverse, +which is not the case modulo 2^n. See [this +article](https://en.wikipedia.org/wiki/Finite_field) for an +introduction. Modulo prime and GF(2^n) are lumped together because the +protocols are very similar due to the mathematical properties. + +Bin. SS stands for binary secret sharing, that is secret sharing +modulo two. In some settings, this requires specific protocols as some +protocols require the domain size to be larger than two. In other +settings, the protocol is the same mathematically speaking, but a +specific implementation allows for optimizations such as using the +inherent parallelism of bit-wise operations on machine words. + +A security model specifies how many parties are "allowed" to misbehave +in what sense. Malicious means that not following the protocol will at +least be detected while semi-honest means that even corrupted parties +are assumed to follow the protocol. See [this paper](https://eprint.iacr.org/2020/300) for an explanation of the various security models and a high-level introduction to multi-party computation. @@ -257,7 +278,9 @@ compute the preprocessing time for a particular computation. add `AVX_OT = 0` in addition. - For optimal results on Linux on ARM, add `ARCH = -march=-march=armv8.2-a+crypto` to `CONFIG.mine`. This enables the - hardware support for AES. + hardware support for AES. See the [GCC + documentation](https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#AArch64-Options) + on available options. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. @@ -501,6 +524,7 @@ The following table shows all programs for dishonest-majority computation using | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | | `chaigear-party.x` | Adapted [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `chaigear.sh` | | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | +| `temi-party.x` | Adapted [CDN01](https://eprint.iacr.org/2000/055) | Mod prime | Semi-honest | `temi.sh` | | `soho-party.x` | Somewhat homomorphic encryption | Mod prime | Semi-honest | `soho.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -538,6 +562,11 @@ Hemi and Soho denote the stripped version version of LowGear and HighGear, respectively, for semi-honest security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. +Temi in turn denotes the adaption of +[Cramer et al.](https://eprint.iacr.org/2000/055) to LWE-based +semi-homomorphic encryption. +Both Hemi and Temi use the diagonal packing by [Halevi and +Shoup](https://eprint.iacr.org/2014/106) for matrix multiplication. We will use MASCOT to demonstrate the use, but the other protocols work similarly. diff --git a/Scripts/prep-usage.py b/Scripts/prep-usage.py new file mode 100755 index 00000000..cb8ca619 --- /dev/null +++ b/Scripts/prep-usage.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +tapename = next(Program.read_tapes(sys.argv[1])) +res = Tape.ReqNum() +for inst in Tape.read_instructions(tapename): + res.update(inst.get_usage()) + +for x in res.pretty(): + print(x) diff --git a/Scripts/temi.sh b/Scripts/temi.sh new file mode 100755 index 00000000..86f46c54 --- /dev/null +++ b/Scripts/temi.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player temi-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 10fe575f..e8c02f6c 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -59,7 +59,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py $compile_opts tutorial for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - atlas mal-shamir sy-shamir hemi semi \ + atlas mal-shamir sy-shamir hemi semi temi \ soho mascot; do test_vm $i $run_opts done diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 941ec425..ffd41123 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -86,6 +86,10 @@ octetStream check_file_signature(ifstream& file, const string& filename) { throw signature_mismatch(filename); } + catch (IO_Error&) + { + throw signature_mismatch(filename); + } if (file_signature() != file_spec) throw signature_mismatch(filename); return file_spec; diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index 96f69b0c..f6f4ba2e 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -35,8 +35,8 @@ wrong_gfp_size::wrong_gfp_size(const char* name, const bigint& p, { } -overflow::overflow(const char* name, size_t i, size_t n) : - runtime_error(string(name) + " overflow: " + to_string(i) + "/" + to_string(n)) +overflow::overflow(const string& name, size_t i, size_t n) : + runtime_error(name + " overflow: " + to_string(i) + "/" + to_string(n)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index 18406cf6..fff8b2de 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -237,7 +237,7 @@ public: class overflow : public runtime_error { public: - overflow(const char* name, size_t i, size_t n); + overflow(const string& name, size_t i, size_t n); }; class unknown_input_type : public runtime_error diff --git a/Tools/octetStream.h b/Tools/octetStream.h index cd90b0e9..676382ea 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -80,6 +80,8 @@ class octetStream size_t get_ptr() const { return ptr; } /// Length size_t get_length() const { return len; } + /// Length including size tag + size_t get_total_length() const { return len + sizeof(len); } /// Allocation size_t get_max_length() const { return mxlen; } /// Data pointer diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index 45e5f337..962b2775 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -129,12 +129,12 @@ void run(int argc, char** argv) output.prepare_open(c); } output.exchange(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - protocol.check(); - output.Check(P); + set.check(); } diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index 532d705e..a36949d6 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -126,12 +126,12 @@ void run(char** argv) output.prepare_open(res); } output.exchange(P); - bit_output.Check(P); + set.check(); cout << "result: "; for (int i = 0; i < n; i++) cout << output.finalize_open() << " "; cout << endl; - output.Check(P); + set.check(); } diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 9cae6953..83571c21 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -110,7 +110,7 @@ void run(char** argv, int prime_length) c = protocol.finalize_dotprod(n); // protocol check before revealing results - protocol.check(); + set.check(); output.init_open(P); output.prepare_open(c); @@ -120,5 +120,5 @@ void run(char** argv, int prime_length) cout << "result: " << result << endl; // result check after opening - output.Check(P); + set.check(); } diff --git a/doc/instructions.rst b/doc/instructions.rst index 1a833994..fb62066e 100644 --- a/doc/instructions.rst +++ b/doc/instructions.rst @@ -85,12 +85,10 @@ Compiler.instructions module .. automodule:: Compiler.instructions :members: :no-undoc-members: - :exclude-members: asm_input, inputmask, lts, print_char4_regint, - print_char_regint, protectmemc, sqrs, - start_grind, startprivateoutput, stop_grind, - stopprivateoutput, writesocketc, writesocketint, - protectmemint, protectmems, print_mem, - matmul_base, g2muls, inputmixed_base, raw_output + :exclude-members: asm_input, sqrs, + start_grind, stop_grind, + writesocketc, writesocketint, + matmul_base, inputmixed_base, raw_output Compiler.GC.instructions module ------------------------------- diff --git a/doc/low-level.rst b/doc/low-level.rst index c70bf5b6..7f5474fd 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -309,6 +309,11 @@ Share Types - ``SpdzWiseShare`` - `SPDZ-wise `_. ``T`` must be ``MaliciousShamirShare`` or ``MaliciousRep3Share``. + * + - ``TemiShare`` + - Semi-honest protocol with Beaver multiplication based on + threshold semi-homomorphic encryption. ``T`` must be + ``gfp_`` or ``gf2n_short``. Protocol Setup diff --git a/doc/non-linear.rst b/doc/non-linear.rst index bcdbbd3a..e5df4c20 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -88,7 +88,7 @@ The following table lists the matching arithmetic and binary protocols. cut-and-choose analysis by `Furukawa et al. `_ * - - Semi, Hemi, Soho, Semi2k + - Semi, Hemi, Temi, Soho, Semi2k - SemiBin (Beaver triples modulo 2 using OT) * - `Malicious Shamir `_ diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 1441e352..21500c45 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -85,7 +85,37 @@ modulo the default 128-bit prime 00000025 -``Fake-Offline.x`` generates preprocessing data insecurely for a range -of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +The actual data is stored is by simple concatenation. For example, +triples are stored as repetitions of ``a, b, ab``, and daBits are +stored as repetitions of ``a, b`` where ``a`` is the arithmetic +share and ``b`` is the binary share. + +For protocols with MAC, the value share is stored before the MAC +share. + +Values are generally stored in little-endian order. Note the following +domain specifics: + +Modulo a prime + Values are stored in `Montgomery representation + `_ + with :math:`R` being the smallest power of :math:`2^{64}` larger than + the prime. For example, :math:`R = 2^{128}` for a 128-bit prime. + Furthermore, the values are stored in the smallest number of 8-byte + blocks necessary, all in little-endian order. + +Modulo a power of two: + Values are stored in the smallest number of 8-byte blocks necessary, + all in little-endian order. + +:math:`GF(2^n)` + Values are stored in blocks according to the storage size above, + all in little-endian order. + +For further details, have a look at ``Utils/Fake-Offline.cpp``, which +contains code that generates preprocessing data insecurely for a range +of protocols (underlying the binary ``Fake-Offline.x``). + +``{mascot,cowgear,mal-shamir}-offline.x`` generate sufficient preprocessing data for a specific high-level program with MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/requirements.txt b/doc/requirements.txt index cd6467ed..32add0c7 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1 +1,2 @@ breathe +sphinx-rtd-theme==0.5.2 From 08ea9b3bd0b33aa5331c9da7f3f5d788d5ffa19b Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 22 Feb 2022 13:25:12 +1100 Subject: [PATCH 022/265] Better scaling network setup. --- Networking/Player.cpp | 26 +++++++++----------------- Networking/Server.cpp | 13 +++++++------ Networking/Server.h | 2 +- Tools/octetStream.cpp | 15 +++++++++++++++ Tools/octetStream.h | 5 +++++ 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/Networking/Player.cpp b/Networking/Player.cpp index cd92df54..b4bab177 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -146,25 +146,17 @@ void Names::setup_names(const char *servername, int my_port) #endif // Now get the set of names - int i; - size_t tmp; - receive(socket_num,tmp,4); - nplayers = tmp; + octetStream os; + os.Receive(socket_num); + os.get(names); + os.get(ports); + if (names.size() != ports.size()) + throw runtime_error("invalid network setup"); + nplayers = names.size(); #ifdef VERBOSE - cerr << nplayers << " players\n"; + for (int i = 0; i < nplayers; i++) + cerr << "Player " << i << " is running on machine " << names[i] << endl; #endif - names.resize(nplayers); - ports.resize(nplayers); - for (i=0; i void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive_stream) const { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 676382ea..96af8191 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -207,6 +207,11 @@ class octetStream s.len=l; } + /// Append string + void store(const string& str); + /// Read string + void get(string& str); + /// Send on ``socket_num`` template void Send(T socket_num) const; From 9c3e607068084f8fff716c510a47ae51a854c3fe Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 24 Feb 2022 11:57:21 +1100 Subject: [PATCH 023/265] Bug when inputting to large arrays. --- Compiler/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 1dbe1f90..a81a2562 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5409,7 +5409,7 @@ class Array(_vectorizable): input_from = self.value_type.get_input_from try: self.assign(input_from(player, size=len(self))) - except TypeError: + except (TypeError, CompilerError): @library.for_range_opt(len(self), budget=budget) def _(i): self[i] = input_from(player) From 6664de3f77bd7abdfdb9cb8e9b01f4e082d59a91 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 1 Mar 2022 17:08:25 +1100 Subject: [PATCH 024/265] Multiplication of matrices larger than the maximum register size. --- Compiler/types.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index a81a2562..c090eef9 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5917,7 +5917,22 @@ class SubMultiArray(_vectorizable): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - res_matrix.assign_vector(self.direct_mul(other)) + if res_matrix.total_size() < _register.maximum_size: + res_matrix.assign_vector(self.direct_mul(other)) + else: + slice = _register.maximum_size // res_matrix.sizes[1] + assert slice > 0 + n = res_matrix.sizes[0] // slice + @library.for_range_opt(n) + def _(i): + res_matrix.assign_part_vector( + self.get_part(i * slice, + slice).direct_mul(other), + i * slice) + base = n * slice + rem = self.sizes[0] - base + res_matrix.assign_part_vector( + self.get_part(base, rem).direct_mul(other), base) except AttributeError: if max(res_matrix.sizes) > 1000: raise AttributeError() From 60dd78797e0e3d93f1e2c02388603f56366eaf46 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 2 Mar 2022 19:25:07 +1100 Subject: [PATCH 025/265] Multithreaded matrix multiplication. --- Compiler/types.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index c090eef9..7f87b905 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5885,7 +5885,7 @@ class SubMultiArray(_vectorizable): # legacy function return self.dot(other, res_params) - def dot(self, other, res_params=None): + def dot(self, other, res_params=None, n_threads=None): """ Matrix-matrix and matrix-vector multiplication. :param self: two-dimensional @@ -5917,22 +5917,12 @@ class SubMultiArray(_vectorizable): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - if res_matrix.total_size() < _register.maximum_size: - res_matrix.assign_vector(self.direct_mul(other)) - else: - slice = _register.maximum_size // res_matrix.sizes[1] - assert slice > 0 - n = res_matrix.sizes[0] // slice - @library.for_range_opt(n) - def _(i): - res_matrix.assign_part_vector( - self.get_part(i * slice, - slice).direct_mul(other), - i * slice) - base = n * slice - rem = self.sizes[0] - base + self.value_type.direct_matrix_mul + max_size = _register.maximum_size // res_matrix.sizes[1] + @library.multithread(n_threads, self.sizes[0], max_size) + def _(base, size): res_matrix.assign_part_vector( - self.get_part(base, rem).direct_mul(other), base) + self.get_part(base, size).direct_mul(other), base) except AttributeError: if max(res_matrix.sizes) > 1000: raise AttributeError() From e485aacd37b7e2f8901b1bee5a4ec132a536a043 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 2 Mar 2022 12:32:13 +1100 Subject: [PATCH 026/265] Easier change of domain in SPDZ2k. --- Machines/SPDZ2k.cpp | 10 ++++++++++ Machines/spdz2k-party.cpp | 22 +++++++++++++++++++--- Processor/RingMachine.hpp | 13 +++++++++---- Protocols/Spdz2kShare.h | 4 ++++ 4 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 Machines/SPDZ2k.cpp diff --git a/Machines/SPDZ2k.cpp b/Machines/SPDZ2k.cpp new file mode 100644 index 00000000..62a4324e --- /dev/null +++ b/Machines/SPDZ2k.cpp @@ -0,0 +1,10 @@ +/* + * SPDZ2k.cpp + * + */ + +#include "SPDZ2k.hpp" + +#ifdef RING_SIZE +template class Machine, Share>; +#endif diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index 188b6291..8aba3173 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -10,7 +10,7 @@ #include "Math/gf2n.h" #include "Networking/Server.h" -#include "Processor/OnlineMachine.hpp" +#include "Processor/RingMachine.hpp" #include "Math/Z2k.hpp" int main(int argc, const char** argv) @@ -46,7 +46,23 @@ int main(int argc, const char** argv) Z(72, 64) Z(72, 48) +#ifdef RING_SIZE + Z(RING_SIZE, SPDZ2K_DEFAULT_SECURITY) +#endif + else - throw runtime_error( - "not compiled for k=" + to_string(k) + " and s=" + to_string(s)); + { + if (s == SPDZ2K_DEFAULT_SECURITY) + { + ring_domain_error(k); + } + else + { + cerr << "not compiled for k=" << k << " and s=" << s << "," << endl; + cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line " + << (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+" + << s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl; + } + exit(1); + } } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index e422e0aa..f2bfc6c1 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -30,6 +30,13 @@ HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char* RingMachine(argc, argv, opt, online_opts, nplayers); } +inline void ring_domain_error(int R) +{ + cerr << "not compiled for " << R << "-bit computation, " << endl; + cerr << "compile with -DRING_SIZE=" << R << endl; + exit(1); +} + template class U, template class V, class W> RingMachine::RingMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) @@ -49,8 +56,7 @@ RingMachine::RingMachine(int argc, const char** argv, #endif #undef X default: - cerr << "not compiled for " << to_string(R) + "-bit computation" << endl; - exit(1); + ring_domain_error(R); } } @@ -88,8 +94,7 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri #endif #undef X default: - cerr << "not compiled for " << to_string(R) + "-bit computation" << endl; - exit(1); + ring_domain_error(R); } } diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index d26cde4f..401070f8 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -6,6 +6,10 @@ #ifndef PROTOCOLS_SPDZ2KSHARE_H_ #define PROTOCOLS_SPDZ2KSHARE_H_ +#ifndef SPDZ2K_DEFAULT_SECURITY +#define SPDZ2K_DEFAULT_SECURITY 64 +#endif + #include "Math/Z2k.h" #include "Protocols/Share.h" #include "Protocols/MAC_Check.h" From 0501a2701cc11376e063817c8731871e47eae835 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 8 Mar 2022 17:05:44 +1100 Subject: [PATCH 027/265] Document domain types. --- Math/Z2k.h | 47 ++++++++++++++++++++++++++++++++++- Math/gfp.h | 63 ++++++++++++++++++++++++++++++++++++++++++++--- Math/gfpvar.h | 8 ++++++ doc/Doxyfile | 2 +- doc/low-level.rst | 16 ++++++++++++ 5 files changed, 131 insertions(+), 5 deletions(-) diff --git a/Math/Z2k.h b/Math/Z2k.h index ad32cbf1..586c78c0 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -19,6 +19,12 @@ using namespace std; template class IntBase; template class fixint; +/** + * Type for values in the ring defined by the integers modulo ``2^K`` + * representing `[0, 2^K-1]`. + * It supports arithmetic, bit-wise, and output streaming operations. + * It does not need initialization because ``K`` completely defines the domain. + */ template class Z2 : public ValueInterface { @@ -71,6 +77,9 @@ public: typedef Z2 next; typedef Z2 Scalar; + /** + * Initialize to zero. + */ Z2() { assign_zero(); } Z2(mp_limb_t x) : Z2() { a[0] = x; } Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); } @@ -78,8 +87,14 @@ public: Z2(long x) : Z2(mp_limb_t(x)) { if (K > 64 and x < 0) memset(&a[1], -1, N_BYTES - 8); } template Z2(const IntBase& x); + /** + * Convert from unrestricted integer. + */ Z2(const bigint& x); Z2(const void* buffer) : Z2() { assign(buffer); } + /** + * Convert from different domain via the canonical integer representation. + */ template Z2(const Z2& x) : Z2() { avx_memcpy(a, x.a, min(N_BYTES, x.N_BYTES)); normalize(); } @@ -140,19 +155,38 @@ public: Z2 invert() const; + /** + * Deterministic square root for values with least significate bit 1. + * Raises an exception otherwise. + */ Z2 sqrRoot(); bool is_zero() const { return *this == Z2(); } bool is_one() const { return *this == 1; } bool is_bit() const { return is_zero() or is_one(); } + /** + * Sample with uniform distribution. + * @param G randomness generator + * @param n (unused) + */ void randomize(PRNG& G, int n = -1); void randomize_part(PRNG& G, int n); void almost_randomize(PRNG& G) { randomize(G); } void force_to_bit() { throw runtime_error("impossible"); } + /** + * Append to buffer in native format. + * @param o buffer + * @param n (unused) + */ void pack(octetStream& o, int = -1) const; + /** + * Read from buffer in native format + * @param o buffer + * @param n (unused) + */ void unpack(octetStream& o, int n = -1); void input(istream& s, bool human=true); @@ -162,21 +196,32 @@ public: friend ostream& operator<<(ostream& o, const Z2& x); }; +/** + * Type for values in the ring defined by the integers modulo ``2^K`` + * representing `[-2^(K-1), 2^(K-1)-1]`. + * It supports arithmetic, bit-wise, comparison, and output streaming operations. + * It does not need initialization because ``K`` completely defines the domain. + */ template class SignedZ2 : public Z2 { public: + /** + * Initialization to zero + */ SignedZ2() { } + /** + * Conversion from another domain via the signed representation + */ template SignedZ2(const SignedZ2& other) : Z2(other) { extend(other); } - template void extend(const SignedZ2& other) { diff --git a/Math/gfp.h b/Math/gfp.h index bde43025..3bc23e19 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -40,6 +40,16 @@ template void generate_prime_setup(string, int, int); #error GFP_MOD_SZ must be at most MAX_MOD_SZ #endif +/** + * Type for values in a field defined by integers modulo a prime + * in a specific range for fixed storage. + * It supports basic arithmetic operations and bit-wise operations. + * The latter use the canonical representation in the range `[0, p-1]`. + * ``X`` is a counter to allow several moduli being used at the same time. + * ``L`` is the number of 64-bit limbs, that is, + * the prime has to have bit length in `[64*L-63, 64*L]`. + * See ``gfpvar_`` for a more flexible alternative. + */ template class gfp_ : public ValueInterface { @@ -72,7 +82,17 @@ class gfp_ : public ValueInterface template static void init(bool mont = true) { init_field(T::pr(), mont); } + /** + * Initialize the field. + * @param p: prime modulus + * @param mont: whether to use Montgomery representation + */ static void init_field(const bigint& p,bool mont=true); + /** + * Initialize the field to a prime of a given bit length. + * @param lgp: bit length + * @param mont: whether to use Montgomery representation + */ static void init_default(int lgp, bool mont = true); static void read_or_generate_setup(string dir, const OnlineOptions& opts); template @@ -85,6 +105,9 @@ class gfp_ : public ValueInterface { write_online_setup(dir, pr()); } static void check_setup(string dir); + /** + * Get the prime modulus + */ static const bigint& pr() { return ZpD.pr; } static int t() @@ -126,15 +149,24 @@ class gfp_ : public ValueInterface const void* get_ptr() const { return &a.x; } void* get_ptr() { return &a.x; } + /** + * Initialize to zero. + */ gfp_() { assignZero(a,ZpD); } template gfp_(const modp_& g) { a=g; } + /** + * Convert from integer without range restrictions. + */ gfp_(const mpz_class& x) { to_modp(a, x, ZpD); } gfp_(int x) : gfp_(long(x)) {} gfp_(long x); gfp_(word x) : gfp_(bigint::tmp = x) {} template gfp_(IntBase x) : gfp_(x.get()) {} + /** + * Convert from different domain via canonical integer representation. + */ template gfp_(const gfp_& x); gfp_(const gfpvar& other); @@ -181,9 +213,16 @@ class gfp_ : public ValueInterface void negate() { Negate(a,a,ZpD); } - // deterministic square root + /** + * Deterministic square root. + */ gfp_ sqrRoot(); + /** + * Sample with uniform distribution. + * @param G randomness generator + * @param n (unused) + */ void randomize(PRNG& G, int n = -1) { (void) n; a.randomize(G,ZpD); } // faster randomization, see implementation for explanation @@ -194,10 +233,20 @@ class gfp_ : public ValueInterface void input(istream& s,bool human) { a.input(s,ZpD,human); } + /** + * Human-readable output in the range `[-p/2, p/2]`. + * @param s output stream + * @param x value + */ friend ostream& operator<<(ostream& s,const gfp_& x) { x.output(s,true); return s; } + /** + * Human-readable input without range restrictions + * @param s input stream + * @param x value + */ friend istream& operator>>(istream& s,gfp_& x) { x.input(s,true); return s; @@ -220,10 +269,18 @@ class gfp_ : public ValueInterface void force_to_bit() { throw runtime_error("impossible"); } - // Pack and unpack in native format - // i.e. Dont care about conversion to human readable form + /** + * Append to buffer in native format. + * @param o buffer + * @param n (unused) + */ void pack(octetStream& o, int n = -1) const { (void) n; a.pack(o); } + /** + * Read from buffer in native format + * @param o buffer + * @param n (unused) + */ void unpack(octetStream& o, int n = -1) { (void) n; a.unpack(o); } diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 438a935e..a3b475f8 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -14,6 +14,14 @@ class FFT_Data; template class BitVec_; +/** + * Type for values in a field defined by integers modulo a prime + * up to a certain length for fixed storage. + * ``X`` is a counter to allow several moduli being used at the same time. + * ``L`` is the maximum number of 64-bit limbs, that is, + * the prime has to have bit length at most `64*L`. + * The interface replicates ``gfp_``. + */ template class gfpvar_ { diff --git a/doc/Doxyfile b/doc/Doxyfile index 3dd29940..36837c38 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/low-level.rst b/doc/low-level.rst index 7f5474fd..fd9d2bfc 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -355,3 +355,19 @@ Protocol Interfaces .. doxygenclass:: BufferPrep :members: + + +Domain Reference +---------------- + +.. doxygenclass:: gfp_ + :members: + +.. doxygenclass:: gfpvar_ + :members: + +.. doxygenclass:: Z2 + :members: + +.. doxygenclass:: SignedZ2 + :members: From 6a223a6b99340e269c8f627fb888891d6aa6990a Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 8 Mar 2022 10:01:19 +0100 Subject: [PATCH 028/265] RTD build. --- doc/Doxyfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index 36837c38..9820ba50 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -801,7 +801,7 @@ WARN_NO_PARAMDOC = NO # a warning is encountered. # The default value is: NO. -WARN_AS_ERROR = YES +WARN_AS_ERROR = NO # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which From c040a54f634f141b59046f9742a73af095a30441 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Wed, 9 Mar 2022 20:23:08 -0600 Subject: [PATCH 029/265] Fix typo in docs --- doc/networking.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/networking.rst b/doc/networking.rst index a1c61b98..c7e031f1 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -13,7 +13,7 @@ base port number, which can be changed using the same option. There are two ways of communicating hosts and individually setting ports: -1. All parties first to connect to a coordination server, which +1. All parties first connect to a coordination server, which broadcasts the data for all parties. This is the default with the coordination server being run as a thread of party 0. The hostname of the coordination server has to be given with the command-line From b283fdb385c0777c5faf7a63557a415955a50995 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 9 Mar 2022 18:25:59 +1100 Subject: [PATCH 030/265] Improved multi-threaded tree reduction. --- Compiler/library.py | 10 ++++++++++ Compiler/types.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/Compiler/library.py b/Compiler/library.py index 3f31499b..46d72ec7 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1340,6 +1340,16 @@ def map_sum_simple(n_threads, n_loops, type, size): return map_reduce(n_threads, 1, n_loops, initializer, summer) def tree_reduce_multithread(n_threads, function, vector): + """ Round-efficient reduction in several threads. The following code + computes the maximum of an array in 10 threads:: + + tree_reduce_multithread(10, lambda x, y: x.max(y), a) + + :param n_threads: number of threads (int) + :param function: reduction function taking exactly two arguments + :param vector: register vector or array + + """ inputs = vector.Array(len(vector)) inputs.assign_vector(vector) outputs = vector.Array(len(vector) // 2) diff --git a/Compiler/types.py b/Compiler/types.py index 7f87b905..b6359733 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5554,6 +5554,10 @@ class Array(_vectorizable): """ library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + def Array(self, size): + # compatibility with registers + return Array(size, self.value_type) + def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), self.address) From 1227376ae3620180be956e3846b08c73343f50b1 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 14 Mar 2022 23:41:18 +1100 Subject: [PATCH 031/265] Bug in conversion from secret integer to secret bits. --- Compiler/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 2f7ca81f..23bee219 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -292,7 +292,7 @@ def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True): """ program.curr_tape.require_bit_length(k + kappa) from .types import sint - if program.use_edabit() and m > 1 and not const_rounds: + if program.use_edabit() and not const_rounds: movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0]) tmp, b[:] = sint.get_edabit(m, True) movs(r_prime, tmp) From 5b248ced927b89f2d357c02fa27faee7712e45ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 16 Mar 2022 12:12:14 +1100 Subject: [PATCH 032/265] Bug in negative sbits input. --- GC/ShareSecret.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 23c86cb2..267b6f83 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -162,7 +162,7 @@ void Processor::inputb(typename T::Input& input, ProcessorBase& input_process for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++) { auto& res = S[x.dest + i]; - res.my_input(input, bigint(whole_input >> (i * dl)).get_ui(), + res.my_input(input, bigint(whole_input >> (i * dl)).get_si(), min(dl, x.n_bits - i * dl)); } } From 07292ec09dbfcfc996802e4e2f88e943b253ccb1 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 31 Mar 2022 18:30:12 +0200 Subject: [PATCH 033/265] Bug in integer division. --- Compiler/library.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index 46d72ec7..06503c7b 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1848,7 +1848,8 @@ def sint_cint_division(a, b, k, f, kappa): return (sign_a * sign_b) * A def IntDiv(a, b, k, kappa=None): - return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, + l = 2 * k + 1 + return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k, kappa, nearest=True) @instructions_base.ret_cisc From 565c364cd4204a8d697c7ab3d235774a15ecb29e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 2 Apr 2022 07:41:22 +0200 Subject: [PATCH 034/265] Bug in fixed-point division. --- Compiler/library.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Compiler/library.py b/Compiler/library.py index 06503c7b..35d6f46f 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1876,19 +1876,20 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() - y = a.extend(2 *k) * w - y = y.round(2*k, f, kappa, nearest, signed=True) + l_y = k + 3 * f - res_f + y = a.extend(l_y) * w + y = y.round(l_y, f, kappa, nearest, signed=True) for i in range(theta - 1): x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) + y = y.extend(l_y) * (alpha + x).extend(l_y) x = x * x - y = y.round(2*k, 2*f, kappa, nearest, signed=True) + y = y.round(l_y, 2*f, kappa, nearest, signed=True) x = x.round(2*k, 2*f, kappa, nearest, signed=True) x = x.extend(2 * k) - y = y.extend(2 * k) * (alpha + x).extend(2 * k) - y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) + y = y.extend(l_y) * (alpha + x).extend(l_y) + y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False): From b0e7857cbc29dfa12d5339abf114e83f10b9baf9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 1 Apr 2022 20:28:41 +0200 Subject: [PATCH 035/265] Store MAC keys for persistence. --- Processor/Machine.hpp | 9 +++++---- Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 4 +++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index cd318f1a..e720b2a9 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -6,6 +6,7 @@ #include "Memory.hpp" #include "Online-Thread.hpp" #include "Protocols/Hemi.hpp" +#include "Protocols/fake-stuff.hpp" #include "Tools/Exceptions.h" @@ -60,10 +61,10 @@ Machine::Machine(int my_number, Names& playerNames, sint::LivePrep::basic_setup(*P); } - sint::read_or_generate_mac_key(prep_dir_prefix(), *P, alphapi); - sgf2n::read_or_generate_mac_key(prep_dir_prefix(), *P, alpha2i); - sint::bit_type::part_type::read_or_generate_mac_key( - prep_dir_prefix(), *P, alphabi); + alphapi = read_generate_write_mac_key(*P); + alpha2i = read_generate_write_mac_key(*P); + alphabi = read_generate_write_mac_key(*P); #ifdef DEBUG_MAC cerr << "MAC Key p = " << alphapi << endl; diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 0c209869..d15581eb 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -34,7 +34,7 @@ template void read_mac_key(const string& directory, const Names& N, U& key); template -typename T::mac_key_type read_generate_write_mac_key(const Player& P, +typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory = ""); template diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 45d92613..aeb51611 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -272,7 +272,9 @@ void write_mac_key(const string& directory, int i, int nplayers, U key) ofstream outf; stringstream filename; filename << mac_filename(directory, i); +#ifdef VERBOSE cout << "Writing to " << filename.str().c_str() << endl; +#endif outf.open(filename.str().c_str()); outf << nplayers << endl; key.output(outf,true); @@ -333,7 +335,7 @@ void read_mac_key(const string& directory, int player_num, int nplayers, U& key) } template -inline typename T::mac_key_type read_generate_write_mac_key(const Player& P, +typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory) { if (directory == "") From f42930edc306a5145ba3d80fe10b7a2e0e452925 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 2 Apr 2022 08:06:03 +0200 Subject: [PATCH 036/265] Sufficient offline preprocessing with several threads. --- Processor/OfflineMachine.h | 3 +++ Processor/OfflineMachine.hpp | 15 +++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Processor/OfflineMachine.h b/Processor/OfflineMachine.h index 792c5bda..d1b14256 100644 --- a/Processor/OfflineMachine.h +++ b/Processor/OfflineMachine.h @@ -18,10 +18,13 @@ class OfflineMachine : public W BaseMachine machine; Names& playerNames; Player& P; + int n_threads; template void generate(); + int buffered_total(size_t required, size_t batch); + public: template OfflineMachine(int argc, const char** argv, diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index dcfafe55..6e0bb525 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -22,6 +22,7 @@ OfflineMachine::OfflineMachine(int argc, const char** argv, Program program(playerNames.num_players()); program.parse(machine.bc_filenames[0]); usage = program.get_offline_data_used(); + n_threads = machine.nthreads; machine.ot_setups.push_back({P}); } @@ -52,6 +53,12 @@ int OfflineMachine::run() return 0; } +template +int OfflineMachine::buffered_total(size_t required, size_t batch) +{ + return DIV_CEIL(required, batch) * batch + (n_threads - 1) * batch; +} + template template void OfflineMachine::generate() @@ -79,7 +86,7 @@ void OfflineMachine::generate() if (i == DATA_DABIT) { for (long long j = 0; - j < DIV_CEIL(my_usage, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(my_usage, BUFFER_SIZE); j++) { T a; typename T::bit_type b; @@ -91,7 +98,7 @@ void OfflineMachine::generate() { vector tuple(DataPositions::tuple_size[i]); for (long long j = 0; - j < DIV_CEIL(my_usage, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(my_usage, BUFFER_SIZE); j++) { preprocessing.get(dtype, tuple.data()); for (auto& x : tuple) @@ -113,7 +120,7 @@ void OfflineMachine::generate() file_signature().output(out); InputTuple tuple; for (long long j = 0; - j < DIV_CEIL(n_inputs, BUFFER_SIZE) * BUFFER_SIZE; j++) + j < buffered_total(n_inputs, BUFFER_SIZE); j++) { preprocessing.get_input(tuple.share, tuple.value, i); tuple.share.output(out, false); @@ -142,7 +149,7 @@ void OfflineMachine::generate() { ofstream out(filename, ios::binary); file_signature().output(out); - for (int i = 0; i < DIV_CEIL(total, batch) * batch; i++) + for (int i = 0; i < buffered_total(total, batch); i++) preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits, out); } From 06f3f21cee7ba02af9ed68144cd49be0cb2340bf Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Thu, 10 Feb 2022 23:43:12 -0600 Subject: [PATCH 037/265] Add SSL_DIR env var --- CONFIG | 3 +++ Networking/CryptoPlayer.cpp | 6 +++--- Networking/ssl_sockets.h | 8 ++++++-- README.md | 3 ++- Scripts/setup-ssl.sh | 7 ++++--- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/CONFIG b/CONFIG index 05b3683d..bf92327e 100644 --- a/CONFIG +++ b/CONFIG @@ -8,6 +8,9 @@ GDEBUG = -g # set this to your preferred local storage directory PREP_DIR = '-DPREP_DIR="Player-Data/"' +# directory to store SSL keys +SSL_DIR = '-DSSL_DIR="Player-Data/"' + # set for SHE preprocessing (SPDZ and Overdrive) USE_NTL = 0 diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 9d8da651..43b2ada5 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -19,9 +19,9 @@ void ssl_error(string side, string other, string me) { cerr << side << "-side handshake with " << other << " failed. Make sure both sides " - << " have the necessary certificate (" << PREP_DIR << me + << " have the necessary certificate (" << SSL_DIR << me << ".pem in the default configuration on their side and " - << PREP_DIR << other << ".pem on ours)," + << SSL_DIR << other << ".pem on ours)," << " and run `c_rehash ` on its location." << endl << "The certificates should be the same on every host. " << "Also make sure that it's still valid. Certificates generated " @@ -36,7 +36,7 @@ void ssl_error(string side, string other, string me) cerr << "Signature (should match the other side): "; for (int i = 0; i < 2; i++) { - auto filename = PREP_DIR + ids[i] + ".pem"; + auto filename = SSL_DIR + ids[i] + ".pem"; ifstream cert(filename); stringstream buffer; buffer << cert.rdbuf(); diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 79cb3522..fe9477a8 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -14,6 +14,10 @@ #include #include +#ifndef SSL_DIR +#define SSL_DIR "Player-Data/" +#endif + typedef boost::asio::io_service ssl_service; void check_ssl_file(string filename); @@ -25,7 +29,7 @@ public: ssl_ctx(string me) : boost::asio::ssl::context(boost::asio::ssl::context::tlsv12) { - string prefix = PREP_DIR + me; + string prefix = SSL_DIR + me; string cert_file = prefix + ".pem"; string key_file = prefix + ".key"; check_ssl_file(cert_file); @@ -33,7 +37,7 @@ public: use_certificate_file(cert_file, pem); use_private_key_file(key_file, pem); - add_verify_path(PREP_DIR); + add_verify_path(SSL_DIR); } }; diff --git a/README.md b/README.md index 99d0f076..4f9b4456 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,7 @@ compute the preprocessing time for a particular computation. on available options. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). + - `SSL_DIR` should point to a local, unversioned directory to store ssl keys (the default is `Player-Data` in the current directory). - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. 2) Run `make` to compile all the software (use the flag `-j` for faster @@ -707,7 +708,7 @@ information. MP-SPDZ uses OpenSSL for secure channels. You can generate the necessary certificates and keys as follows: -`Scripts/setup-ssl.sh []` +`Scripts/setup-ssl.sh [ ]` The programs expect the keys and certificates to be in `Player-Data/P.key` and `Player-Data/P.pem`, respectively, and diff --git a/Scripts/setup-ssl.sh b/Scripts/setup-ssl.sh index ffd79bf0..01113f16 100755 --- a/Scripts/setup-ssl.sh +++ b/Scripts/setup-ssl.sh @@ -4,13 +4,14 @@ PATH=/usr/local/opt/openssl/bin:$PATH n=${1:-4} +ssl_dir=${2:-"Player-Data"} -test -e Player-Data || mkdir Player-Data +test -e $ssl_dir || mkdir $ssl_dir echo Setting up SSL for $n parties for i in `seq 0 $[n-1]`; do - openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" + openssl req -newkey rsa -nodes -x509 -out $ssl_dir/P$i.pem -keyout $ssl_dir/P$i.key -subj "/CN=P$i" done -c_rehash Player-Data +c_rehash $ssl_dir From 68d6eb0832af9076e317b046c9f3deb69e06724b Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Fri, 8 Apr 2022 20:54:12 -0500 Subject: [PATCH 038/265] Add .env to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ce68ca4e..9a4dd72e 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ Thumbs.db # Sphinx build _build/ + +# environment +.env From 9b4e0447eb5a1a55233970462f4cdc92d9d7ba84 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Tue, 12 Apr 2022 20:54:35 -0500 Subject: [PATCH 039/265] Add missing SSL_DIR cflag in CONFIG --- CONFIG | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONFIG b/CONFIG index bf92327e..cef15e0b 100644 --- a/CONFIG +++ b/CONFIG @@ -87,7 +87,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) From 5773930bbfe085e69a171d8a2cbd128b9ffa0f8a Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Tue, 12 Apr 2022 21:44:08 -0500 Subject: [PATCH 040/265] Emphasize that ssl keys must be under SSL_DIR --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4f9b4456..59a90a11 100644 --- a/README.md +++ b/README.md @@ -708,16 +708,18 @@ information. MP-SPDZ uses OpenSSL for secure channels. You can generate the necessary certificates and keys as follows: -`Scripts/setup-ssl.sh [ ]` +`Scripts/setup-ssl.sh [ ]` The programs expect the keys and certificates to be in -`Player-Data/P.key` and `Player-Data/P.pem`, respectively, and +`SSL_DIR/P.key` and `SSL_DIR/P.pem`, respectively, and the certificates to have the common name `P` for player ``. Furthermore, the relevant root certificates have to be in -`Player-Data` such that OpenSSL can find them (run `c_rehash -Player-Data`). The script above takes care of all this by generating +`SSL_DIR` such that OpenSSL can find them (run `c_rehash +`). The script above takes care of all this by generating self-signed certificates. Therefore, if you are running the programs on different hosts you will need to copy the certificate files. +Note that `` must match `SSL_DIR` set in `CONFIG` or `CONFIG.mine`. +Just like `SSL_DIR`, `` defaults to `Player-Data`. In the following, we will walk through running the tutorial modulo 2^k with three parties. The other programs work similarly. From 30508236a772cacda146e4f518125c438617202c Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Fri, 8 Apr 2022 15:16:34 -0500 Subject: [PATCH 041/265] Add Dockerfile --- Dockerfile | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..e01dd5c5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,149 @@ +############################################################################### +# Build this stage for a build environment, e.g.: # +# # +# docker build --tag mpspdz:buildenv --target buildenv . # +# # +# The above is equivalent to: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=native \ # +# --build-arg cxx=clang++-11 \ # +# --build-arg use_ntl=0 \ # +# --build-arg prep_dir="Player-Data" \ # +# --build-arg ssl_dir="Player-Data" # +# --build-arg cryptoplayers=0 # +# # +# To build for an x86-64 architecture, with g++, NTL (for HE), custom # +# prep_dir & ssl_dir, and to use encrypted channels for 4 players: # +# # +# docker build --tag mpspdz:buildenv \ # +# --target buildenv \ # +# --build-arg arch=x86-64 \ # +# --build-arg cxx=g++ \ # +# --build-arg use_ntl=1 \ # +# --build-arg prep_dir="/opt/prepdata" \ # +# --build-arg ssl_dir="/opt/ssl" # +# --build-arg cryptoplayers=4 . # +# # +# To work in a container to build different machines, and compile programs: # +# # +# docker run --rm -it mpspdz:buildenv bash # +# # +# Once in the container, build a machine and compile a program: # +# # +# $ make replicated-ring-party.x # +# $ ./compile.py -R 64 tutorial # +# # +############################################################################### +FROM python:3.10.3-bullseye as buildenv + +RUN apt-get update && apt-get install -y --no-install-recommends \ + automake \ + build-essential \ + clang-11 \ + git \ + libboost-dev \ + libboost-thread-dev \ + libclang-dev \ + libntl-dev \ + libsodium-dev \ + libssl-dev \ + libtool \ + m4 \ + texinfo \ + yasm \ + vim \ + gdb \ + valgrind \ + && rm -rf /var/lib/apt/lists/* + +# mpir +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/ +COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/ + +ENV MP_SPDZ_HOME /usr/src/MP-SPDZ +WORKDIR $MP_SPDZ_HOME + +RUN pip install --upgrade pip ipython + +COPY . . + +ARG arch=native +ARG cxx=clang++-11 +ARG use_ntl=0 +ARG prep_dir="Player-Data" +ARG ssl_dir="Player-Data" + +RUN echo "ARCH = -march=${arch}" >> CONFIG.mine \ + && echo "CXX = ${cxx}" >> CONFIG.mine \ + && echo "USE_NTL = ${use_ntl}" >> CONFIG.mine \ + && echo "MY_CFLAGS += -I/usr/local/include" >> CONFIG.mine \ + && echo "MY_LDLIBS += -Wl,-rpath -Wl,/usr/local/lib -L/usr/local/lib" \ + >> CONFIG.mine \ + && mkdir -p $prep_dir $ssl_dir \ + && echo "PREP_DIR = '-DPREP_DIR=\"${prep_dir}/\"'" >> CONFIG.mine \ + && echo "SSL_DIR = '-DSSL_DIR=\"${ssl_dir}/\"'" >> CONFIG.mine + +# ssl keys +ARG cryptoplayers=0 +ENV PLAYERS ${cryptoplayers} +RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir} + + +############################################################################### +# Use this stage to a build a specific virtual machine. For example: # +# # +# docker build --tag mpspdz:shamir \ # +# --target machine \ # +# --build-arg machine=shamir-party.x \ # +# --build-arg gfp_mod_sz=4 . # +# # +# The above will build shamir-party.x with 256 bit length. # +# # +# If no build arguments are passed (via --build-arg), mascot-party.x is built # +# with the default 128 bit length. # +############################################################################### +FROM buildenv as machine + +ARG machine="mascot-party.x" + +ARG gfp_mod_sz=2 + +RUN echo "MOD = -DGFP_MOD_SZ=${gfp_mod_sz}" >> CONFIG.mine + +RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/ + + +################################################################################ +# This is the default stage. Use it to compile a high-level program. # +# By default, tutorial.mpc is compiled with --field=64 bits. # +# # +# docker build --tag mpspdz:mascot-tutorial \ # +# --build-arg src=tutorial \ # +# --build-arg compile_options="--field=64" . # +# # +# Note that build arguments from previous stages can also be passed. For # +# instance, building replicated-ring-party.x, for 3 crypto players with custom # +# PREP_DIR and SSL_DIR, and compiling tutorial.mpc with --ring=64: # +# # +# docker build --tag mpspdz:replicated-ring \ # +# --build-arg machine=replicated-ring-party.x \ # +# --build-arg prep_dir=/opt/prep \ # +# --build-arg ssl_dir=/opt/ssl \ # +# --build-arg nparties=3 \ # +# --build-arg compile_options="--ring=64" . # +# # +# Test it: # +# # +# docker run --rm -it mpspdz:replicated-ring ./Scripts/ring.sh tutorial # +################################################################################ +FROM machine as program + +ARG src="tutorial" +ARG compile_options="--field=64" +RUN ./compile.py ${compile_options} ${src} +RUN mkdir -p Player-Data \ + && echo 1 2 3 4 > Player-Data/Input-P0-0 \ + && echo 1 2 3 4 > Player-Data/Input-P1-0 From 4d17b4f38957306e5c105dd4e097a0138a60f889 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Mon, 11 Apr 2022 15:18:07 -0500 Subject: [PATCH 042/265] Add tl;dr for docker in readme --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 4f9b4456..b5b4445e 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,21 @@ echo 1 2 3 4 > Player-Data/Input-P1-0 Scripts/mascot.sh tutorial ``` +#### TL;DR (Docker) +Build a docker image for `mascot-party.x`: + +``` +docker build --tag mpspdz:mascot-party --build-arg machine=mascot-party.x . +``` + +Run the [the tutorial](Programs/Source/tutorial.mpc): + +``` +docker run --rm -it mpspdz:mascot-party ./Scripts/mascot.sh tutorial +``` + +See the [`Dockerfile`](./Dockerfile) for examples of how it can be used. + #### Preface The primary aim of this software is to run the same computation in From 24bf33aba24d55959fb0fb6fa6550c8c2a127bf1 Mon Sep 17 00:00:00 2001 From: Bishakh Ghosh Date: Fri, 15 Apr 2022 15:41:51 +0530 Subject: [PATCH 043/265] Fix typo in ECDSA readme --- ECDSA/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ECDSA/README.md b/ECDSA/README.md index 6307a91e..7cb014ce 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -24,8 +24,8 @@ The following binaries have been used for the paper: All binaries offer the same interface. With MASCOT for example, run the following: ``` -./mascot-ecsda-party.x -p 0 [-N ] [-h ] [-D] [] -./mascot-ecsda-party.x -p 1 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 0 [-N ] [-h ] [-D] [] +./mascot-ecdsa-party.x -p 1 [-N ] [-h ] [-D] [] ... ``` From 9ef15cc2f56d1aa335d9884e3f7bb75be0eed2af Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 19 Apr 2022 15:10:18 +0200 Subject: [PATCH 044/265] Protocol in dealer model. --- CHANGELOG.md | 10 ++ Compiler/GC/instructions.py | 8 +- Compiler/GC/types.py | 15 +- Compiler/allocator.py | 4 +- Compiler/instructions.py | 12 +- Compiler/instructions_base.py | 13 +- Compiler/library.py | 13 ++ Compiler/ml.py | 7 +- Compiler/program.py | 3 + Compiler/types.py | 35 ++++- ECDSA/P256Element.h | 1 + ECDSA/semi-ecdsa-party.cpp | 1 + ExternalIO/Client.h | 4 +- ExternalIO/Client.hpp | 6 +- ExternalIO/bankers-bonus-client.cpp | 18 +-- FHE/FHE_Params.cpp | 23 ++- FHE/FHE_Params.h | 5 +- FHE/Matrix.cpp | 13 +- FHE/NTL-Subs.cpp | 73 ++++++--- FHE/NTL-Subs.h | 9 +- FHE/NoiseBounds.cpp | 18 ++- FHE/Subroutines.cpp | 38 ++--- FHEOffline/DataSetup.cpp | 8 +- FHEOffline/DataSetup.h | 2 +- FHEOffline/Multiplier.cpp | 2 +- FHEOffline/PairwiseMachine.cpp | 4 +- FHEOffline/PairwiseSetup.cpp | 5 +- FHEOffline/Proof.h | 1 + FHEOffline/Sacrificing.cpp | 9 +- FHEOffline/SimpleMachine.cpp | 7 +- FHEOffline/SimpleMachine.h | 1 - FHEOffline/TemiSetup.cpp | 4 +- GC/CcdShare.h | 2 +- GC/DealerPrep.h | 69 +++++++++ GC/FakeSecret.h | 2 +- GC/MaliciousCcdShare.h | 2 +- GC/MaliciousRepSecret.h | 4 +- GC/NoShare.h | 16 +- GC/Processor.hpp | 2 +- GC/SemiSecret.h | 106 +++++++++---- GC/{SemiSecret.cpp => SemiSecret.hpp} | 42 ++++-- GC/ShareSecret.h | 6 +- GC/ThreadMaster.hpp | 9 +- GC/TinyMC.h | 4 +- Machines/SPDZ2k.hpp | 1 + Machines/Semi.hpp | 1 + Machines/ShamirMachine.hpp | 1 - Machines/TripleMachine.cpp | 4 +- Machines/dealer-ring-party.cpp | 22 +++ Machines/emulate.cpp | 5 +- Machines/hemi-party.cpp | 1 + Machines/malicious-ccd-party.cpp | 3 +- Machines/semi-bin-party.cpp | 1 + Machines/soho-party.cpp | 1 + Machines/spdz2k-party.cpp | 8 +- Machines/temi-party.cpp | 1 + Machines/tinier-party.cpp | 4 +- Machines/tiny-party.cpp | 2 +- Makefile | 26 ++-- Math/Integer.h | 2 - Math/Setup.cpp | 15 +- Math/Setup.h | 1 + Math/bigint.h | 4 +- Math/gf2n.cpp | 35 ++++- Math/gf2n.h | 3 +- Networking/AllButLastPlayer.h | 67 +++++++++ Networking/Player.cpp | 32 +++- Networking/Player.h | 30 ++-- OT/NPartyTripleGenerator.hpp | 9 +- Processor/BaseMachine.cpp | 2 + Processor/Data_Files.hpp | 4 +- Processor/Input.h | 4 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 6 +- Processor/Instruction.hpp | 42 +++--- Processor/Machine.h | 19 ++- Processor/Machine.hpp | 135 ++++++++++++----- Processor/OfflineMachine.hpp | 7 +- Processor/Online-Thread.hpp | 5 +- Processor/OnlineMachine.h | 4 +- Processor/OnlineMachine.hpp | 44 +----- Processor/OnlineOptions.cpp | 49 ++++++- Processor/OnlineOptions.h | 9 +- Processor/OnlineOptions.hpp | 42 +++++- Processor/Program.h | 4 +- Processor/RingMachine.hpp | 16 +- Processor/RingOptions.cpp | 19 +-- Processor/RingOptions.h | 4 +- Processor/instructions.h | 4 +- Programs/Source/l2h_comparison.mpc | 3 + Programs/Source/l2h_multiplication.mpc | 1 + Protocols/Beaver.h | 1 + Protocols/ChaiGearPrep.hpp | 9 +- Protocols/CowGearOptions.cpp | 14 +- Protocols/CowGearOptions.h | 1 - Protocols/CowGearPrep.hpp | 7 +- Protocols/DabitSacrifice.h | 6 +- Protocols/DabitSacrifice.hpp | 6 + Protocols/DealerInput.h | 38 +++++ Protocols/DealerInput.hpp | 115 +++++++++++++++ Protocols/DealerMC.h | 42 ++++++ Protocols/DealerMC.hpp | 76 ++++++++++ Protocols/DealerPrep.h | 33 +++++ Protocols/DealerPrep.hpp | 196 +++++++++++++++++++++++++ Protocols/DealerShare.h | 76 ++++++++++ Protocols/FakeMC.h | 2 +- Protocols/FakeProtocol.h | 4 + Protocols/FakeShare.h | 3 +- Protocols/Hemi.h | 1 + Protocols/Hemi.hpp | 3 +- Protocols/HemiMatrixPrep.h | 11 +- Protocols/HemiMatrixPrep.hpp | 1 + Protocols/HemiPrep.hpp | 4 +- Protocols/MAC_Check.h | 3 +- Protocols/MAC_Check.hpp | 39 +---- Protocols/MAC_Check_Base.h | 3 +- Protocols/MAC_Check_Base.hpp | 10 +- Protocols/MalRepRingOptions.cpp | 4 +- Protocols/MalRepRingPrep.hpp | 2 +- Protocols/MaliciousRepMC.h | 6 +- Protocols/MaliciousRepMC.hpp | 4 +- Protocols/MaliciousRepPrep.hpp | 4 +- Protocols/MaliciousShamirMC.h | 2 +- Protocols/MaliciousShamirMC.hpp | 2 +- Protocols/MaliciousShamirShare.h | 3 + Protocols/MamaPrep.hpp | 12 +- Protocols/MamaShare.h | 5 + Protocols/NoLivePrep.h | 5 + Protocols/NoProtocol.h | 1 + Protocols/NoShare.h | 3 - Protocols/Rep3Share.h | 7 +- Protocols/Replicated.h | 2 + Protocols/Replicated.hpp | 13 +- Protocols/ReplicatedMC.h | 6 +- Protocols/ReplicatedMC.hpp | 2 +- Protocols/SPDZ.h | 3 +- Protocols/SPDZ2k.h | 28 ++++ Protocols/Semi2kShare.h | 2 - Protocols/SemiInput.h | 4 +- Protocols/SemiInput.hpp | 2 +- Protocols/SemiShare.h | 11 +- Protocols/ShamirMC.h | 2 +- Protocols/ShamirMC.hpp | 4 +- Protocols/ShamirShare.h | 37 +---- Protocols/Share.h | 10 +- Protocols/Share.hpp | 7 + Protocols/ShareInterface.h | 5 + Protocols/ShareMatrix.h | 2 +- Protocols/ShuffleSacrifice.hpp | 5 +- Protocols/SohoPrep.hpp | 1 + Protocols/Spdz2kPrep.h | 3 +- Protocols/Spdz2kShare.h | 3 +- Protocols/SpdzWise.hpp | 3 + Protocols/SpdzWiseMC.h | 4 +- Protocols/SpdzWiseShare.hpp | 14 +- Protocols/TemiPrep.h | 2 + Protocols/TemiPrep.hpp | 8 + Protocols/config.h | 13 ++ Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 57 ++++++- Protocols/mac_key.hpp | 8 + README.md | 24 ++- Scripts/dealer-ring.sh | 10 ++ Scripts/memory-usage.py | 9 +- Scripts/test_tutorial.sh | 6 +- Scripts/tldr.sh | 6 +- Tools/Buffer.cpp | 5 + Tools/Exceptions.cpp | 5 + Tools/Exceptions.h | 6 + Tools/avx_memcpy.h | 9 +- Tools/benchmarking.cpp | 19 +++ Tools/benchmarking.h | 15 +- Tools/intrinsics.h | 2 + Tools/parse.h | 8 + Utils/Check-Offline.cpp | 12 +- Utils/Fake-Offline.cpp | 34 ++--- Utils/binary-example.cpp | 1 + Utils/l2h-example.cpp | 54 +++++++ Utils/mixed-example.cpp | 1 + Utils/paper-example.cpp | 5 +- Yao/YaoEvalWire.cpp | 2 +- Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoPlayer.cpp | 2 +- Yao/YaoWire.hpp | 2 +- doc/Compiler.rst | 10 +- doc/non-linear.rst | 10 ++ 186 files changed, 2008 insertions(+), 618 deletions(-) create mode 100644 GC/DealerPrep.h rename GC/{SemiSecret.cpp => SemiSecret.hpp} (61%) create mode 100644 Machines/dealer-ring-party.cpp create mode 100644 Networking/AllButLastPlayer.h create mode 100644 Programs/Source/l2h_comparison.mpc create mode 100644 Programs/Source/l2h_multiplication.mpc create mode 100644 Protocols/DealerInput.h create mode 100644 Protocols/DealerInput.hpp create mode 100644 Protocols/DealerMC.h create mode 100644 Protocols/DealerMC.hpp create mode 100644 Protocols/DealerPrep.h create mode 100644 Protocols/DealerPrep.hpp create mode 100644 Protocols/DealerShare.h create mode 100644 Protocols/SPDZ2k.h create mode 100644 Protocols/config.h create mode 100755 Scripts/dealer-ring.sh create mode 100644 Utils/l2h-example.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a0406a8..744d0ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.1 (Apr 19, 2022) + +- Protocol in dealer model +- Command-line option for security parameter +- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482)) +- Ability to run high-level (Python) code from C++ +- More memory capacity due to 64-bit addressing +- Homomorphic encryption for more fields of characteristic two +- Docker container + ## 0.3.0 (Feb 17, 2022) - Semi-honest computation based on threshold semi-homomorphic encryption diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index ef9c14a3..e53b7187 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -305,7 +305,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMSB'] - arg_format = ['sbw','int'] + arg_format = ['sbw','long'] class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy secret bit register to secret bit memory cell with compile-time @@ -315,7 +315,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMSB'] - arg_format = ['sb','int'] + arg_format = ['sb','long'] # def __init__(self, *args, **kwargs): # super(type(self), self).__init__(*args, **kwargs) # import inspect @@ -330,7 +330,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction, :param: memory address (int) """ code = opcodes['LDMCB'] - arg_format = ['cbw','int'] + arg_format = ['cbw','long'] class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): """ Copy clear bit register to clear bit memory cell with compile-time @@ -340,7 +340,7 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): :param: memory address (int) """ code = opcodes['STMCB'] - arg_format = ['cb','int'] + arg_format = ['cb','long'] class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ Copy secret bit memory cell with run-time address to secret bit diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 38c37a26..6c3abad0 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -3,6 +3,9 @@ This modules contains basic types for binary circuits. The fixed-length types obtained by :py:obj:`get_type(n)` are the preferred way of using them, and in some cases required in connection with container types. + +Computation using these types will always be executed as a binary +circuit. See :ref:`protocol-pairs` for the exact protocols. """ from Compiler.types import MemValue, read_mem_value, regint, Array, cint @@ -17,7 +20,6 @@ import math from functools import reduce class bits(Tape.Register, _structure, _bit): - """ Base class for binary registers. """ n = 40 unit = 64 PreOp = staticmethod(floatingpoint.PreOpN) @@ -400,12 +402,18 @@ class sbits(bits): res = sbit() inst.bitb(res) return res + @staticmethod + def _check_input_player(player): + if not util.is_constant(player): + raise CompilerError('player must be known at compile time ' + 'for binary circuit inputs') @classmethod def get_input_from(cls, player, n_bits=None): """ Secret input from :py:obj:`player`. :param: player (int) """ + cls._check_input_player(player) if n_bits is None: n_bits = cls.n res = cls() @@ -653,6 +661,7 @@ class sbitvec(_vec): :param: player (int) """ + sbits._check_input_player(player) res = cls.from_vec(sbit() for i in range(n)) inst.inputbvec(n + 3, 0, player, *res.v) return res @@ -780,6 +789,8 @@ class sbitvec(_vec): size = other.size return (other.get_vector(base, min(64, size - base)) \ for base in range(0, size, 64)) + if not isinstance(other, type(self)): + return type(self)(other) return other def __xor__(self, other): other = self.coerce(other) @@ -1222,6 +1233,7 @@ class sbitfix(_fix): :param: player (int) """ + sbits._check_input_player(player) v = cls.int_type() inst.inputb(player, cls.k, cls.f, v) return cls._new(v) @@ -1287,6 +1299,7 @@ class sbitfixvec(_fix): :param: player (int) """ v = [sbit() for i in range(sbitfix.k)] + sbits._check_input_player(player) inst.inputbvec(len(v) + 3, sbitfix.f, player, *v) return cls._new(cls.int_type.from_vec(v)) def __init__(self, value=None, *args, **kwargs): diff --git a/Compiler/allocator.py b/Compiler/allocator.py index cf2f13ef..bf431ca3 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -15,11 +15,11 @@ from functools import reduce class BlockAllocator: """ Manages freed memory blocks. """ def __init__(self): - self.by_logsize = [defaultdict(set) for i in range(32)] + self.by_logsize = [defaultdict(set) for i in range(64)] self.by_address = {} def by_size(self, size): - if size >= 2 ** 32: + if size >= 2 ** 64: raise CompilerError('size exceeds addressing capability') return self.by_logsize[int(math.log(size, 2))][size] diff --git a/Compiler/instructions.py b/Compiler/instructions.py index e0679768..8a10ee58 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -69,7 +69,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMC'] - arg_format = ['cw','int'] + arg_format = ['cw','long'] @base.gf2n @base.vectorize @@ -84,7 +84,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMS'] - arg_format = ['sw','int'] + arg_format = ['sw','long'] @base.gf2n @base.vectorize @@ -99,7 +99,7 @@ class stmc(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMC'] - arg_format = ['c','int'] + arg_format = ['c','long'] @base.gf2n @base.vectorize @@ -114,7 +114,7 @@ class stms(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMS'] - arg_format = ['s','int'] + arg_format = ['s','long'] @base.vectorize class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @@ -128,7 +128,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): """ __slots__ = [] code = base.opcodes['LDMINT'] - arg_format = ['ciw','int'] + arg_format = ['ciw','long'] @base.vectorize class stmint(base.DirectMemoryWriteInstruction): @@ -142,7 +142,7 @@ class stmint(base.DirectMemoryWriteInstruction): """ __slots__ = [] code = base.opcodes['STMINT'] - arg_format = ['ci','int'] + arg_format = ['ci','long'] @base.vectorize class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d6c647ad..8ae0b86f 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -337,7 +337,7 @@ def gf2n(instruction): if isinstance(arg_format, list): __format = [] for __f in arg_format: - if __f in ('int', 'p', 'ci', 'str'): + if __f in ('int', 'long', 'p', 'ci', 'str'): __format.append(__f) else: __format.append(__f[0] + 'g' + __f[1:]) @@ -360,7 +360,7 @@ def gf2n(instruction): arg_format = instruction_cls.gf2n_arg_format elif isinstance(instruction_cls.arg_format, itertools.repeat): __f = next(instruction_cls.arg_format) - if __f != 'int' and __f != 'p': + if __f not in ('int', 'long', 'p'): arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) else: arg_format = copy.deepcopy(instruction_cls.arg_format) @@ -711,6 +711,14 @@ class IntArgFormat(ArgFormat): def __str__(self): return str(self.i) +class LongArgFormat(IntArgFormat): + @classmethod + def encode(cls, arg): + return struct.pack('>Q', arg) + + def __init__(self, f): + self.i = struct.unpack('>Q', f.read(8))[0] + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): @@ -768,6 +776,7 @@ ArgFormats = { 'i': ImmediateModpAF, 'ig': ImmediateGF2NAF, 'int': IntArgFormat, + 'long': LongArgFormat, 'p': PlayerNoAF, 'str': String, } diff --git a/Compiler/library.py b/Compiler/library.py index 35d6f46f..ef2fe1ab 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1149,6 +1149,7 @@ def multithread(n_threads, n_items=None, max_size=None): :param n_threads: compile-time (int) :param n_items: regint/cint/int (default: :py:obj:`n_threads`) + :param max_size: maximum size to be processed at once (default: no limit) The following executes ``f(0, 8)``, ``f(8, 8)``, and ``f(16, 9)`` in three different threads: @@ -1366,6 +1367,18 @@ def tree_reduce_multithread(n_threads, function, vector): left = (left + 1) // 2 return inputs[0] +def tree_reduce(function, sequence): + """ Round-efficient reduction. The following computes the maximum + of the list :py:obj:`l`:: + + m = tree_reduce(lambda x, y: x.max(y), l) + + :param function: reduction function taking two arguments + :param sequence: list, vector, or array + + """ + return util.tree_reduce(function, sequence) + def foreach_enumerate(a): """ Run-time loop over public data. This uses ``Player-Data/Public-Input/``. Example: diff --git a/Compiler/ml.py b/Compiler/ml.py index c521934f..02f0f04e 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1079,14 +1079,17 @@ class MaxPool(NoVariableLayer): (type(self).__name__, self.X.sizes, self.strides, self.ksize, self.padding) - def _forward(self, batch): + def forward(self, batch=None, training=False): + if batch is None: + batch = Array.create_from(regint(0)) def process(pool, bi, k, i, j): def m(a, b): c = a[0] > b[0] l = [c * x for x in a[1]] l += [(1 - c) * x for x in b[1]] return c.if_else(a[0], b[0]), l - red = util.tree_reduce(m, [(x[0], [1]) for x in pool]) + red = util.tree_reduce(m, [(x[0], [1] if training else []) + for x in pool]) self.Y[bi][i][j][k] = red[0] for i, x in enumerate(red[1]): self.comparisons[bi][k][i] = x diff --git a/Compiler/program.py b/Compiler/program.py index 36672330..78b802e1 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -400,6 +400,9 @@ class Program(object): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) + if addr + size >= 2 ** 32: + raise CompilerError("allocation exceeded for type '%s'" % + mem_type) self.allocated_mem_blocks[addr,mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if diff --git a/Compiler/types.py b/Compiler/types.py index b6359733..99ca6a8c 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1710,7 +1710,12 @@ class personal(Tape._no_truth): res = Array.create_from(res) return personal(player, res) - def bit_decompose(self, length): + def bit_decompose(self, length=None): + """ Bit decomposition. + + :param length: number of bits + + """ return [personal(self.player, x) for x in self._v.bit_decompose(length)] def _san(self, other): @@ -2144,7 +2149,7 @@ class sint(_secret, _int): the bit length. :param val: initialization (sint/cint/regint/int/cgf2n or list - thereof or sbits/sbitvec/sfix) + thereof, sbits/sbitvec/sfix, or :py:class:`personal`) :param size: vector size (int), defaults to 1 or size of list When converting :py:class:`~Compiler.GC.types.sbits`, the result is a @@ -2152,6 +2157,9 @@ class sint(_secret, _int): :py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values with bit length equal the length of the input. + Initializing from a :py:class:`personal` value implies the + relevant party inputting their value securely. + """ __slots__ = [] instruction_type = 'modp' @@ -4285,6 +4293,7 @@ class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` for security considerations of the underlying integer operations. + The secret integer is stored as the :py:obj:`v` member. It supports basic arithmetic (``+, -, *, /``), returning :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), @@ -5121,7 +5130,8 @@ class Array(_vectorizable): array ``a`` and ``i`` being a :py:class:`regint`, :py:class:`cint`, or a Python integer. - :param length: compile-time integer (int) or :py:obj:`None` for unknown length + :param length: compile-time integer (int) or :py:obj:`None` + for unknown length (need to specify :py:obj:`address`) :param value_type: basic type :param address: if given (regint/int), the array will not be allocated @@ -5178,6 +5188,8 @@ class Array(_vectorizable): self.address = None def get_address(self, index): + if isinstance(index, (_secret, _single)): + raise CompilerError('need cleartext index') key = str(index) if self.length is not None: from .GC.types import cbits @@ -5211,6 +5223,7 @@ class Array(_vectorizable): if index.step == 0: raise CompilerError('slice step cannot be zero') return index.start or 0, \ + index.stop if self.length is None else \ min(index.stop or self.length, self.length), index.step or 1 def __getitem__(self, index): @@ -5517,7 +5530,15 @@ class Array(_vectorizable): :param end: string to print after (default: line break) """ - library.print_str('%s' + end, self.get_vector().reveal()) + if util.is_constant(self.length): + library.print_str('%s' + end, self.get_vector().reveal()) + else: + library.print_str('[') + @library.for_range(self.length - 1) + def _(i): + library.print_str('%s, ', self[i].reveal()) + library.print_str('%s', self[self.length - 1].reveal()) + library.print_str(']' + end) def reveal_to_binary_output(self, player=None): """ Reveal to binary output if supported by type. @@ -5893,7 +5914,8 @@ class SubMultiArray(_vectorizable): """ Matrix-matrix and matrix-vector multiplication. :param self: two-dimensional - :param other: Matrix or Array of matching size and type """ + :param other: Matrix or Array of matching size and type + :param n_threads: number of threads (default: all in same thread) """ assert len(self.sizes) == 2 if isinstance(other, Array): assert len(other) == self.sizes[1] @@ -5928,6 +5950,7 @@ class SubMultiArray(_vectorizable): res_matrix.assign_part_vector( self.get_part(base, size).direct_mul(other), base) except AttributeError: + assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() A = self.get_vector() @@ -5937,7 +5960,7 @@ class SubMultiArray(_vectorizable): res_params)) except (AttributeError, AssertionError): # fallback for sfloat etc. - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt_multithread(n_threads, self.sizes[0]) def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index e426bade..4657b5d8 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -28,6 +28,7 @@ public: static const true_type invertible; static int size() { return 0; } + static int length() { return 256; } static string type_string() { return "P256"; } static void init(); diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp index 6bdcec28..d7a4d883 100644 --- a/ECDSA/semi-ecdsa-party.cpp +++ b/ECDSA/semi-ecdsa-party.cpp @@ -10,6 +10,7 @@ #include "Protocols/SemiPrep.hpp" #include "Protocols/SemiInput.hpp" #include "Protocols/MAC_Check_Base.hpp" +#include "GC/SemiSecret.hpp" #include "ot-ecdsa-party.hpp" #include diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 5f8e76fd..de9e9cad 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -49,8 +49,8 @@ public: * @param n number of values * @returns vector of integer-like values */ - template - vector receive_outputs(int n); + template + vector receive_outputs(int n); }; #endif /* EXTERNALIO_CLIENT_H_ */ diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 601d9a48..3af40f2f 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -91,8 +91,8 @@ void Client::send_private_inputs(const vector& values) // Receive shares of the result and sum together. // Also receive authenticating values. -template -vector Client::receive_outputs(int n) +template +vector Client::receive_outputs(int n) { vector triples(3 * n); octetStream os; @@ -111,7 +111,7 @@ vector Client::receive_outputs(int n) } } - vector output_values; + vector output_values; for (int i = 0; i < 3 * n; i += 3) { if (T(triples[i] * triples[i + 1]) != triples[i + 2]) diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index f68384c0..b040dd5e 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -46,7 +46,7 @@ #include #include -template +template void one_run(T salary_value, Client& client) { // Run the computation @@ -54,18 +54,18 @@ void one_run(T salary_value, Client& client) cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; // Get the result back (client_id of winning client) - T result = client.receive_outputs(1)[0]; + U result = client.receive_outputs(1)[0]; cout << "Winning client id is : " << result << endl; } -template +template void run(double salary_value, Client& client) { // sint - one_run(long(round(salary_value)), client); + one_run(long(round(salary_value)), client); // sfix with f = 16 - one_run(long(round(salary_value * exp2(16))), client); + one_run(long(round(salary_value * exp2(16))), client); } int main(int argc, char** argv) @@ -125,7 +125,7 @@ int main(int argc, char** argv) { gfp::init_field(specification.get()); cerr << "using prime " << gfp::pr() << endl; - run(salary_value, client); + run(salary_value, client); break; } case 'R': @@ -134,13 +134,13 @@ int main(int argc, char** argv) switch (R) { case 64: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 104: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; case 128: - run>(salary_value, client); + run, Z2<64>>(salary_value, client); break; default: cerr << R << "-bit ring not implemented"; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 0de8bb1e..5fb07f23 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -2,9 +2,11 @@ #include "FHE_Params.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" +#include "Protocols/HemiOptions.h" +#include "Processor/OnlineOptions.h" -FHE_Params::FHE_Params(int n_mults) : - FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1) +FHE_Params::FHE_Params(int n_mults, int drown_sec) : + FFTData(n_mults + 1), Chi(0.7), sec_p(drown_sec), matrix_dim(1) { } @@ -17,16 +19,20 @@ void FHE_Params::set(const Ring& R, for (size_t i = 0; i < FFTData.size(); i++) FFTData[i].init(R,primes[i]); - set_sec(40); + set_sec(sec_p); } void FHE_Params::set_sec(int sec) { + assert(sec >= 0); sec_p=sec; Bval=1; Bval=Bval<matrix_dim = matrix_dim; } +void FHE_Params::set_matrix_dim_from_options() +{ + set_matrix_dim( + HemiOptions::singleton.plain_matmul ? + 1 : OnlineOptions::singleton.batch_size); +} + bigint FHE_Params::Q() const { bigint res = FFTData[0].get_prime(); diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 9407b0ba..8821e2e2 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -13,6 +13,7 @@ #include "FHE/FFT_Data.h" #include "FHE/DiscreteGauss.h" #include "Tools/random.h" +#include "Protocols/config.h" class FHE_Params { @@ -30,15 +31,17 @@ class FHE_Params public: - FHE_Params(int n_mults = 1); + FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); + void set_min_sec(int sec); void set_matrix_dim(int matrix_dim); + void set_matrix_dim_from_options(); int get_matrix_dim() const { return matrix_dim; } const vector& FFTD() const { return FFTData; } diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index c9c23aab..dcec137e 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -68,6 +68,10 @@ void HNF(matrix& H,matrix& U,const matrix& A) { int m=A.size(),n=A[0].size(),r,i,j,k; +#ifdef VERBOSE + cerr << "HNF m=" << m << ", n=" << n << endl; +#endif + H=A; ident(U,n); r=min(m,n); @@ -79,9 +83,9 @@ void HNF(matrix& H,matrix& U,const matrix& A) { if (step==2) { // Step 2 k=-1; - mn=bigint(1)<<256; + mn=bigint(0); for (j=i; j 0) @@ -428,6 +429,16 @@ GF2X Subs_PowX_Mod(const GF2X& a,int pow,int m,const GF2X& c) +GF2X get_F(const Ring& Rg) +{ + GF2X F; + for (int i=0; i<=Rg.phi_m(); i++) + { if (((Rg.Phi()[i])%2)!=0) + { SetCoeff(F,i,1); } + } + //cout << "F = " << F << endl; + return F; +} void init(P2Data& P2D,const Ring& Rg) { @@ -438,16 +449,12 @@ void init(P2Data& P2D,const Ring& Rg) { SetCoeff(G,gf2n_short::get_t(i),1); } //cout << "G = " << G << endl; - for (int i=0; i<=Rg.phi_m(); i++) - { if (((Rg.Phi()[i])%2)!=0) - { SetCoeff(F,i,1); } - } - //cout << "F = " << F << endl; + F = get_F(Rg); // seed randomness to achieve same result for all players // randomness is used in SFCanZass and FindRoot SetSeed(ZZ(0)); - + // Now factor F modulo 2 vec_GF2X facts=SFCanZass(F); @@ -459,17 +466,34 @@ void init(P2Data& P2D,const Ring& Rg) // Compute the quotient group QGroup QGrp; int Gord=-1,e=Rg.phi_m()/d; // e = # of plaintext slots, phi(m)/degree - int seed=1; - while (Gord!=e) + + if ((e*gf2n_short::degree())!=Rg.phi_m()) + { cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl; + cout << e << " * " << gf2n_short::degree() << " != " << Rg.phi_m() << endl; + throw invalid_params(); + } + + int max_tries = 10; + for (int seed = 0;; seed++) { QGrp.assign(Rg.m(),seed); // QGrp encodes the the quotient group Z_m^*/<2> - Gord=QGrp.order(); - if (Gord!=e) { cout << "Group order wrong, need to repeat the Haf-Mc algorithm" << endl; seed++; } + Gord = QGrp.order(); + if (Gord == e) + { + break; + } + else + { + if (seed == max_tries) + { + cerr << "abort after " << max_tries << " tries" << endl; + throw invalid_params(); + } + else + cout << "Group order wrong, need to repeat the Haf-Mc algorithm" + << endl; + } } //cout << " l = " << Gord << " , d = " << d << endl; - if ((Gord*gf2n_short::degree())!=Rg.phi_m()) - { cout << "Plaintext type requires Gord*gf2n_short::degree == phi_m" << endl; - throw not_implemented(); - } vector Fi(Gord); vector Rts(Gord); @@ -590,8 +614,23 @@ void char_2_dimension(int& m, int& lg2) m=5797; lg2=40; break; + case 64: + m = 9615; + break; + case 63: + m = 9271; + break; + case 28: + m = 3277; + break; case 16: - m = 13107; + m = 4369; + break; + case 12: + m = 4095; + break; + case 11: + m = 2047; break; default: throw runtime_error("field size not supported"); @@ -628,7 +667,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D) finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[0], round_up, params); } - if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) > phi_N(m)) + if (NoiseBounds::min_phi_m(lg2p0 + lg2p1, params) * 2 > m) throw runtime_error("number of slots too small"); cout << "m = " << m << endl; diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index acaba70b..7d7d13de 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -55,12 +55,19 @@ int generate_semi_setup(int plaintext_length, int sec, FHE_Params& params, FD& FieldD, bool round_up, int n = 1); // field-independent semi-homomorphic setup -int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, +int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up); void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); +namespace NTL +{ +class GF2X; +} + +NTL::GF2X get_F(const Ring& Rg); + // For use when we want p to be a specific value void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, bigint& pr1, int n, int sec, bigint& p, FHE_Params& params); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index f2e151c4..a1fe3e03 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -36,11 +36,12 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); // unify parameters by taking maximum over TopGear or not bigint B_clean_top_gear = B_clean * 2; - bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); + bigint B_clean_not_top_gear = B_clean << max(slack - sec, 0); B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); int matrix_dim = params.get_matrix_dim(); #ifdef NOISY + cout << "phi(m): " << phi_m << endl; cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "V_s: " << V_s << endl; cout << "c1: " << c1 << endl; @@ -50,10 +51,13 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, cout << "B_clean: " << B_clean << endl; cout << "B_scale: " << B_scale << endl; cout << "matrix dimension: " << matrix_dim << endl; + cout << "drown sec: " << params.secp() << endl; + cout << "sec: " << sec << endl; #endif assert(matrix_dim > 0); - drown = 1 + matrix_dim * n * (bigint(1) << sec); + assert(params.secp() >= 0); + drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp()); } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) @@ -71,8 +75,14 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma) { if (sigma <= 0) sigma = FHE_Params().get_R(); - // the constant was updated using Martin Albrecht's LWE estimator in Sep 2019 - return 37.8 * (log_q - log2(sigma)); + // the constant was updated using Martin Albrecht's LWE estimator in Mar 2022 + // found the following pairs for 128-bit security + // and alpha = 0.7 * sqrt(2*pi) / q + // m = 2048, log_2(q) = 68 + // m = 4096, log_2(q) = 138 + // m = 8192, log_2(q) = 302 + // m = 16384, log_2(q) = 560 + return 15.1 * log_q; } double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params) diff --git a/FHE/Subroutines.cpp b/FHE/Subroutines.cpp index 8ec6a7ce..a688b169 100644 --- a/FHE/Subroutines.cpp +++ b/FHE/Subroutines.cpp @@ -11,35 +11,15 @@ void Subs(modp& ans,const vector& poly,const modp& x,const Zp_Data& ZpD) assignZero(ans,ZpD); for (int i=poly.size()-1; i>=0; i--) { Mul(ans,ans,x,ZpD); - switch (poly[i]) - { case 0: - break; - case 1: - Add(ans,ans,one,ZpD); - break; - case -1: - Sub(ans,ans,one,ZpD); - break; - case 2: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -2: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - case 3: - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - Add(ans,ans,one,ZpD); - break; - case -3: - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - Sub(ans,ans,one,ZpD); - break; - default: - throw not_implemented(); + if (poly[i] > 0) + { + for (int j = 0; j < poly[i]; j++) + Add(ans, ans, one, ZpD); + } + if (poly[i] < 0) + { + for (int j = 0; j < -poly[i]; j++) + Sub(ans, ans, one, ZpD); } } } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 48a8a6ef..0dc8d967 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -40,10 +40,9 @@ template void PartSetup::generate_setup(int n_parties, int plaintext_length, int sec, int slack, bool round_up) { - sec = max(sec, 40); + params.set_min_sec(sec); Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup( params, FieldD); - params.set_sec(sec); pk = FHE_PK(params, FieldD.get_prime()); sk = FHE_SK(params, FieldD.get_prime()); calpha = Ciphertext(params); @@ -180,11 +179,8 @@ void PartSetup::init_field() } template -void PartSetup::check(int sec) const +void PartSetup::check() const { - sec = max(sec, 40); - if (abs(sec - params.secp()) > 2) - throw runtime_error("security parameters vary too much between protocol and distributed decryption"); sk.check(params, pk, FieldD.get_prime()); } diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h index 88c9e05f..1c606506 100644 --- a/FHEOffline/DataSetup.h +++ b/FHEOffline/DataSetup.h @@ -57,7 +57,7 @@ public: void init_field(); - void check(int sec) const; + void check() const; bool operator!=(const PartSetup& other); void secure_init(Player& P, MachineBase& machine, int plaintext_length, diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 92632002..43ad7e84 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -69,7 +69,7 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, product_share.randomize(G); bigint B = 6 * machine.setup().params.get_R(); B *= machine.setup().FieldD.get_prime(); - B <<= machine.drown_sec; + B <<= machine.setup().params.secp(); // slack B *= NonInteractiveProof::slack(machine.sec, machine.setup().params.phi_m()); diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index e41fe183..b19dd62c 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -29,7 +29,7 @@ void PairwiseMachine::init() { if (use_gf2n) { - field_size = 40; + field_size = gf2n_short::DEFAULT_LENGTH; gf2n_short::init_field(field_size); setup_keys(); } @@ -67,7 +67,7 @@ void PairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); - s.init(P, drown_sec, field_size, extra_slack); + s.init(P, sec, field_size, extra_slack); if (output) write_mac_key(get_prep_dir(P), P.my_num(), P.num_players(), s.alphai); for (auto& x : other_pks) diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 047c84f2..59223ad0 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -65,11 +65,12 @@ template void secure_init(T& setup, Player& P, U& machine, int plaintext_length, int sec, FHE_Params& params) { + assert(sec >= 0); machine.sec = sec; - sec = max(sec, 40); - machine.drown_sec = sec; + params.set_min_sec(sec); string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + + to_string(params.secp()) + "-" + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index 5e690b67..6059ef3b 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -78,6 +78,7 @@ class Proof diagonal(diagonal), B_plain_length(0), B_rand_length(0), pk(&pk), n_proofs(n_proofs) { sec=sc; + assert(sec > 0); tau=Tau; rho=Rho; phim=(pk.get_params()).phi_m(); diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index dab3e507..63d559d5 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -10,6 +10,8 @@ #include "Tools/Subroutines.h" +#include "Protocols/mac_key.hpp" + // The number of sacrifices to amortize at one time #define amortize 512 @@ -19,12 +21,7 @@ void Triple_Checking(const Player& P, MAC_Check& MC, int nm, int output_thread, TripleSacriFactory< Share >& factory, bool write_output, bool clear, string dir) { - if (T::length() < 40) - { - cerr << "Field too small for reasonable security" << endl; - cerr << "Use a larger field or remove this warning from " << __FILE__ << endl; - exit(1); - } + check_field_size(); ofstream outf; if (write_output) diff --git a/FHEOffline/SimpleMachine.cpp b/FHEOffline/SimpleMachine.cpp index 04d190f1..a7be00a4 100644 --- a/FHEOffline/SimpleMachine.cpp +++ b/FHEOffline/SimpleMachine.cpp @@ -27,7 +27,7 @@ void* run_generator(void* generator) MachineBase::MachineBase() : throughput_loop_thread(0),portnum_base(0), data_type(DATA_TRIPLE), - sec(0), drown_sec(0), field_size(0), extra_slack(0), + sec(0), field_size(0), extra_slack(0), produce_inputs(false), use_gf2n(false) { @@ -91,7 +91,6 @@ void MachineBase::parse_options(int argc, const char** argv) opt.get("-h")->getString(hostname); opt.get("-pn")->getInt(portnum_base); opt.get("-s")->getInt(sec); - drown_sec = max(40, sec); opt.get("-f")->getInt(field_size); use_gf2n = opt.isSet("-2"); if (use_gf2n) @@ -221,7 +220,7 @@ void MultiplicativeMachine::fake_keys(int slack) PartSetup& part_setup = setup.part(); if (P.my_num() == 0) { - part_setup.generate_setup(N.num_players(), field_size, drown_sec, slack, true); + part_setup.generate_setup(N.num_players(), field_size, sec, slack, true); vector > setups; part_setup.fake(setups, P.num_players(), false); for (int i = 1; i < P.num_players(); i++) @@ -238,7 +237,7 @@ void MultiplicativeMachine::fake_keys(int slack) P.receive_player(0, os); } part_setup.unpack(os); - part_setup.check(drown_sec); + part_setup.check(); part_setup.alphai = read_or_generate_mac_key>(P); Plaintext_ m(part_setup.FieldD); diff --git a/FHEOffline/SimpleMachine.h b/FHEOffline/SimpleMachine.h index 8ca37bbf..e8d07170 100644 --- a/FHEOffline/SimpleMachine.h +++ b/FHEOffline/SimpleMachine.h @@ -26,7 +26,6 @@ protected: public: int sec; - int drown_sec; int field_size; int extra_slack; bool produce_inputs; diff --git a/FHEOffline/TemiSetup.cpp b/FHEOffline/TemiSetup.cpp index fc222ed5..fd922d4c 100644 --- a/FHEOffline/TemiSetup.cpp +++ b/FHEOffline/TemiSetup.cpp @@ -15,9 +15,7 @@ TemiSetup::TemiSetup() this->pk = {this->params, 0}; this->sk = {this->params, 0}; this->calpha = this->params; - this->params.set_matrix_dim( - HemiOptions::singleton.plain_matmul ? - 1 : OnlineOptions::singleton.batch_size); + this->params.set_matrix_dim_from_options(); } template diff --git a/GC/CcdShare.h b/GC/CcdShare.h index e890ce63..894d3ae7 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -40,7 +40,7 @@ public: return "CCD"; } - static MAC_Check* new_mc(T) + static MAC_Check* new_mc(typename super::mac_key_type) { return new MAC_Check; } diff --git a/GC/DealerPrep.h b/GC/DealerPrep.h new file mode 100644 index 00000000..a3bd4bcc --- /dev/null +++ b/GC/DealerPrep.h @@ -0,0 +1,69 @@ +/* + * DealerPrep.h + * + */ + +#ifndef GC_DEALERPREP_H_ +#define GC_DEALERPREP_H_ + +#include "Protocols/DealerPrep.h" +#include "Protocols/ProtocolSet.h" +#include "ShiftableTripleBuffer.h" +#include "SemiSecret.h" + +namespace GC +{ +class DealerPrep : public BufferPrep, ShiftableTripleBuffer +{ + Player* P; + +public: + DealerPrep(DataPositions& usage, int = -1) : + BufferPrep(usage), P(0) + { + } + + void set_protocol(DealerSecret::Protocol& protocol) + { + P = &protocol.P; + } + + void buffer_triples() + { + ProtocolSetup> setup(*P); + ProtocolSet> set(*P, setup); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + auto triple = set.preprocessing.get_triple( + DealerSecret::default_length); + this->triples.push_back({{triple[0], triple[1], triple[2]}}); + } + } + + void buffer_bits() + { + SeededPRNG G; + if (P->my_num() != 0) + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->bits.push_back(G.get_bit()); + else + this->bits.resize( + this->bits.size() + OnlineOptions::singleton.batch_size); + } + + void get(Dtype type, DealerSecret* data) + { + BufferPrep::get(type, data); + } + + array get_triple_no_count(int n_bits) + { + if (n_bits == -1) + n_bits = DealerSecret::default_length; + return ShiftableTripleBuffer::get_triple_no_count(n_bits); + } +}; + +} + +#endif /* GC_DEALERPREP_H_ */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 00e6c52c..ee7a8446 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -10,6 +10,7 @@ #include "GC/Memory.h" #include "GC/Access.h" #include "GC/ArgTuples.h" +#include "GC/NoShare.h" #include "Math/gf2nlong.h" #include "Tools/SwitchableOutput.h" @@ -40,7 +41,6 @@ public: typedef FakeSecret DynamicType; typedef Memory DynamicMemory; - typedef BitVec mac_key_type; typedef BitVec clear; typedef BitVec open_type; diff --git a/GC/MaliciousCcdShare.h b/GC/MaliciousCcdShare.h index 9dc63fc6..fbc66ea1 100644 --- a/GC/MaliciousCcdShare.h +++ b/GC/MaliciousCcdShare.h @@ -44,7 +44,7 @@ public: return "Malicious CCD"; } - static MAC_Check* new_mc(T) + static MAC_Check* new_mc(typename super::mac_key_type) { return new MAC_Check; } diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index 500bbb5a..9f941d51 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -30,7 +30,7 @@ public: typedef MaliciousRepMC MC; typedef BitVec_ open_type; typedef open_type clear; - typedef BitVec mac_key_type; + typedef NoValue mac_key_type; static MC* new_mc(mac_key_type) { @@ -71,7 +71,7 @@ public: static const bool expensive_triples = true; - static MC* new_mc(BitVec) + static MC* new_mc(typename super::mac_key_type) { try { diff --git a/GC/NoShare.h b/GC/NoShare.h index c435ec3f..49f93ac4 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -60,35 +60,43 @@ public: throw not_implemented(); } + static void init_minimum(int) + { + } + static void fail() { throw runtime_error("VM does not support binary circuits"); } NoValue() {} - NoValue(int) { fail(); } + NoValue(bool) {} + NoValue(ValueInterface) {} + NoValue(int128) {} void assign(const char*) { fail(); } + const char* get_ptr() const { return (char*) this; } + int get() const { fail(); return 0; } int operator<<(int) const { fail(); return 0; } void operator+=(int) { fail(); } - bool operator!=(NoValue) const { fail(); return 0; } + bool operator!=(NoValue) const { return false; } bool operator==(int) { fail(); return false; } bool get_bit(int) { fail(); return 0; } - void randomize(PRNG&) { fail(); } + void randomize(PRNG&) {} void invert() { fail(); } void mask(int) { fail(); } void input(istream&, bool) { fail(); } - void output(ostream&, bool) { fail(); } + void output(ostream&, bool) {} }; inline ostream& operator<<(ostream& o, NoValue) diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 663d55fc..96b2d62d 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -340,7 +340,7 @@ void Processor::convcbit2s(const BaseInstruction& instruction) for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) S[instruction.get_r(0) + i] = T::constant(C[instruction.get_r(1) + i], share_thread.P->my_num(), share_thread.MC->get_alphai(), - min(unsigned(unit), instruction.get_n() - i * unit)); + min(size_t(unit), instruction.get_n() - i * unit)); } template diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index ae10b522..e95554bf 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -8,6 +8,7 @@ #include "Protocols/SemiMC.h" #include "Protocols/SemiShare.h" +#include "Protocols/DealerShare.h" #include "Processor/DummyProtocol.h" #include "ShareSecret.h" @@ -17,71 +18,116 @@ namespace GC { class SemiPrep; +class DealerPrep; -class SemiSecret : public SemiShare, public ShareSecret +template +class SemiSecretBase : public V, public ShareSecret { + typedef V super; + public: - typedef Memory DynamicMemory; + typedef Memory DynamicMemory; - typedef SemiMC MC; - typedef DirectSemiMC Direct_MC; - typedef Beaver Protocol; - typedef MC MAC_Check; - typedef SemiPrep LivePrep; - typedef SemiInput Input; + typedef Beaver Protocol; - typedef SemiSecret part_type; - typedef SemiSecret small_type; + typedef T part_type; + typedef T small_type; static const int default_length = sizeof(BitVec) * 8; static string type_string() { return "binary secret"; } static string phase_name() { return "Binary computation"; } - static MC* new_mc(mac_key_type); - - template - static void generate_mac_key(mac_key_type, T) - { - } - - static void trans(Processor& processor, int n_outputs, + static void trans(Processor& processor, int n_outputs, const vector& args); - SemiSecret() + SemiSecretBase() { } - SemiSecret(long other) : - SemiShare(other) + SemiSecretBase(long other) : + V(other) { } - SemiSecret(const IntBase& other) : - SemiShare(other) + template + SemiSecretBase(const IntBase& other) : + V(other) { } template - SemiSecret(const Z2& other) : - SemiShare(other) + SemiSecretBase(const Z2& other) : + V(other) { } void load_clear(int n, const Integer& x); - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(Memory& S, const vector& regs); + void bitdec(Memory& S, const vector& regs) const; - void xor_(int n, const SemiSecret& x, const SemiSecret& y) + void xor_(int n, const T& x, const T& y) { *this = BitVec(x ^ y).mask(n); } - void xor_bit(int i, const SemiSecret& bit) + void xor_bit(int i, const T& bit) { *this ^= bit << i; } void reveal(size_t n_bits, Clear& x); - SemiSecret lsb() + T lsb() { return *this & 1; } }; +class SemiSecret: public SemiSecretBase> +{ + typedef SemiSecret This; + +public: + typedef SemiSecretBase> super; + + typedef SemiMC MC; + typedef DirectSemiMC Direct_MC; + typedef MC MAC_Check; + typedef SemiInput Input; + typedef SemiPrep LivePrep; + + static MC* new_mc(typename SemiShare::mac_key_type); + + SemiSecret() + { + } + + template + SemiSecret(const T& other) : + super(other) + { + } +}; + +class DealerSecret : public SemiSecretBase> +{ + typedef DealerSecret This; + +public: + typedef SemiSecretBase> super; + + typedef DealerMC MC; + typedef DirectDealerMC Direct_MC; + typedef MC MAC_Check; + typedef DealerInput Input; + typedef DealerPrep LivePrep; + + static MC* new_mc(typename super::mac_key_type); + + DealerSecret() + { + } + + template + DealerSecret(const T& other) : + super(other) + { + } +}; + } /* namespace GC */ #endif /* GC_SEMISECRET_H_ */ diff --git a/GC/SemiSecret.cpp b/GC/SemiSecret.hpp similarity index 61% rename from GC/SemiSecret.cpp rename to GC/SemiSecret.hpp index 704e2a2f..f6a4d398 100644 --- a/GC/SemiSecret.cpp +++ b/GC/SemiSecret.hpp @@ -4,17 +4,20 @@ */ #include "GC/ShareParty.h" -#include "SemiSecret.h" - #include "GC/ShareSecret.hpp" #include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/DealerMC.h" +#include "SemiSecret.h" namespace GC { -const int SemiSecret::default_length; +template +const int SemiSecretBase::default_length; -SemiSecret::MC* SemiSecret::new_mc(mac_key_type) +inline +SemiSecret::MC* SemiSecret::new_mc( + typename super::mac_key_type) { if (OnlineOptions::singleton.direct) return new Direct_MC; @@ -22,7 +25,18 @@ SemiSecret::MC* SemiSecret::new_mc(mac_key_type) return new MC; } -void SemiSecret::trans(Processor& processor, int n_outputs, +inline +DealerSecret::MC* DealerSecret::new_mc( + typename super::mac_key_type) +{ + if (OnlineOptions::singleton.direct) + return new Direct_MC; + else + return new MC; +} + +template +void SemiSecretBase::trans(Processor& processor, int n_outputs, const vector& args) { int N_BITS = default_length; @@ -46,29 +60,33 @@ void SemiSecret::trans(Processor& processor, int n_outputs, } } -void SemiSecret::load_clear(int n, const Integer& x) +template +void SemiSecretBase::load_clear(int n, const Integer& x) { - check_length(n, x); - *this = constant(x, ShareThread::s().P->my_num()); + this->check_length(n, x); + *this = this->constant(x, ShareThread::s().P->my_num()); } -void SemiSecret::bitcom(Memory& S, const vector& regs) +template +void SemiSecretBase::bitcom(Memory& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) *this ^= (S[regs[i]] << i); } -void SemiSecret::bitdec(Memory& S, +template +void SemiSecretBase::bitdec(Memory& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; } -void SemiSecret::reveal(size_t n_bits, Clear& x) +template +void SemiSecretBase::reveal(size_t n_bits, Clear& x) { - auto& thread = ShareThread::s(); + auto& thread = ShareThread::s(); x = thread.MC->POpen(*this, *thread.P).mask(n_bits); } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 6d9f2652..fb254486 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -112,8 +112,8 @@ public: typedef BitVec clear; typedef BitVec open_type; - typedef BitVec mac_type; - typedef BitVec mac_key_type; + typedef NoShare mac_type; + typedef NoValue mac_key_type; typedef NoShare bit_type; @@ -213,7 +213,7 @@ public: typedef ReplicatedMC MC; typedef BitVec_ open_type; typedef open_type clear; - typedef BitVec mac_key_type; + typedef NoValue mac_key_type; static MC* new_mc(mac_key_type) { diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index c6c9dcaa..a426eea2 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -11,6 +11,8 @@ #include "instructions.h" +#include "Tools/benchmarking.h" + #include "Machine.hpp" namespace GC @@ -58,15 +60,10 @@ Thread* ThreadMaster::new_thread(int i) template void ThreadMaster::run() { -#ifndef INSECURE if (not opts.live_prep) { - cerr - << "Preprocessing from file not supported by binary virtual machines" - << endl; - exit(1); + insecure("preprocessing from file in binary virtual machines"); } -#endif P = new PlainPlayer(N, "main"); diff --git a/GC/TinyMC.h b/GC/TinyMC.h index e0a0b948..ac3a29ab 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -48,12 +48,12 @@ public: part_MC.exchange(P); } - typename T::open_type finalize_open() + typename T::open_type finalize_raw() { int n = sizes.next(); typename T::open_type opened = 0; for (int i = 0; i < n; i++) - opened += typename T::open_type(part_MC.finalize_open().get_bit(0)) << i; + opened += typename T::open_type(part_MC.finalize_raw().get_bit(0)) << i; return opened; } diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 6cb02779..04508c09 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -5,6 +5,7 @@ #include "Protocols/Spdz2kShare.h" #include "Protocols/Spdz2kPrep.h" +#include "Protocols/SPDZ2k.h" #include "GC/TinySecret.h" #include "GC/TinyMC.h" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 1a093146..02686df7 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -19,3 +19,4 @@ #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "GC/SemiSecret.hpp" diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 9f18d3a6..d8cc3014 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -97,6 +97,5 @@ ShamirMachineSpec::ShamirMachineSpec(int argc, const char** argv) auto& opts = ShamirOptions::singleton; ez::ezOptionParser opt; opts = {opt, argc, argv}; - T::bit_type::part_type::open_type::init_field(); HonestMajorityFieldMachine(argc, argv, opt, opts.nparties); } diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index 82cde5e8..e4a876fe 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -142,8 +142,8 @@ TripleMachine::TripleMachine(int argc, const char** argv) : gfpvar1::init_field(prime, false); else gfpvar1::init_default(128, false); - gf2n_long::init_field(128); - gf2n_short::init_field(40); + gf2n_long::init_field(); + gf2n_short::init_field(); PRNG G; G.ReSeed(); diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp new file mode 100644 index 00000000..4bc8fab1 --- /dev/null +++ b/Machines/dealer-ring-party.cpp @@ -0,0 +1,22 @@ +/* + * dealer-ring-party.cpp + * + */ + +#include "Protocols/DealerShare.h" +#include "Protocols/DealerInput.h" + +#include "Processor/RingMachine.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/DealerPrep.hpp" +#include "Protocols/DealerInput.hpp" +#include "Protocols/DealerMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Semi.hpp" +#include "GC/DealerPrep.h" + +int main(int argc, const char** argv) +{ + HonestMajorityRingMachine(argc, argv, 0); +} diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index f26f5f32..5999050c 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -54,9 +54,8 @@ int main(int argc, const char** argv) { #define X(L) \ case L: \ - Machine>, FakeShare>(0, N, progname, \ - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \ - online_opts.live_prep, online_opts).run(); \ + Machine>, FakeShare>(N, false, online_opts, \ + gf2n::default_degree()).run(progname); \ break; X(64) X(128) X(256) X(192) X(384) X(512) #ifdef RING_SIZE diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 934c15dc..60e0d6e4 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -27,6 +27,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/malicious-ccd-party.cpp b/Machines/malicious-ccd-party.cpp index 4ce84aea..55ec8b99 100644 --- a/Machines/malicious-ccd-party.cpp +++ b/Machines/malicious-ccd-party.cpp @@ -18,8 +18,9 @@ int main(int argc, const char** argv) { - gf2n_short::init_field(40); ez::ezOptionParser opt; ShamirOptions::singleton = {opt, argc, argv}; + OnlineOptions opts(opt, argc, argv); + gf2n_short::init_minimum(opts.security_parameter); GC::ShareParty>(argc, argv, opt); } diff --git a/Machines/semi-bin-party.cpp b/Machines/semi-bin-party.cpp index fbd0a634..6c99ebf9 100644 --- a/Machines/semi-bin-party.cpp +++ b/Machines/semi-bin-party.cpp @@ -14,6 +14,7 @@ #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" #include "GC/Processor.hpp" +#include "GC/SemiSecret.hpp" #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/SemiInput.hpp" diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index 7ecc450d..ced64919 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -25,6 +25,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index 8aba3173..9e826737 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -22,12 +22,12 @@ int main(int argc, const char** argv) 1, // Number of args expected. 0, // Delimiter if expecting multiple args. "SPDZ2k security parameter (default: 64)", // Help description. - "-S", // Flag token. - "--security" // Flag token. + "-SP", // Flag token. + "--spdz2k-security" // Flag token. ); opt.parse(argc, argv); int s; - opt.get("-S")->getInt(s); + opt.get("-SP")->getInt(s); opt.resetArgs(); RingOptions ring_options(opt, argc, argv); int k = ring_options.R; @@ -62,6 +62,8 @@ int main(int argc, const char** argv) cerr << "add Z(" << k << ", " << s << ") to " << __FILE__ << " at line " << (__LINE__ - 11) << " and create Machines/SPDZ2^" << k << "+" << s << ".cpp based on Machines/SPDZ2^72+64.cpp" << endl; + cerr << "Alternatively, compile with -DRING_SIZE=" << k + << " and -DSPDZ2K_DEFAULT_SECURITY=" << s << endl; } exit(1); } diff --git a/Machines/temi-party.cpp b/Machines/temi-party.cpp index 12e99dc2..f8abd35d 100644 --- a/Machines/temi-party.cpp +++ b/Machines/temi-party.cpp @@ -26,6 +26,7 @@ #include "Protocols/Hemi.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" +#include "GC/SemiSecret.hpp" #include "Math/gfp.hpp" int main(int argc, const char** argv) diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 82122bd1..67234a8a 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -28,6 +28,8 @@ int main(int argc, const char** argv) { - gf2n_short::init_field(40); + ez::ezOptionParser opt; + OnlineOptions opts(opt, argc, argv); + gf2n_short::init_minimum(opts.security_parameter); GC::simple_binary_main>(argc, argv, 1000); } diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp index 7e72361e..f83f839f 100644 --- a/Machines/tiny-party.cpp +++ b/Machines/tiny-party.cpp @@ -29,5 +29,5 @@ int main(int argc, const char** argv) { - GC::simple_binary_main>(argc, argv, 1000); + GC::simple_binary_main>(argc, argv, 1000); } diff --git a/Makefile b/Makefile index 4f558e1d..3c2be009 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR) -GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +GC_SEMI = GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) OT_EXE = ot.x ot-offline.x @@ -57,7 +57,7 @@ vm: arithmetic binary doc: cd doc; $(MAKE) html -arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy +arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy dealer-ring-party.x binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr all: overdrive she-offline @@ -162,7 +162,7 @@ bmr-%.x: $(BMR) $(VM) Machines/bmr-%.cpp $(LIBSIMPLEOT) bmr-clean: -rm BMR/*.o BMR/*/*.o GC/*.o -bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) +bankers-bonus-client.x: ExternalIO/bankers-bonus-client.o $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) simple-offline.x: $(FHEOFFLINE) @@ -203,13 +203,13 @@ replicated-field-party.x: GC/square64.o brain-party.x: GC/square64.o malicious-rep-bin-party.x: GC/square64.o ps-rep-bin-party.x: GC/PostSacriBin.o -semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +semi-bin-party.x: $(OT) GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) tinier-party.x: $(OT) spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) -semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o -semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +semi-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi2k-party.x: $(OT) GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) @@ -234,15 +234,16 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o rep4-ring-party.x: GC/Rep4Secret.o no-party.x: Protocols/ShareInterface.o -semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o +semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) emulate.x: GC/FakeSecret.o -semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT) +semi-bmr-party.x: GC/SemiPrep.o $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o -mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o +l2h-example.x: $(VM) $(OT) Machines/Tinier.o mascot-offline.x: $(VM) $(TINIER) cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) @@ -253,6 +254,7 @@ static/semi-bmr-party.x: $(BMR) static/real-bmr-party.x: $(BMR) static/bmr-program-party.x: $(BMR) static/no-party.x: Protocols/ShareInterface.o +Test/failure.x: Protocols/MalRepRingOptions.o ifeq ($(AVX_OT), 1) $(LIBSIMPLEOT): SimpleOT/Makefile @@ -270,7 +272,7 @@ Programs/Circuits: .PHONY: mpir-setup mpir-global mpir mpir-setup: - git submodule update --init mpir + git submodule update --init mpir || git clone https://github.com/wbhart/mpir cd mpir; \ autoreconf -i; \ autoreconf -i @@ -306,7 +308,7 @@ linux-machine-setup: endif simde/simde: - git submodule update --init simde + git submodule update --init simde || git clone https://github.com/simd-everywhere/simde clean: -rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so diff --git a/Math/Integer.h b/Math/Integer.h index 8104724c..1fbb257f 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -37,8 +37,6 @@ public: static void specification(octetStream& os); - static void init_default(int lgp) { (void)lgp; } - static bool allows(Dtype type) { return type <= DATA_BIT; } IntBase() { a = 0; } diff --git a/Math/Setup.cpp b/Math/Setup.cpp index b4800017..dc76e47d 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -27,6 +27,13 @@ void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m) cerr << "Setting up parameters" << endl; #endif + m = default_m(lgp, idx); + generate_prime(p, lgp, m); +} + +int default_m(int& lgp, int& idx) +{ + int m; switch (lgp) { case -1: m=16; @@ -56,15 +63,12 @@ void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m) default: m=1; idx=0; -#ifdef VERBOSE - cerr << "no precomputed parameters, trying anyway" << endl; -#endif break; } #ifdef VERBOSE cerr << "m = " << m << endl; #endif - generate_prime(p, lgp, m); + return m; } bigint generate_prime(int lgp, int m) @@ -95,6 +99,9 @@ void generate_prime(bigint& p, int lgp, int m) return; } + int idx; + m = max(m, default_m(lgp, idx)); + bigint u; int ex; ex = lgp - numBits(m); diff --git a/Math/Setup.h b/Math/Setup.h index f8405ba3..8c599198 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -35,6 +35,7 @@ bigint SPDZ_Data_Setup_Primes(int lgp); void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m); void generate_prime(bigint& p, int lgp, int m); bigint generate_prime(int lgp, int m); +int default_m(int& lgp, int& idx); string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, const string& type_short); diff --git a/Math/bigint.h b/Math/bigint.h index 5cd31981..cb79f242 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -12,6 +12,7 @@ using namespace std; #include "Tools/random.h" #include "Tools/octetStream.h" #include "Tools/avx_memcpy.h" +#include "Protocols/config.h" enum ReportType { @@ -270,7 +271,8 @@ inline int probPrime(const bigint& x) { gmp_randstate_t rand_state; gmp_randinit_default(rand_state); - int ans=mpz_probable_prime_p(x.get_mpz_t(),rand_state,40,0); + int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state, + max(40, DEFAULT_SECURITY), 0); gmp_randclear(rand_state); return ans; } diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index f9491fb7..44e42479 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -18,7 +18,7 @@ bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 7 +#define num_2_fields 17 /* Require * 2*(n-1)-64+t1<64 @@ -26,11 +26,21 @@ word gf2n_short_table[256][256]; int fields_2[num_2_fields][4] = { { 4, 1, 0, 0 }, + { 5, 2, 0, 0 }, + { 6, 1, 0, 0 }, + { 7, 1, 0, 0 }, { 8, 4, 3, 1 }, + { 9, 1, 0, 0 }, + { 10, 3, 0, 0}, + { 11, 2, 0, 0}, + { 12, 3, 0, 0}, + { 14, 5, 0, 0}, + { 15, 1, 0, 0}, { 16, 5, 3, 1 }, { 28, 1, 0, 0 }, { 40, 20, 15, 10 }, { 63, 1, 0, 0 }, + { 64, 4, 3, 1}, { 128, 7, 2, 1 }, }; @@ -55,6 +65,21 @@ void gf2n_::init_tables() } } +template +void gf2n_::init_minimum(int lower) +{ + if (lower <= n) + return; + + for (int i = 0; i < num_2_fields; i++) + { + int n = fields_2[i][0]; + if (lower <= n and n <= MAX_N_BITS) + return init_field(n); + } + throw runtime_error("no suitable field for minimum degree " + to_string(lower)); +} + void gf2n_short::init_field(int nn) { super::init_field(nn == 0 ? DEFAULT_LENGTH : nn); @@ -88,7 +113,7 @@ void gf2n_::init_field(int nn) if (j==-1) { - throw runtime_error("field size not supported"); + throw gf2n_not_supported(nn); } n=nn; @@ -332,7 +357,11 @@ gf2n_ gf2n_::invert() const if (n < 64) return U(invert(a)); else - return invert>(a).get_lower(); + { + gf2n_ res; + res.a = invert(a).get_lower(); + return res; + } } template<> diff --git a/Math/gf2n.h b/Math/gf2n.h index add8627c..485d8430 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -65,6 +65,7 @@ protected: static void init_field(int nn = 0); static void init_default(int, bool = false) { init_field(); } + static void init_minimum(int lower); static void reset() { n = 0; } static int degree() { return n; } @@ -213,7 +214,7 @@ public: static const int DEFAULT_LENGTH = 40; static int length() { return n == 0 ? DEFAULT_LENGTH : n; } - static int default_degree() { return 40; } + static int default_degree() { return DEFAULT_LENGTH; } static void init_field(int nn = 0); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h new file mode 100644 index 00000000..22482c48 --- /dev/null +++ b/Networking/AllButLastPlayer.h @@ -0,0 +1,67 @@ +/* + * AllButZeroPlayer.h + * + */ + +#ifndef NETWORKING_ALLBUTLASTPLAYER_H_ +#define NETWORKING_ALLBUTLASTPLAYER_H_ + +#include "Player.h" + +class AllButLastPlayer : public Player +{ + const Player& P; + Names* N; + +public: + AllButLastPlayer(const Player& P) : + Player(*(N = new Names(P.my_num(), P.num_players() - 1))), P(P) + { + } + + ~AllButLastPlayer() + { + delete N; + } + + void send_to_no_stats(int player, const octetStream& o) const + { + P.send_to(player, o); + } + + void receive_player_no_stats(int i, octetStream& o) const + { + P.receive_player(i, o); + } + + void send_receive_all_no_stats(const vector>& channels, + const vector& to_send, + vector& to_receive) const + { + auto my_channels = channels; + my_channels.resize(P.num_players()); + for (auto& x : my_channels) + x.resize(P.num_players()); + auto my_to_send = to_send; + if (P.my_num() != P.num_players() - 1) + P.send_receive_all(my_channels, my_to_send, to_receive); + to_receive.resize(P.num_players() - 1); + } + + void Broadcast_Receive_no_stats(vector& os) const + { + vector to_send(P.num_players(), os[P.my_num()]); + vector> channels(P.num_players(), + vector(P.num_players(), true)); + for (auto& x: channels) + x.back() = false; + channels.back() = vector(P.num_players(), false); + vector to_receive; + P.send_receive_all(channels, to_send, to_receive); + for (int i = 0; i < P.num_players() - 1; i++) + if (i != P.my_num()) + os[i] = to_receive[i]; + } +}; + +#endif diff --git a/Networking/Player.cpp b/Networking/Player.cpp index b4bab177..a7935f30 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -146,10 +146,18 @@ void Names::setup_names(const char *servername, int my_port) #endif // Now get the set of names - octetStream os; - os.Receive(socket_num); - os.get(names); - os.get(ports); + try + { + octetStream os; + os.Receive(socket_num); + os.get(names); + os.get(ports); + } + catch (exception& e) + { + throw runtime_error(string("error in network setup: ") + e.what()); + } + if (names.size() != ports.size()) throw runtime_error("invalid network setup"); nplayers = names.size(); @@ -186,6 +194,11 @@ Names::Names(const Names& other) server = 0; } +Names::Names(int my_num, int num_players) : + nplayers(num_players), portnum_base(-1), player_no(my_num), server(0) +{ +} + Names::~Names() { if (server != 0) @@ -817,6 +830,17 @@ void NamedCommStats::print(bool newline) cerr << endl; } +void NamedCommStats::reset() +{ + clear(); + sent = 0; +} + +void PlayerBase::reset_stats() +{ + comm_stats.reset(); +} + NamedCommStats Player::total_comm() const { auto res = comm_stats; diff --git a/Networking/Player.h b/Networking/Player.h index ff4bdcd1..a547d479 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -116,7 +116,7 @@ class Names Names(ez::ezOptionParser& opt, int argc, const char** argv, int default_nplayers = 2); - Names() : nplayers(1), portnum_base(-1), player_no(0), server(0) { ; } + Names(int my_num = 0, int num_players = 1); Names(const Names& other); ~Names(); @@ -159,6 +159,7 @@ public: NamedCommStats operator-(const NamedCommStats& other) const; size_t total_data(); void print(bool newline = false); + void reset(); #ifdef VERBOSE_COMM CommStats& operator[](const string& name) { @@ -190,10 +191,19 @@ public: virtual int my_num() const = 0; virtual int num_players() const = 0; - virtual void pass_around(octetStream& o, int offset = 1) const = 0; - virtual void Broadcast_Receive(vector& o) const = 0; + virtual void receive_player(int, octetStream&) const + { throw not_implemented(); } + virtual void pass_around(octetStream&, int = 1) const + { throw not_implemented(); } + virtual void Broadcast_Receive(vector&) const + { throw not_implemented(); } virtual void unchecked_broadcast(vector& o) const { Broadcast_Receive(o); } + virtual void send_receive_all(const vector&, + vector&) const + { throw not_implemented(); } + + void reset_stats(); }; /** @@ -230,8 +240,8 @@ public: virtual bool is_encrypted() { return false; } - virtual void send_long(int i, long a) const = 0; - virtual long receive_long(int i) const = 0; + virtual void send_long(int, long) const { throw not_implemented(); } + virtual long receive_long(int) const { throw not_implemented(); } // The following functions generally update the statistics // and then call the *_no_stats equivalent specified by a subclass. @@ -283,7 +293,8 @@ public: * reusing the buffer if possible. */ void exchange(int other, const octetStream& to_send, octetStream& ot_receive) const; - virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0; + virtual void exchange_no_stats(int, const octetStream&, octetStream&) const + { throw runtime_error("implement exchange"); } /** * Exchange information with one other party, reusing the buffer. */ @@ -304,8 +315,8 @@ public: * The default is to send to the next party while receiving from the previous. */ void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const; - virtual void pass_around_no_stats(const octetStream& to_send, - octetStream& to_receive, int offset) const = 0; + virtual void pass_around_no_stats(const octetStream&, octetStream&, + int) const { throw runtime_error("implement passing around"); } /** * Broadcast and receive data to/from all players. @@ -317,7 +328,8 @@ public: * Assumes o[player_no] contains the data to be broadcast by me. */ virtual void Broadcast_Receive(vector& o) const; - virtual void Broadcast_Receive_no_stats(vector& o) const = 0; + virtual void Broadcast_Receive_no_stats(vector&) const + { throw runtime_error("implement broadcast"); } /** * Run protocol to verify broadcast is correct diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 732850bc..5fdbb3d6 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -15,6 +15,7 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiInput.hpp" #include "Protocols/SemiMC.hpp" +#include "Protocols/mac_key.hpp" #include #include @@ -274,9 +275,9 @@ void NPartyTripleGenerator::generateInputs(int player) inputs.resize(nTriplesPerLoop); typename W::input_check_type::MAC_Check MC(mac_key); - MC.POpen(check_sum, globalPlayer); // use zero element because all is perfectly randomized MC.set_random_element({}); + MC.POpen(check_sum, globalPlayer); MC.Check(globalPlayer); } @@ -673,7 +674,7 @@ void MascotTripleGenerator::sacrifice(typename T::MAC_Check& MC, PRNG& G) auto& outputFile = this->outputFile; auto& uncheckedTriples = this->uncheckedTriples; - assert(T::clear::length() >= 40); + check_field_size(); vector maskedAs(nTriplesPerLoop); vector > maskedTriples(nTriplesPerLoop); @@ -744,6 +745,8 @@ void Spdz2kTripleGenerator::sacrificeZ2k(U& MC, PRNG& G) // and first part of [sigma], i.e., t * [c] - [chat] maskedTriples[j].template prepare_sacrifice(uncheckedTriples[j], G); maskedAs[j] = maskedTriples[j].a[0]; + // enough randomness in values + MC.set_random_element({}); } vector openedAs(nTriplesPerLoop); @@ -754,6 +757,8 @@ void Spdz2kTripleGenerator::sacrificeZ2k(U& MC, PRNG& G) for (int j = 0; j < nTriplesPerLoop; j++) { // compute t * [c] - [chat] - [b] * p sigmas.push_back(maskedTriples[j].computeCheckShare(V(openedAs[j]))); + // enough randomness in values + MC.set_random_element({}); } vector open_sigmas; diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 019fc6f2..33d5f441 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -59,6 +59,8 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) cerr << "Number of program sequences I need to load = " << nprogs << endl; #endif + bc_filenames.clear(); + // Load in the programs string threadname; for (int i=0; i void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, false_type) { -#ifndef INSECURE - throw runtime_error("no secure implementation of reading edaBits from files"); -#endif + insecure("reading edaBits from files"); if (edabit_buffers.find(n_bits) == edabit_buffers.end()) { string filename = PrepBase::get_edabit_filename(prep_data_dir, diff --git a/Processor/Input.h b/Processor/Input.h index 728c81f6..0a84a55f 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -26,7 +26,7 @@ class InputBase typedef typename T::clear clear; protected: - Player* P; + PlayerBase* P; int my_num; Buffer buffer; @@ -63,7 +63,7 @@ public: /// Initialize input round for ``player`` virtual void reset(int player) = 0; /// Initialize input round for all players - void reset_all(Player& P); + void reset_all(PlayerBase& P); /// Schedule input from me virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 246c9eb1..09c6e056 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -81,7 +81,7 @@ void InputBase::reset(int player) } template -void InputBase::reset_all(Player& P) +void InputBase::reset_all(PlayerBase& P) { this->P = &P; my_num = P.my_num(); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index a7e1e318..5279b258 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -328,14 +328,14 @@ protected: int opcode; // The code int size; // Vector size int r[4]; // Fixed parameter registers - unsigned int n; // Possible immediate value + size_t n; // Possible immediate value vector start; // Values for a start/stop open public: virtual ~BaseInstruction() {}; int get_r(int i) const { return r[i]; } - unsigned int get_n() const { return n; } + size_t get_n() const { return n; } const vector& get_start() const { return start; } int get_opcode() const { return opcode; } int get_size() const { return size; } @@ -350,7 +350,7 @@ public: bool is_direct_memory_access() const; // Returns the memory size used if applicable and known - unsigned get_mem(RegType reg_type) const; + size_t get_mem(RegType reg_type) const; // Returns the maximal register used unsigned get_max_reg(int reg_type) const; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 1bc46f94..2a5dce70 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -218,24 +218,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 1 register + 1 integer operand case LDI: case LDSI: - case LDMC: - case LDMS: - case STMC: - case STMS: - case LDMSB: - case STMSB: - case LDMCB: - case STMCB: - case LDMINT: - case STMINT: case JMPNZ: case JMPEQZ: case GLDI: case GLDSI: - case GLDMC: - case GLDMS: - case GSTMC: - case GSTMS: case PRINTREG: case PRINTREGB: case GPRINTREG: @@ -247,6 +233,24 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) r[0]=get_int(s); n = get_int(s); break; + // instructions with 1 register + 1 long operand + case LDMC: + case LDMS: + case STMC: + case STMS: + case LDMSB: + case STMSB: + case LDMCB: + case STMCB: + case LDMINT: + case STMINT: + case GLDMC: + case GLDMS: + case GSTMC: + case GSTMS: + r[0] = get_int(s); + n = get_long(s); + break; // instructions with 1 integer operand case PRINTSTR: case PRINTCHR: @@ -783,7 +787,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const } inline -unsigned BaseInstruction::get_mem(RegType reg_type) const +size_t BaseInstruction::get_mem(RegType reg_type) const { if (get_reg_type() == reg_type and is_direct_memory_access()) return n + size; @@ -843,7 +847,7 @@ inline void Instruction::execute(Processor& Proc) const } int r[3] = {this->r[0], this->r[1], this->r[2]}; - int n = this->n; + int64_t n = this->n; for (int i = 0; i < size; i++) { switch (opcode) { @@ -1065,7 +1069,7 @@ inline void Instruction::execute(Processor& Proc) const case PRINTREG: { Proc.out << "Reg[" << r[0] << "] = " << Proc.read_Cp(r[0]) - << " # " << string((char*)&n,sizeof(n)) << endl; + << " # " << string((char*)&n, 4) << endl; } break; case PRINTREGPLAIN: @@ -1085,7 +1089,7 @@ inline void Instruction::execute(Processor& Proc) const case CONDPRINTSTR: if (not Proc.read_Cp(r[0]).is_zero()) { - string str = {(char*)&n, sizeof(n)}; + string str = {(char*)&n, 4}; size_t n = str.find('\0'); if (n < 4) str.erase(n); @@ -1313,7 +1317,7 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c out << "["; for (int i = 0; i < size; i++) { - if (p == 0) + if (p == 0 or (*p == 0 and s == 0)) out << v[i]; else if (s == 0) out << bigint::get_float(v[i], p[i], {}, {}); diff --git a/Processor/Machine.h b/Processor/Machine.h index 331a9a22..8b3d018c 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -46,6 +46,8 @@ class Machine : public BaseMachine void load_program(const string& threadname, const string& filename); + void prepare(const string& progname_str); + void suggest_optimizations(); public: @@ -71,10 +73,10 @@ class Machine : public BaseMachine ExecutionStats stats; - Machine(int my_number, Names& playerNames, const string& progname, - const string& memtype, int lg2, bool direct, int opening_sum, - bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep, - OnlineOptions opts); + static void init_binary_domains(int security_parameter, int lg2); + + Machine(Names& playerNames, bool use_encryption = true, + const OnlineOptions opts = sint(), int lg2 = 0); ~Machine(); const Names& get_N() { return N; } @@ -92,7 +94,11 @@ class Machine : public BaseMachine DataPositions run_tape(int thread_number, int tape_number, int arg, const DataPositions& pos); DataPositions join_tape(int thread_number); - void run(); + + void run(const string& progname); + + void run_step(const string& progname); + pair stop_threads(); string memory_filename(); @@ -102,6 +108,9 @@ class Machine : public BaseMachine void reqbl(int n); typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } + typename sint::mac_key_type get_sint_mac_key() { return alphapi; } + + Player& get_player() { return *P; } }; #endif /* MACHINE_H_ */ diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e720b2a9..e0299c2f 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -24,28 +24,52 @@ using namespace std; template -Machine::Machine(int my_number, Names& playerNames, - const string& progname_str, const string& memtype, - int lg2, bool direct, - int opening_sum, bool receive_threads, int max_broadcast, - bool use_encryption, bool live_prep, OnlineOptions opts) - : my_number(my_number), N(playerNames), - direct(direct), opening_sum(opening_sum), - receive_threads(receive_threads), max_broadcast(max_broadcast), - use_encryption(use_encryption), live_prep(live_prep), opts(opts) +void Machine::init_binary_domains(int security_parameter, int lg2) { + sgf2n::clear::init_field(lg2); + + if (not is_same()) + { + if (sgf2n::clear::degree() < security_parameter) + { + cerr << "Security parameter needs to be at most n in GF(2^n)." + << endl; + cerr << "Increase the latter (-lg2) or decrease the former (-S)." + << endl; + exit(1); + } + } + + if (not is_same()) + { + sint::bit_type::mac_key_type::init_minimum(security_parameter); + } + else + { + // Initialize field for CCD + sint::bit_type::part_type::open_type::init_field(); + } +} + +template +Machine::Machine(Names& playerNames, bool use_encryption, + const OnlineOptions opts, int lg2) + : my_number(playerNames.my_num()), N(playerNames), + direct(opts.direct), opening_sum(opts.opening_sum), + receive_threads(opts.receive_threads), max_broadcast(opts.max_broadcast), + use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts) +{ + OnlineOptions::singleton = opts; + if (opening_sum < 2) this->opening_sum = N.num_players(); if (max_broadcast < 2) this->max_broadcast = N.num_players(); // Set up the fields - sgf2n::clear::init_field(lg2); sint::clear::read_or_generate_setup(prep_dir_prefix(), opts); - sint::bit_type::mac_key_type::init_field(); - // Initialize gf2n_short for CCD - sint::bit_type::part_type::open_type::init_field(); + init_binary_domains(opts.security_parameter, lg2); // make directory for outputs if necessary mkdir_p(PREP_DIR); @@ -75,6 +99,7 @@ Machine::Machine(int my_number, Names& playerNames, sint::clear::next::template init(false); // Initialize the global memory + auto memtype = opts.memtype; if (memtype.compare("old")==0) { ifstream inpf; @@ -92,9 +117,18 @@ Machine::Machine(int my_number, Names& playerNames, { cerr << "Invalid memory argument" << endl; exit(1); } +} +template +void Machine::prepare(const string& progname_str) +{ + int old_n_threads = nthreads; + progs.clear(); load_schedule(progname_str); + // keep preprocessing + nthreads = max(old_n_threads, nthreads); + // initialize persistence if necessary for (auto& prog : progs) { @@ -122,7 +156,7 @@ Machine::Machine(int my_number, Names& playerNames, if (live_prep and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot)) { - for (int i = 0; i < nthreads; i++) + for (int i = old_n_threads; i < nthreads; i++) ot_setups.push_back({ *P, true }); } @@ -132,7 +166,7 @@ Machine::Machine(int my_number, Names& playerNames, queues.resize(nthreads); join_timer.resize(nthreads); - for (int i=0; i::Machine(int my_number, Names& playerNames, } // synchronize with clients before starting timer - for (int i=0; iresult(); } @@ -155,6 +189,9 @@ Machine::Machine(int my_number, Names& playerNames, template Machine::~Machine() { + sint::LivePrep::teardown(); + sgf2n::LivePrep::teardown(); + delete P; for (auto& queue : queues) delete queue; @@ -308,14 +345,12 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, //printf("Running line %d\n",exec); if (progs[tape_number].usage_unknown()) { -#ifndef INSECURE if (not opts.live_prep and thread_number != 0) { - cerr << "Internally called tape " << tape_number << - " has unknown offline data usage" << endl; - throw invalid_program(); + insecure( + "Internally called tape " + to_string(tape_number) + + " has unknown offline data usage"); } -#endif return DataPositions(N.num_players()); } else @@ -336,23 +371,20 @@ DataPositions Machine::join_tape(int i) } template -void Machine::run() +void Machine::run_step(const string& progname) { - Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); - proc_timer.start(); - timer[0].start({}); - - // run main tape + prepare(progname); run_tape(0, 0, 0, N.num_players()); join_tape(0); +} - print_compiler(); - - finish_timer.start(); +template +pair Machine::stop_threads() +{ // Tell all C-threads to stop for (int i=0; ischedule(-1); } @@ -369,6 +401,40 @@ void Machine::run() pos.increase(queues[i]->result().pos); pthread_join(threads[i],NULL); } + + auto comm_stats = total_comm(); + + for (auto& queue : queues) + delete queue; + + queues.clear(); + + nthreads = 0; + + return {pos, comm_stats}; +} + +template +void Machine::run(const string& progname) +{ + prepare(progname); + + Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); + proc_timer.start(); + timer[0].start({}); + + // run main tape + run_tape(0, 0, 0, N.num_players()); + join_tape(0); + + print_compiler(); + + finish_timer.start(); + + // actual usage + auto res = stop_threads(); + DataPositions& pos = res.first; + finish_timer.stop(); #ifdef VERBOSE @@ -387,7 +453,7 @@ void Machine::run() cerr << "Finish timer: " << finish_timer.elapsed() << endl; #endif - NamedCommStats comm_stats = total_comm(); + NamedCommStats& comm_stats = res.second; if (opts.verbose) { @@ -475,17 +541,12 @@ void Machine::run() stats.print(); } -#ifndef INSECURE if (not opts.file_prep_per_thread) { Data_Files df(*this); df.seekg(pos); df.prune(); } -#endif - - sint::LivePrep::teardown(); - sgf2n::LivePrep::teardown(); suggest_optimizations(); diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index 6e0bb525..b04ea6a1 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -37,13 +37,16 @@ template int OfflineMachine::run() { T::clear::init_default(this->online_opts.prime_length()); - U::clear::init_field(U::clear::default_degree()); - T::bit_type::mac_key_type::init_field(); + Machine::init_binary_domains(this->online_opts.security_parameter, + this->lg2); auto binary_mac_key = read_generate_write_mac_key< typename T::bit_type::part_type>(P); typename T::bit_type::LivePrep bit_prep(usage); GC::ShareThread thread(bit_prep, P, binary_mac_key); + // setup before generation to fix prime + T::LivePrep::basic_setup(P); + generate(); generate(); generate(); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index e98f1a3a..b0c5e579 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -100,6 +100,9 @@ void thread_info::Sub_Main_Func() processor = new Processor(tinfo->thread_num,P,*MC2,*MCp,machine,progs.at(thread_num > 0)); auto& Proc = *processor; + // don't count communication for initialization + P.reset_stats(); + bool flag=true; int program=-3; // int exec=0; @@ -287,10 +290,8 @@ void thread_info::Sub_Main_Func() // final check Proc.check(); -#ifndef INSECURE if (machine.opts.file_prep_per_thread) Proc.DataF.prune(); -#endif wait_timer.start(); queues->next(); diff --git a/Processor/OnlineMachine.h b/Processor/OnlineMachine.h index 9804828e..68e38e37 100644 --- a/Processor/OnlineMachine.h +++ b/Processor/OnlineMachine.h @@ -17,11 +17,11 @@ protected: const char** argv; OnlineOptions& online_opts; - int lg2, opening_sum, max_broadcast; + int lg2; Names playerNames; - bool use_encryption, receive_threads; + bool use_encryption; ez::ezOptionParser& opt; diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 4e944d62..d4c66e9a 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -18,7 +18,7 @@ template int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_prep_default = true) { OnlineOptions& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, 1000, live_prep_default, T::clear::invertible}; + online_opts = {opt, argc, argv, T(), live_prep_default}; DishonestMajorityMachine machine(argc, argv, opt, online_opts, typename U::clear()); return machine.run(); @@ -28,8 +28,7 @@ template OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers, V) : argc(argc), argv(argv), online_opts(online_opts), lg2(0), - opening_sum(0), max_broadcast(0), - use_encryption(false), receive_threads(false), + use_encryption(false), opt(opt), nplayers(nplayers) { opt.add( @@ -125,33 +124,6 @@ DishonestMajorityMachine::DishonestMajorityMachine(int argc, const char** argv, opt.example = string() + argv[0] + " -p 0 -N 2 sample-prog\n" + argv[0] + " -h localhost -p 1 -N 2 sample-prog\n"; - opt.add( - "0", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Sum at most n shares at once when using indirect communication", // Help description. - "-s", // Flag token. - "--opening-sum" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Use player-specific threads for communication", // Help description. - "-t", // Flag token. - "--threads" // Flag token. - ); - opt.add( - "0", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Maximum number of parties to send to at once", // Help description. - "-mb", // Flag token. - "--max-broadcast" // Flag token. - ); opt.add( "", // Default. 0, // Required? @@ -163,11 +135,7 @@ DishonestMajorityMachine::DishonestMajorityMachine(int argc, const char** argv, ); online_opts.finalize(opt, argc, argv); - opt.get("--opening-sum")->getInt(opening_sum); - opt.get("--max-broadcast")->getInt(max_broadcast); - use_encryption = opt.isSet("--encrypted"); - receive_threads = opt.isSet("--threads"); start_networking(); } @@ -230,12 +198,8 @@ int OnlineMachine::run() try #endif { - Machine(online_opts.playerno, playerNames, online_opts.progname, - online_opts.memtype, lg2, - online_opts.direct, opening_sum, - receive_threads, max_broadcast, - use_encryption, online_opts.live_prep, - online_opts).run(); + Machine(playerNames, use_encryption, online_opts, lg2).run( + online_opts.progname); if (online_opts.verbose) { diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 2a5e090b..d404f642 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -8,6 +8,7 @@ #include "Math/gfp.h" #include "Math/gfpvar.h" #include "Protocols/HemiOptions.h" +#include "Protocols/config.h" #include "Math/gfp.hpp" @@ -26,10 +27,14 @@ OnlineOptions::OnlineOptions() : playerno(-1) bits_from_squares = false; direct = false; bucket_size = 4; + security_parameter = DEFAULT_SECURITY; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; - trunc_error = 40; + trunc_error = DEFAULT_SECURITY; + opening_sum = 0; + max_broadcast = 0; + receive_threads = false; #ifdef VERBOSE verbose = true; #else @@ -38,7 +43,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, - const char** argv, false_type) : + const char** argv, bool security) : OnlineOptions() { opt.syntax = std::string(argv[0]) + " [OPTIONS] [] "; @@ -104,6 +109,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--bucket-size" // Flag token. ); + if (security) + opt.add( + to_string(security_parameter).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + ("Security parameter (default: " + to_string(security_parameter) + + ")").c_str(), // Help description. + "-S", // Flag token. + "--security" // Flag token. + ); + opt.parse(argc, argv); interactive = opt.isSet("-I"); @@ -117,13 +134,24 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, verbose = opt.isSet("--verbose"); #endif + if (security) + { + opt.get("-S")->getInt(security_parameter); + cerr << "Using security parameter " << security_parameter << endl; + if (security_parameter <= 0) + { + cerr << "Invalid security parameter: " << security_parameter << endl; + exit(1); + } + } + opt.resetArgs(); } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size, bool default_live_prep, - bool variable_prime_length) : - OnlineOptions(opt, argc, argv, false_type()) + bool variable_prime_length, bool security) : + OnlineOptions(opt, argc, argv, security) { if (default_batch_size <= 0) default_batch_size = batch_size; @@ -263,6 +291,9 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, vector badOptions; unsigned int i; + opt.footer += "\nSee also https://mp-spdz.readthedocs.io/en/latest/networking.html " + "for documentation on the networking setup.\n"; + if (allArgs.size() != 3u - opt.isSet("-p")) { cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; @@ -329,6 +360,16 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, } set_trunc_error(opt); + + auto o = opt.get("--opening-sum"); + if (o) + o->getInt(opening_sum); + + o = opt.get("--max-broadcast"); + if (o) + o->getInt(max_broadcast); + + receive_threads = opt.isSet("--threads"); } void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 4b2fe4f8..61c1352b 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -26,21 +26,26 @@ public: bool bits_from_squares; bool direct; int bucket_size; + int security_parameter; std::string cmd_private_input_file; std::string cmd_private_output_file; bool verbose; bool file_prep_per_thread; int trunc_error; + int opening_sum, max_broadcast; + bool receive_threads; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, - false_type); + bool security); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, - bool variable_prime_length = false); + bool variable_prime_length = false, bool security = true); template OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, bool default_live_prep = true); + template + OnlineOptions(T); ~OnlineOptions() {} void finalize(ez::ezOptionParser& opt, int argc, const char** argv); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp index 8961853e..d8b71cea 100644 --- a/Processor/OnlineOptions.hpp +++ b/Processor/OnlineOptions.hpp @@ -20,11 +20,49 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Probabilistic truncation error " - "(2^-x, default: 40)", // Help description. + ("Probabilistic truncation error (2^-x, default: " + + to_string(trunc_error) + ")").c_str(), // Help description. "-E", // Flag token. "--trunc-error" // Flag token. ); + + if (T::dishonest_majority) + { + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Sum at most n shares at once when using indirect communication", // Help description. + "-s", // Flag token. + "--opening-sum" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use player-specific threads for communication", // Help description. + "-t", // Flag token. + "--threads" // Flag token. + ); + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Maximum number of parties to send to at once", // Help description. + "-mb", // Flag token. + "--max-broadcast" // Flag token. + ); + } +} + +template +OnlineOptions::OnlineOptions(T) : OnlineOptions() +{ + if (T::dishonest_majority) + batch_size = 1000; } #endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/Program.h b/Processor/Program.h index 87a263f0..8fb3df14 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -21,7 +21,7 @@ class Program unsigned max_reg[MAX_REG_TYPE]; // Memory size used directly - unsigned max_mem[MAX_REG_TYPE]; + size_t max_mem[MAX_REG_TYPE]; // True if program contains variable-sized loop bool unknown_usage; @@ -48,7 +48,7 @@ class Program unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } - unsigned direct_mem(RegType reg_type) const + size_t direct_mem(RegType reg_type) const { return max_mem[reg_type]; } friend ostream& operator<<(ostream& s,const Program& P); diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index f2bfc6c1..62694221 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -65,7 +65,7 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri int argc, const char** argv, ez::ezOptionParser& opt) { OnlineOptions online_opts(opt, argc, argv); - RingOptions opts(opt, argc, argv, true); + RingOptions opts(opt, argc, argv); HonestMajorityMachine machine(argc, argv, opt, online_opts); int R = opts.ring_size_from_opts_or_schedule(online_opts.progname); switch (R) @@ -76,15 +76,19 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri break; #define X(K) \ case K: \ - switch (opts.S) \ + { \ + int S = online_opts.security_parameter; \ + switch (S) \ { \ - Y(K, 40) \ + Y(K, DEFAULT_SECURITY) \ default: \ - cerr << "not compiled for security parameter " << to_string(opts.S) << endl; \ - cerr << "add 'Y(K, " << opts.S << ")' to " __FILE__ ", line 76" << endl; \ + cerr << "not compiled for security parameter " << to_string(S) << endl; \ + cerr << "add 'Y(K, " << S << ")' to " __FILE__ ", line 76" << endl; \ + cerr << "or compile with -DDEFAULT_SECURITY=" << S << endl; \ exit(1); \ } \ - break; + break; \ + } X(64) #ifdef RING_SIZE X(RING_SIZE) diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp index d59a709a..ec9e9f06 100644 --- a/Processor/RingOptions.cpp +++ b/Processor/RingOptions.cpp @@ -9,8 +9,7 @@ #include using namespace std; -RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, - bool security) +RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) { opt.add( "64", // Default. @@ -21,28 +20,12 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, "-R", // Flag token. "--ring" // Flag token. ); - if (security) - opt.add( - "40", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Security parameter (default: 40)", // Help description. - "-S", // Flag token. - "--security" // Flag token. - ); opt.parse(argc, argv); opt.get("-R")->getInt(R); - if (security) - opt.get("-S")->getInt(S); - else - S = -1; R_is_set = opt.isSet("-R"); opt.resetArgs(); if (R_is_set) cerr << "Trying to run " << R << "-bit computation" << endl; - if (security) - cerr << "Using security parameter " << S << endl; } int RingOptions::ring_size_from_opts_or_schedule(string progname) diff --git a/Processor/RingOptions.h b/Processor/RingOptions.h index 899c7021..8f5361f6 100644 --- a/Processor/RingOptions.h +++ b/Processor/RingOptions.h @@ -16,10 +16,8 @@ class RingOptions public: int R; - int S; - RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, - bool security = false); + RingOptions(ez::ezOptionParser& opt, int argc, const char** argv); int ring_size_from_opts_or_schedule(string progname); }; diff --git a/Processor/instructions.h b/Processor/instructions.h index 5928fdab..bf443b0f 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -203,7 +203,7 @@ *dest++ = *op1++ == *op2++) \ X(PRINTINT, Proc.out << Proc.read_Ci(r[0]) << flush,) \ X(PRINTFLOATPREC, Proc.out << setprecision(n),) \ - X(PRINTSTR, Proc.out << string((char*)&n,sizeof(n)) << flush,) \ + X(PRINTSTR, Proc.out << string((char*)&n,4) << flush,) \ X(PRINTCHR, Proc.out << string((char*)&n,1) << flush,) \ X(SHUFFLE, shuffle(Proc),) \ X(BITDECINT, bitdecint(Proc),) \ @@ -270,7 +270,7 @@ *dest++ = *op1++ >> n) \ X(GPRINTREG, auto source = &C2[r[0]], \ Proc.out << "Reg[" << r[0] << "] = " << *source++ \ - << " # " << string((char*)&n,sizeof(n)) << endl) \ + << " # " << string((char*)&n, 4) << endl) \ X(GPRINTREGPLAIN, auto source = &C2[r[0]], \ Proc.out << *source++ << flush) \ X(GBITDEC, gbitdec(C2),) \ diff --git a/Programs/Source/l2h_comparison.mpc b/Programs/Source/l2h_comparison.mpc new file mode 100644 index 00000000..c233caa7 --- /dev/null +++ b/Programs/Source/l2h_comparison.mpc @@ -0,0 +1,3 @@ +res = sint.load_mem(0) < sint.load_mem(1) +res.store_in_mem(3) +print_ln('comparison in VM: %s', res.reveal()) diff --git a/Programs/Source/l2h_multiplication.mpc b/Programs/Source/l2h_multiplication.mpc new file mode 100644 index 00000000..aecbca65 --- /dev/null +++ b/Programs/Source/l2h_multiplication.mpc @@ -0,0 +1 @@ +(sint.load_mem(0) * sint.load_mem(1)).store_in_mem(2) diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 2d28127c..9b695d0d 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -23,6 +23,7 @@ class Player; template class Beaver : public ProtocolBase { +protected: vector shares; vector opened; vector> triples; diff --git a/Protocols/ChaiGearPrep.hpp b/Protocols/ChaiGearPrep.hpp index 69b16fcf..07693178 100644 --- a/Protocols/ChaiGearPrep.hpp +++ b/Protocols/ChaiGearPrep.hpp @@ -43,15 +43,16 @@ void ChaiGearPrep::basic_setup(Player& P) assert(machine == 0); machine = new MultiplicativeMachine; auto& setup = machine->setup.part(); - auto& options = CowGearOptions::singleton; + int lowgear_security = OnlineOptions::singleton.security_parameter; #ifdef VERBOSE + auto& options = CowGearOptions::singleton; cerr << "Covert security parameter for key and MAC generation: " << options.covert_security << endl; cerr << "Triple generation security parameter: " - << options.lowgear_security << endl; + << lowgear_security << endl; #endif - machine->sec = options.lowgear_security; - setup.secure_init(P, *machine, T::clear::length(), options.lowgear_security); + machine->sec = lowgear_security; + setup.secure_init(P, *machine, T::clear::length(), lowgear_security); T::clear::template init(); #ifdef VERBOSE cerr << T::type_string() << " parameter setup took " << timer.elapsed() diff --git a/Protocols/CowGearOptions.cpp b/Protocols/CowGearOptions.cpp index 9212a7e0..e018dd8b 100644 --- a/Protocols/CowGearOptions.cpp +++ b/Protocols/CowGearOptions.cpp @@ -23,7 +23,6 @@ CowGearOptions::CowGearOptions(bool covert) covert_security = -1; } - lowgear_security = 40; use_top_gear = false; } @@ -49,7 +48,7 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "LowGear security parameter (default: 40)", // Help description. + "DEPRECATED: use -S/--security", // Help description. "-l", // Flag token. "--lowgear-security" // Flag token. ); @@ -76,15 +75,8 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, opt.get("-c")->getInt(covert_security); if (opt.isSet("-l")) { - opt.get("-l")->getInt(lowgear_security); - if (lowgear_security <= 0) - { - throw exception(); - cerr << "Invalid LowGear Security parameter: " << lowgear_security << endl; - exit(1); - } - if (covert_security > (1LL << lowgear_security)) - insecure(", LowGear security less than key generation security"); + cerr << "Deprecated parameter, use -S/--security" << endl; + exit(1); } use_top_gear = not opt.isSet("-J"); if (opt.isSet("-T")) diff --git a/Protocols/CowGearOptions.h b/Protocols/CowGearOptions.h index f79bd521..af9006dc 100644 --- a/Protocols/CowGearOptions.h +++ b/Protocols/CowGearOptions.h @@ -16,7 +16,6 @@ public: static CowGearOptions singleton; int covert_security; - int lowgear_security; CowGearOptions(bool covert = true); CowGearOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index 3b9daae1..36b36d66 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -38,14 +38,15 @@ void CowGearPrep::basic_setup(Player& P) pairwise_machine = new PairwiseMachine(P); auto& machine = *pairwise_machine; auto& setup = machine.setup(); - auto& options = CowGearOptions::singleton; + int lowgear_security = OnlineOptions::singleton.security_parameter; #ifdef VERBOSE + auto& options = CowGearOptions::singleton; if (T::covert) cerr << "Covert security parameter for key and MAC generation: " << options.covert_security << endl; - cerr << "LowGear security parameter: " << options.lowgear_security << endl; + cerr << "LowGear security parameter: " << lowgear_security << endl; #endif - setup.secure_init(P, machine, T::clear::length(), options.lowgear_security); + setup.secure_init(P, machine, T::clear::length(), lowgear_security); T::clear::template init(); #ifdef VERBOSE cerr << T::type_string() << " parameter setup took " << timer.elapsed() diff --git a/Protocols/DabitSacrifice.h b/Protocols/DabitSacrifice.h index 3b436547..6da8cc23 100644 --- a/Protocols/DabitSacrifice.h +++ b/Protocols/DabitSacrifice.h @@ -9,10 +9,12 @@ template class DabitSacrifice { - static const int S = 40; + const int S; public: - static int minimum_n_inputs(int n_outputs = 0) + DabitSacrifice(); + + int minimum_n_inputs(int n_outputs = 0) { if (n_outputs < 1) n_outputs = OnlineOptions::singleton.batch_size; diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index aa2abf61..74d9f026 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -11,6 +11,12 @@ #include +template +DabitSacrifice::DabitSacrifice() : + S(OnlineOptions::singleton.security_parameter) +{ +} + template dabit& operator+=(dabit& x, const dabit& y) { diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h new file mode 100644 index 00000000..7d0699da --- /dev/null +++ b/Protocols/DealerInput.h @@ -0,0 +1,38 @@ +/* + * DealerInput.h + * + */ + +#ifndef PROTOCOLS_DEALERINPUT_H_ +#define PROTOCOLS_DEALERINPUT_H_ + +#include "../Networking/AllButLastPlayer.h" +#include "Processor/Input.h" + +template +class DealerInput : public InputBase +{ + Player& P; + octetStreams to_send, to_receive; + SeededPRNG G; + vector> shares; + bool from_dealer; + AllButLastPlayer sub_player; + SemiInput>* internal; + +public: + DealerInput(SubProcessor& proc, typename T::MAC_Check&); + DealerInput(typename T::MAC_Check&, Preprocessing&, Player& P); + DealerInput(Player& P); + ~DealerInput(); + + bool is_dealer(int player = -1); + + void reset(int player); + void add_mine(const typename T::open_type& input, int n_bits = -1); + void add_other(int player, int n_bits = -1); + void exchange(); + T finalize(int player, int n_bits = -1); +}; + +#endif /* PROTOCOLS_DEALERINPUT_H_ */ diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp new file mode 100644 index 00000000..26bfb9a1 --- /dev/null +++ b/Protocols/DealerInput.hpp @@ -0,0 +1,115 @@ +/* + * DealerInput.hpp + * + */ + +#ifndef PROTOCOLS_DEALERINPUT_HPP_ +#define PROTOCOLS_DEALERINPUT_HPP_ + +#include "DealerInput.h" + +template +DealerInput::DealerInput(SubProcessor& proc, typename T::MAC_Check&) : + DealerInput(proc.P) +{ +} + +template +DealerInput::DealerInput(typename T::MAC_Check&, Preprocessing&, + Player& P) : + DealerInput(P) +{ +} + +template +DealerInput::DealerInput(Player& P) : + P(P), to_send(P), shares(P.num_players()), from_dealer(false), + sub_player(P) +{ + if (is_dealer()) + internal = 0; + else + internal = new SemiInput>(0, sub_player); +} + +template +DealerInput::~DealerInput() +{ + if (internal) + delete internal; +} + +template +bool DealerInput::is_dealer(int player) +{ + int dealer_player = P.num_players() - 1; + if (player == -1) + return P.my_num() == dealer_player; + else + return player == dealer_player; +} + +template +void DealerInput::reset(int player) +{ + if (player == 0) + { + to_send.reset(P); + from_dealer = false; + } + else if (not is_dealer()) + internal->reset(player - 1); +} + +template +void DealerInput::add_mine(const typename T::open_type& input, + int) +{ + if (is_dealer()) + { + make_share(shares.data(), input, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(to_send[i]); + from_dealer = true; + } + else + internal->add_mine(input); +} + +template +void DealerInput::add_other(int player, int) +{ + if (is_dealer(player)) + from_dealer = true; + else if (not is_dealer()) + internal->add_other(player); +} + +template +void DealerInput::exchange() +{ + if (from_dealer) + { + vector senders(P.num_players()); + senders.back() = true; + P.send_receive_all(senders, to_send, to_receive); + } + else if (not is_dealer()) + internal->exchange(); +} + +template +T DealerInput::finalize(int player, int) +{ + if (is_dealer()) + return {}; + else + { + if (is_dealer(player)) + return to_receive.back().template get(); + else + return internal->finalize(player); + } +} + +#endif /* PROTOCOLS_DEALERINPUT_HPP_ */ diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h new file mode 100644 index 00000000..5311f813 --- /dev/null +++ b/Protocols/DealerMC.h @@ -0,0 +1,42 @@ +/* + * DealerMC.h + * + */ + +#ifndef PROTOCOLS_DEALERMC_H_ +#define PROTOCOLS_DEALERMC_H_ + +#include "MAC_Check_Base.h" +#include "Networking/AllButLastPlayer.h" + +template +class DealerMC : public MAC_Check_Base +{ + typedef SemiMC> internal_type; + internal_type& internal; + AllButLastPlayer* sub_player; + +public: + DealerMC(typename T::mac_key_type = {}, int = 0, int = 0); + DealerMC(internal_type& internal); + ~DealerMC(); + + void init_open(const Player& P, int n = 0); + void prepare_open(const T& secret); + void exchange(const Player& P); + typename T::open_type finalize_raw(); + + DealerMC& get_part_MC() + { + return *this; + } +}; + +template +class DirectDealerMC : public DealerMC +{ +public: + DirectDealerMC(typename T::mac_key_type = {}); +}; + +#endif /* PROTOCOLS_DEALERMC_H_ */ diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp new file mode 100644 index 00000000..a9ddc035 --- /dev/null +++ b/Protocols/DealerMC.hpp @@ -0,0 +1,76 @@ +/* + * DealerMC.hpp + * + */ + +#ifndef PROTOCOLS_DEALERMC_HPP_ +#define PROTOCOLS_DEALERMC_HPP_ + +#include "DealerMC.h" + +template +DealerMC::DealerMC(typename T::mac_key_type, int, int) : + DealerMC(*(new internal_type)) +{ +} + +template +DirectDealerMC::DirectDealerMC(typename T::mac_key_type) : + DealerMC(*(new DirectSemiMC>)) +{ +} + +template +DealerMC::DealerMC(internal_type& internal) : + internal(internal), sub_player(0) +{ +} + +template +DealerMC::~DealerMC() +{ + delete &internal; + if (sub_player) + delete sub_player; +} + +template +void DealerMC::init_open(const Player& P, int n) +{ + if (P.my_num() != P.num_players() - 1) + { + if (not sub_player) + sub_player = new AllButLastPlayer(P); + internal.init_open(P, n); + } +} + +template +void DealerMC::prepare_open(const T& secret) +{ + if (sub_player) + internal.prepare_open(secret); + else + { + if (secret != T()) + throw runtime_error("share for dealer should be 0"); + } +} + +template +void DealerMC::exchange(const Player&) +{ + if (sub_player) + internal.exchange(*sub_player); +} + +template +typename T::open_type DealerMC::finalize_raw() +{ + if (sub_player) + return internal.finalize_raw(); + else + return {}; +} + +#endif /* PROTOCOLS_DEALERMC_HPP_ */ diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h new file mode 100644 index 00000000..ae28ec69 --- /dev/null +++ b/Protocols/DealerPrep.h @@ -0,0 +1,33 @@ +/* + * DealerPrep.h + * + */ + +#ifndef PROTOCOLS_DEALERPREP_H_ +#define PROTOCOLS_DEALERPREP_H_ + +#include "ReplicatedPrep.h" + +template +class DealerPrep : virtual public BitPrep +{ + template + void buffer_edabits(int n_bits, true_type); + template + void buffer_edabits(int n_bits, false_type); + +public: + DealerPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage) + { + } + + void buffer_triples(); + void buffer_bits(); + + void buffer_dabits(ThreadQueues* = 0); + void buffer_edabits(int n_bits, ThreadQueues*); + void buffer_sedabits(int n_bits, ThreadQueues*); +}; + +#endif /* PROTOCOLS_DEALERPREP_H_ */ diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp new file mode 100644 index 00000000..d4a0a91d --- /dev/null +++ b/Protocols/DealerPrep.hpp @@ -0,0 +1,196 @@ +/* + * DealerPrep.hpp + * + */ + +#ifndef PROTOCOLS_DEALERPREP_HPP_ +#define PROTOCOLS_DEALERPREP_HPP_ + +#include "DealerPrep.h" + +template +void DealerPrep::buffer_triples() +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T triples[3]; + for (int i = 0; i < 2; i++) + triples[i] = G.get(); + triples[2] = triples[0] * triples[1]; + for (auto& value : triples) + { + make_share(shares.data(), typename T::clear(value), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + this->triples.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->triples.push_back(to_receive.back().get>().get()); + } +} + +template +void DealerPrep::buffer_bits() +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T bit = G.get_bit(); + make_share(shares.data(), typename T::clear(bit), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + this->bits.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->bits.push_back(to_receive.back().get()); + } +} + +template +void DealerPrep::buffer_dabits(ThreadQueues*) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + vector bit_shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + auto bit = G.get_bit(); + make_share(shares.data(), typename T::clear(bit), + P.num_players() - 1, 0, G); + make_share(bit_shares.data(), typename T::bit_type::clear(bit), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + { + shares.at(i - 1).pack(os[i - 1]); + bit_shares.at(i - 1).pack(os[i - 1]); + } + this->dabits.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + this->dabits.push_back({to_receive.back().get(), + to_receive.back().get()}); + } + } +} + +template +void DealerPrep::buffer_sedabits(int length, ThreadQueues*) +{ + auto& buffer = this->edabits[{false, length}]; + if (buffer.empty()) + buffer_edabits(length, 0); + this->edabits[{true, length}].push_back(buffer.back()); + buffer.pop_back(); +} + +template +void DealerPrep::buffer_edabits(int length, ThreadQueues*) +{ + buffer_edabits(length, T::clear::characteristic_two); +} + +template +template +void DealerPrep::buffer_edabits(int, true_type) +{ + throw not_implemented(); +} + +template +template +void DealerPrep::buffer_edabits(int length, false_type) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + int n_vecs = OnlineOptions::singleton.batch_size / edabitvec::MAX_SIZE; + auto& buffer = this->edabits[{false, length}]; + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + vector bit_shares(P.num_players() - 1); + for (int i = 0; i < n_vecs; i++) + { + vector as; + vector bs; + plain_edabits(as, bs, length, G); + for (auto& a : as) + { + make_share(shares.data(), a, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + for (auto& b : bs) + { + make_share(bit_shares.data(), b, P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + bit_shares.at(i - 1).pack(os[i - 1]); + } + buffer.push_back({}); + buffer.back().a.resize(edabitvec::MAX_SIZE); + buffer.back().b.resize(length); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < n_vecs; i++) + { + buffer.push_back({}); + for (int j = 0; j < edabitvec::MAX_SIZE; j++) + buffer.back().a.push_back(to_receive.back().get()); + for (int j = 0; j < length; j++) + buffer.back().b.push_back( + to_receive.back().get()); + } + } +} + +#endif /* PROTOCOLS_DEALERPREP_HPP_ */ diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h new file mode 100644 index 00000000..38900ff3 --- /dev/null +++ b/Protocols/DealerShare.h @@ -0,0 +1,76 @@ +/* + * DealerShare.h + * + */ + +#ifndef PROTOCOLS_DEALERSHARE_H_ +#define PROTOCOLS_DEALERSHARE_H_ + +#include "Math/Z2k.h" +#include "SemiShare.h" + +template class DealerPrep; +template class DealerInput; +template class DealerMC; +template class DirectDealerMC; + +namespace GC +{ +class DealerSecret; +} + +template +class DealerShare : public SemiShare +{ + typedef DealerShare This; + typedef SemiShare super; + +public: + typedef GC::DealerSecret bit_type; + + typedef DealerMC MAC_Check; + typedef DirectDealerMC Direct_MC; + typedef Beaver Protocol; + typedef DealerInput Input; + typedef DealerPrep LivePrep; + typedef ::PrivateOutput PrivateOutput; + + static false_type dishonest_majority; + const static bool needs_ot = false; + + static string type_short() + { + return "DD" + string(1, T::type_char()); + } + + static int threshold(int) + { + throw runtime_error("undefined threshold"); + } + + static This constant(const T& other, int my_num, + const typename super::mac_key_type& = {}, int = -1) + { + if (my_num == 1) + return other; + else + return {}; + } + + DealerShare() + { + } + + template + DealerShare(const U& other) : super(other) + { + } +}; + +template +using DealerRingShare = DealerShare>; + +template +false_type DealerShare::dishonest_majority; + +#endif /* PROTOCOLS_DEALERSHARE_H_ */ diff --git a/Protocols/FakeMC.h b/Protocols/FakeMC.h index d16dda1c..b5876ec0 100644 --- a/Protocols/FakeMC.h +++ b/Protocols/FakeMC.h @@ -12,7 +12,7 @@ template class FakeMC : public MAC_Check_Base { public: - FakeMC(T, int = 0, int = 0) + FakeMC(typename T::mac_key_type, int = 0, int = 0) { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index fb55f0cf..018ac338 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -28,6 +28,7 @@ class FakeProtocol : public ProtocolBase vector trunc_stats; map cisc_stats; + map ltz_stats; public: Player& P; @@ -54,6 +55,8 @@ public: { cerr << x.second << " " << x.first << endl; } + for (auto& x : ltz_stats) + cerr << "LTZ " << x.first << ": " << x.second << endl; } template @@ -219,6 +222,7 @@ public: { for (size_t i = 0; i < args.size(); i += args[i]) { + ltz_stats[args[i + 4]] += args[i + 1]; assert(i + args[i] <= args.size()); assert(args[i] == 6); for (int j = 0; j < args[i + 1]; j++) diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index 569c136e..c0a269d1 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -19,7 +19,6 @@ class FakeShare : public T, public ShareInterface typedef FakeShare This; public: - typedef T mac_key_type; typedef T open_type; typedef T clear; @@ -45,7 +44,7 @@ public: return 0; } - static T constant(T value, int = 0, T = 0) + static T constant(T value, int = 0, mac_key_type = {}) { return value; } diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 8a00c793..f43260ea 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -16,6 +16,7 @@ template class Hemi : public Semi { map, HemiMatrixPrep*> matrix_preps; + DataPositions matrix_usage; ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 1eebd3b7..1b3d8f5b 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -27,7 +27,8 @@ HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, new HemiMatrixPrep(dims[0], dims[1], dims[2], - dynamic_cast(processor.DataF))}); + dynamic_cast(processor.DataF), + matrix_usage)}); return *matrix_preps.at(dims); } diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index ea5a7211..8038e8ef 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -22,15 +22,15 @@ class HemiMatrixPrep : public BufferPrep> int n_rows, n_inner, n_cols; bool swapped; - DataPositions* usage; LivePrep* prep; HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: - HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) : - super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep, + DataPositions& usage) : + super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { swapped = n_rows > n_cols; @@ -39,11 +39,6 @@ public: assert(this->n_cols >= this->n_rows); } - ~HemiMatrixPrep() - { - delete usage; - } - void set_protocol(typename ShareMatrix::Protocol&) { } diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index f4221299..b2dd92d2 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -5,6 +5,7 @@ #include "HemiMatrixPrep.h" #include "FHE/Diagonalizer.h" +#include "Tools/Bundle.h" class CipherPlainMultJob : public ThreadJob { diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index c456424e..ce55bce7 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -30,7 +30,9 @@ void HemiPrep::basic_setup(Player& P) pairwise_machine = new PairwiseMachine(P); auto& machine = *pairwise_machine; auto& setup = machine.setup(); - setup.secure_init(P, machine, T::clear::length(), 40); + setup.params.set_matrix_dim_from_options(); + setup.params.set_sec(OnlineOptions::singleton.security_parameter); + setup.secure_init(P, machine, T::clear::length(), 0); T::clear::template init(); } diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 2250417d..19d5e72d 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -122,7 +122,6 @@ template class MAC_Check_Z2k : public Tree_MAC_Check { protected: - vector shares; Preprocessing* prep; W get_random_element(); @@ -130,11 +129,11 @@ protected: public: vector random_elements; - void AddToCheck(const W& share, const T& value, const Player& P); MAC_Check_Z2k(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0); MAC_Check_Z2k(const T& ai, Names& Nms, int thread_num); void prepare_open(const W& secret); + void prepare_open_no_mask(const W& secret); virtual void Check(const Player& P); void set_random_element(const W& random_element); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index fd71d526..ca607fd7 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -14,6 +14,7 @@ #include #include "Protocols/MAC_Check_Base.hpp" +#include "mac_key.hpp" template const char* TreeSum::mc_timer_names[] = { @@ -118,6 +119,7 @@ template void MAC_Check_::Check(const Player& P) { assert(U::mac_type::invertible); + check_field_size(); if (this->WaitingForCheck() == 0) return; @@ -214,17 +216,15 @@ MAC_Check_Z2k::MAC_Check_Z2k(const T& ai, Names& Nms, } template -void MAC_Check_Z2k::AddToCheck(const W& share, const T& value, const Player& P) +void MAC_Check_Z2k::prepare_open(const W& secret) { - shares.push_back(share.get_share()); - Tree_MAC_Check::AddToCheck(share, value, P); + prepare_open_no_mask(secret + (get_random_element() << W::clear::N_BITS)); } template -void MAC_Check_Z2k::prepare_open(const W& secret) +void MAC_Check_Z2k::prepare_open_no_mask(const W& secret) { - shares.push_back(secret.get_share()); - this->values.push_back(V(secret.get_share())); + this->values.push_back(secret.get_share()); this->macs.push_back(secret.get_mac()); } @@ -269,7 +269,6 @@ void MAC_Check_Z2k::Check(const Player& P) cout << "Checking " << shares[0] << " " << this->vals[0] << " " << this->macs[0] << endl; #endif - int k = V::N_BITS; octet seed[SEED_SIZE]; Create_Random_Seed(seed,P,SEED_SIZE); PRNG G; @@ -290,30 +289,7 @@ void MAC_Check_Z2k::Check(const Player& P) chi.push_back(temp_chi); } - W r = get_random_element(); - T lj = r.get_mac(); - U pj; - pj.assign_zero(); - for (int i = 0; i < this->popen_cnt; ++i) - { - T xji = shares[i]; - V xbarji = xji; - U pji = U((xji - xbarji) >> k); - pj += chi[i] * pji; - } - pj += U(r.get_share()); - - U pbar(pj); - vector pj_stream(P.num_players()); - pj.pack(pj_stream[P.my_num()]); - P.unchecked_broadcast(pj_stream); - for (int j=0; jalphai * y) - (((this->alphai * pbar)) << k) + (lj << k); + T zj = mj - this->alphai * y; vector zjs(P.num_players()); zjs[P.my_num()] = zj; Commit_And_Open(zjs, P); @@ -325,7 +301,6 @@ void MAC_Check_Z2k::Check(const Player& P) this->vals.erase(this->vals.begin(), this->vals.begin() + this->popen_cnt); this->macs.erase(this->macs.begin(), this->macs.begin() + this->popen_cnt); - this->shares.erase(this->shares.begin(), this->shares.begin() + this->popen_cnt); this->popen_cnt=0; if (!zj_sum.is_zero()) { throw mac_fail(); } } diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index 5a60281c..e855214f 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -57,7 +57,8 @@ public: /// Run opening protocol virtual void exchange(const Player& P) = 0; /// Get next opened value - virtual typename T::open_type finalize_open(); + virtual typename T::clear finalize_open(); + virtual typename T::open_type finalize_raw(); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 91e2ea86..59c6c5de 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -25,7 +25,7 @@ void MAC_Check_Base::POpen_End(vector& values, values.clear(); values.reserve(S.size()); for (size_t i = 0; i < S.size(); i++) - values.push_back(finalize_open()); + values.push_back(finalize_raw()); } template @@ -59,7 +59,13 @@ void MAC_Check_Base::prepare_open(const T& secret) } template -typename T::open_type MAC_Check_Base::finalize_open() +typename T::clear MAC_Check_Base::finalize_open() +{ + return finalize_raw(); +} + +template +typename T::open_type MAC_Check_Base::finalize_raw() { return values.next(); } diff --git a/Protocols/MalRepRingOptions.cpp b/Protocols/MalRepRingOptions.cpp index a2537da6..c5aafc18 100644 --- a/Protocols/MalRepRingOptions.cpp +++ b/Protocols/MalRepRingOptions.cpp @@ -21,10 +21,10 @@ MalRepRingOptions::MalRepRingOptions(ez::ezOptionParser& opt, int argc, 0, // Number of args expected. 0, // Delimiter if expecting multiple args. "Shuffle sacrifice (default: disabled)", // Help description. - "-S", // Flag token. + "-SH", // Flag token. "--shuffle" // Flag token. ); opt.parse(argc, argv); - shuffle = opt.isSet("-S"); + shuffle = opt.isSet("-SH"); opt.resetArgs(); } diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 96f2c813..6ce2e244 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -89,7 +89,7 @@ void MalRepRingPrep::simple_buffer_triples() template void MalRepRingPrep::shuffle_buffer_triples() { - assert(T::SECURITY <= 40); + assert(T::SECURITY <= OnlineOptions::singleton.security_parameter); assert(this->proc != 0); typename T::MAC_Check MC; shuffle_triple_generation(this->triples, this->proc->P, MC); diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 87deaaa3..e023945b 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -49,11 +49,11 @@ class HashMaliciousRepMC : public MaliciousRepMC public: // emulate MAC_Check - HashMaliciousRepMC(const typename T::value_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + HashMaliciousRepMC(const typename T::mac_key_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() { (void)_; (void)__; (void)___; } // emulate Direct_MAC_Check - HashMaliciousRepMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + HashMaliciousRepMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) : HashMaliciousRepMC() { (void)_; (void)__; (void)___; (void)____; } HashMaliciousRepMC(); @@ -62,7 +62,7 @@ public: void POpen(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index e64db21a..17eec6f1 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -84,9 +84,9 @@ void HashMaliciousRepMC::POpen_End(vector& values, } template -typename T::open_type HashMaliciousRepMC::finalize_open() +typename T::open_type HashMaliciousRepMC::finalize_raw() { - auto res = ReplicatedMC::finalize_open(); + auto res = ReplicatedMC::finalize_raw(); os.reset_write_head(); res.pack(os); update(); diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 8ffbff7b..b4d83d1b 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -7,6 +7,8 @@ #include "Tools/Subroutines.h" #include "Processor/OnlineOptions.h" +#include "mac_key.hpp" + template MaliciousBitOnlyRepPrep::MaliciousBitOnlyRepPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), @@ -69,7 +71,7 @@ void MaliciousBitOnlyRepPrep::init_honest(Player& P) template void MaliciousRepPrep::buffer_triples() { - assert(T::open_type::length() >= 40); + check_field_size(); auto& triples = this->triples; auto buffer_size = this->buffer_size; auto& honest_proc = this->honest_proc; diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index a72c36b0..4efe0711 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -38,7 +38,7 @@ public: { (void)_; (void)__; (void)___; (void)____; } void init_open(const Player& P, int n = 0); - typename T::open_type finalize_open(); + typename T::open_type finalize_raw(); typename T::open_type reconstruct(const vector& shares); }; diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 41c6f208..7f66215d 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -33,7 +33,7 @@ void MaliciousShamirMC::init_open(const Player& P, int n) } template -typename T::open_type MaliciousShamirMC::finalize_open() +typename T::open_type MaliciousShamirMC::finalize_raw() { int threshold = ShamirMachine::s().threshold; shares.resize(2 * threshold + 1); diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index fee8e829..ceedc915 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -38,6 +38,9 @@ public: typedef MaliciousRepPrep TriplePrep; typedef T random_type; + // indicate security relevance of field size + typedef T mac_key_type; + #ifndef NO_MIXED_CIRCUITS typedef GC::MaliciousCcdSecret bit_type; #endif diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index c9eb63cf..11942825 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -25,13 +25,15 @@ template void MamaPrep::buffer_triples() { int mac_security = T::N_MACS * T::clear::length(); + int sec = OnlineOptions::singleton.security_parameter; - if (mac_security < 40) + if (mac_security < sec) { - cerr << T::N_MACS << " MACs are not enough for 40-bit security with " - << T::clear::length() << "-bit primes." << endl; + cerr << T::N_MACS << " MACs are not enough for " << sec + << "-bit security with " << T::clear::length() << "-bit primes." + << endl; cerr << "Compile with -DN_MAMA_MACS=" - << DIV_CEIL(40, T::clear::length()) + << DIV_CEIL(sec, T::clear::length()) << " or remove this check in " << __FILE__ << endl; exit(1); } @@ -45,7 +47,7 @@ void MamaPrep::buffer_triples() size_t required = OnlineOptions::singleton.batch_size; // prefer shuffling if not loosing much security and bucket size is smaller - bool use_shuffling = mac_security <= 42 + bool use_shuffling = mac_security <= (sec + 2) and OnlineOptions::singleton.bucket_size < T::N_MACS; if (use_shuffling) required = sacrifice.minimum_n_inputs(); diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index c90a5e27..f2751518 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -23,6 +23,11 @@ class MamaMac : public FixedVec, N> public: static const true_type invertible; + static int length() + { + return N * T::length(); + } + MamaMac() { } diff --git a/Protocols/NoLivePrep.h b/Protocols/NoLivePrep.h index c53ec7e8..a1b89f9d 100644 --- a/Protocols/NoLivePrep.h +++ b/Protocols/NoLivePrep.h @@ -32,6 +32,11 @@ public: { } + NoLivePrep(DataPositions& usage, int = -1) : + BufferPrep(usage) + { + } + // access to protocol instance if needed void set_protocol(typename T::Protocol&) { diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h index d8259eb0..f1ef3c02 100644 --- a/Protocols/NoProtocol.h +++ b/Protocols/NoProtocol.h @@ -8,6 +8,7 @@ #include "Protocols/Replicated.h" #include "Protocols/MAC_Check_Base.h" +#include "Processor/Input.h" // opening facility template diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h index d966f586..0532200b 100644 --- a/Protocols/NoShare.h +++ b/Protocols/NoShare.h @@ -25,9 +25,6 @@ public: typedef T clear; typedef clear open_type; - // needs to be defined even if protocol doesn't use MACs - typedef clear mac_key_type; - // disable binary computation typedef GC::NoShare bit_type; diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 44853b79..afb45662 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -27,8 +27,6 @@ class RepShare : public FixedVec, public ShareInterface public: typedef T clear; typedef T open_type; - typedef T mac_type; - typedef T mac_key_type; const static bool needs_ot = false; const static bool dishonest_majority = false; @@ -138,9 +136,10 @@ public: return T::type_char(); } - static Rep3Share constant(T value, int my_num, const T& alphai = {}) + static Rep3Share constant(T value, int my_num, + typename super::mac_key_type = {}) { - return Rep3Share(value, my_num, alphai); + return Rep3Share(value, my_num); } Rep3Share() diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 2357d0f5..ba5b85c8 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -53,6 +53,8 @@ protected: int trunc_pr_counter; int rounds, trunc_rounds; + int dot_counter; + int bit_counter; public: typedef T share_type; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 1a8a66b9..2d9eba57 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -20,7 +20,8 @@ template ProtocolBase::ProtocolBase() : - trunc_pr_counter(0), rounds(0), trunc_rounds(0), counter(0) + trunc_pr_counter(0), rounds(0), trunc_rounds(0), dot_counter(0), + bit_counter(0), counter(0) { } @@ -67,7 +68,11 @@ ProtocolBase::~ProtocolBase() { #ifdef VERBOSE_COUNT if (counter or rounds) - cerr << "Number of " << T::type_string() << " multiplications: " << counter << " in " << rounds << " rounds" << endl; + cerr << "Number of " << T::type_string() << " multiplications: " + << counter << " (" << bit_counter << " bits) in " << rounds + << " rounds" << endl; + if (counter or rounds) + cerr << "Number of " << T::type_string() << " dot products: " << dot_counter << endl; if (trunc_pr_counter or trunc_rounds) cerr << "Number of probabilistic truncations: " << trunc_pr_counter << " in " << trunc_rounds << " rounds" << endl; #endif @@ -126,6 +131,7 @@ template T ProtocolBase::finalize_dotprod(int length) { counter += length; + dot_counter++; T res; for (int i = 0; i < length; i++) res += finalize_mul(); @@ -199,6 +205,7 @@ template inline T Replicated::finalize_mul(int n) { this->counter++; + this->bit_counter += n; T result; result[0] = add_shares.next(); result[1].unpack(os[1], n); @@ -230,6 +237,7 @@ template inline T Replicated::finalize_dotprod(int length) { (void) length; + this->dot_counter++; return finalize_mul(); } @@ -316,6 +324,7 @@ void Replicated::trunc_pr(const vector& regs, int size, U& proc, for (auto info : infos) for (int i = 0; i < size; i++) { + this->trunc_pr_counter++; auto c_prime = input.finalize(comp_player); auto r_prime = input.finalize(gen_player); S[info.dest_base + i] = c_prime - r_prime; diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index bb6f36a2..17916a2e 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -22,11 +22,11 @@ class ReplicatedMC : public MAC_Check_Base public: // emulate MAC_Check - ReplicatedMC(const typename T::value_type& _ = {}, int __ = 0, int ___ = 0) + ReplicatedMC(const typename T::mac_key_type& _ = {}, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; } // emulate Direct_MAC_Check - ReplicatedMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) + ReplicatedMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } void POpen(vector& values,const vector& S,const Player& P); @@ -34,7 +34,7 @@ public: void POpen_End(vector& values,const vector& S,const Player& P); virtual void exchange(const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index bbcf73e5..e72c0d83 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -65,7 +65,7 @@ void ReplicatedMC::finalize(vector& values, } template -typename T::open_type ReplicatedMC::finalize_open() +typename T::open_type ReplicatedMC::finalize_raw() { auto a = this->secrets.next().sum(); return a + o.get(); diff --git a/Protocols/SPDZ.h b/Protocols/SPDZ.h index fb2888c0..bd804ea0 100644 --- a/Protocols/SPDZ.h +++ b/Protocols/SPDZ.h @@ -26,7 +26,8 @@ public: { } - static void assign(typename T::clear& share, const typename T::clear& clear, int my_num) + static void assign(typename T::open_type& share, + const typename T::open_type& clear, int my_num) { if (my_num == 0) share = clear; diff --git a/Protocols/SPDZ2k.h b/Protocols/SPDZ2k.h new file mode 100644 index 00000000..da128fee --- /dev/null +++ b/Protocols/SPDZ2k.h @@ -0,0 +1,28 @@ +/* + * SPDZ2k.h + * + */ + +#ifndef PROTOCOLS_SPDZ2K_H_ +#define PROTOCOLS_SPDZ2K_H_ + +#include "SPDZ.h" + +template +class SPDZ2k : public SPDZ +{ +public: + SPDZ2k(Player& P) : + SPDZ(P) + { + } + + void exchange() + { + for (size_t i = 0; i < this->shares.size(); i++) + this->MC->set_random_element({}); + SPDZ::exchange(); + } +}; + +#endif /* PROTOCOLS_SPDZ2K_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index ee5e8320..cc41d023 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -21,8 +21,6 @@ class Semi2kShare : public SemiShare> typedef SignedZ2 T; public: - typedef Z2<64> mac_key_type; - typedef SemiMC MAC_Check; typedef DirectSemiMC Direct_MC; typedef SemiInput Input; diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 4fc265b7..c40d0c17 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -18,7 +18,7 @@ class SemiInput : public InputBase { vector send_prngs; vector recv_prngs; - Player& P; + PlayerBase& P; vector> shares; public: @@ -27,7 +27,7 @@ public: { } - SemiInput(SubProcessor* proc, Player& P); + SemiInput(SubProcessor* proc, PlayerBase& P); SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : SemiInput(0, P) diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 3ed1feef..f0fefe13 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -11,7 +11,7 @@ #include "ShamirInput.hpp" template -SemiInput::SemiInput(SubProcessor* proc, Player& P) : +SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : InputBase(proc), P(P) { shares.resize(P.num_players()); diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index c2dd9085..b306d5c3 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -9,6 +9,7 @@ #include "Protocols/Beaver.h" #include "Protocols/Semi.h" #include "Processor/DummyProtocol.h" +#include "GC/NoShare.h" #include "ShareInterface.h" #include @@ -51,8 +52,6 @@ class SemiShare : public T, public ShareInterface typedef T super; public: - typedef T mac_key_type; - typedef T mac_type; typedef T open_type; typedef T clear; @@ -87,10 +86,10 @@ public: return nplayers - 1; } - static SemiShare constant(const clear& other, int my_num, - const T& alphai = {}, int = -1) + static SemiShare constant(const open_type& other, int my_num, + mac_key_type = {}, int = -1) { - return SemiShare(other, my_num, alphai); + return SemiShare(other, my_num); } SemiShare() @@ -100,7 +99,7 @@ public: SemiShare(const U& other) : T(other) { } - SemiShare(const clear& other, int my_num, const T& alphai = {}) + SemiShare(const open_type& other, int my_num, const T& alphai = {}) { (void) alphai; Protocol::assign(*this, other, my_num); diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 6bda92df..c6a88f0a 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -69,7 +69,7 @@ public: virtual void init_open(const Player& P, int n = 0); virtual void prepare_open(const T& secret); virtual void exchange(const Player& P); - virtual typename T::open_type finalize_open(); + virtual typename T::open_type finalize_raw(); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index e3e7cd3a..7238aa5e 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -112,11 +112,11 @@ void ShamirMC::finalize(vector& values, { values.clear(); for (size_t i = 0; i < S.size(); i++) - values.push_back(finalize_open()); + values.push_back(finalize_raw()); } template -typename T::open_type ShamirMC::finalize_open() +typename T::open_type ShamirMC::finalize_raw() { assert(reconstruction.size()); typename T::open_type res; diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index e7daabfc..aea0bb97 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -29,10 +29,7 @@ class ShamirShare : public T, public ShareInterface public: typedef T clear; typedef T open_type; - typedef T mac_key_type; typedef void sacri_type; - typedef GC::NoShare mac_type; - typedef GC::NoShare mac_share_type; typedef Shamir Protocol; typedef IndirectShamirMC MAC_Check; @@ -76,9 +73,9 @@ public: return Protocol::get_rec_factor(i, n); } - static ShamirShare constant(T value, int my_num, const T& alphai = {}) + static ShamirShare constant(T value, int, const mac_key_type& = {}) { - return ShamirShare(value, my_num, alphai); + return ShamirShare(value); } ShamirShare() @@ -89,42 +86,12 @@ public: { T::operator=(other); } - template - ShamirShare(const U& other, int my_num, T alphai = {}) : ShamirShare(other) - { - (void) my_num, (void) alphai; - } - // Share compatibility - void assign(clear other, int my_num, const T& alphai) - { - (void)alphai, (void)my_num; - *this = other; - } void assign(const char* buffer) { T::assign(buffer); } - void add(const ShamirShare& S, const clear aa, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = S + aa; - } - void sub(const ShamirShare& S, const clear& aa, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = S - aa; - } - void sub(const clear& aa, const ShamirShare& S, int my_num, - const T& alphai) - { - (void) my_num, (void) alphai; - *this = aa - S; - } - ShamirShare operator<<(int i) { return *this * (T(1) << i); diff --git a/Protocols/Share.h b/Protocols/Share.h index 92be4f14..e2a9f0bb 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -73,7 +73,7 @@ class Share_ : public ShareInterface static void specification(octetStream& os) { T::specification(os); } - static Share_ constant(const clear& aa, int my_num, const typename V::Scalar& alphai) + static Share_ constant(const open_type& aa, int my_num, const typename V::Scalar& alphai) { return Share_(aa, my_num, alphai); } template @@ -85,12 +85,12 @@ class Share_ : public ShareInterface { a.assign_zero(); mac.assign_zero(); } - void assign(const clear& aa, int my_num, const typename V::Scalar& alphai); + void assign(const open_type& aa, int my_num, const typename V::Scalar& alphai); Share_() { assign_zero(); } template Share_(const Share_& S) { assign(S); } - Share_(const clear& aa, int my_num, const typename V::Scalar& alphai) + Share_(const open_type& aa, int my_num, const typename V::Scalar& alphai) { assign(aa, my_num, alphai); } Share_(const T& share, const V& mac) : a(share), mac(mac) {} @@ -128,6 +128,8 @@ class Share_ : public ShareInterface void force_to_bit() { a.force_to_bit(); } + void randomize(PRNG& G); + // Input and output from a stream // - Can do in human or machine only format (later should be faster) void output(ostream& s,bool human) const @@ -235,7 +237,7 @@ inline void Share_::mul(const Share_& S,const clear& aa) } template -inline void Share_::assign(const clear& aa, int my_num, +inline void Share_::assign(const open_type& aa, int my_num, const typename V::Scalar& alphai) { a = T::constant(aa, my_num); diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index 62f90e91..c6f675f7 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -23,6 +23,13 @@ void Share_::read_or_generate_mac_key(string directory, const Player& P, } } +template +void Share_::randomize(PRNG& G) +{ + a.randomize(G); + mac.randomize(G); +} + template inline void Share_::pack(octetStream& os, bool full) const { diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 444214e4..e5af8ddd 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -20,6 +20,7 @@ class ValueInterface; namespace GC { class NoShare; +class NoValue; } class ShareInterface @@ -28,6 +29,10 @@ public: typedef GC::NoShare part_type; typedef GC::NoShare bit_type; + typedef GC::NoValue mac_key_type; + typedef GC::NoShare mac_type; + typedef GC::NoShare mac_share_type; + static const bool needs_ot = false; static const bool expensive = false; static const bool expensive_triples = false; diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index b7fdf50b..7f84213e 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -128,7 +128,7 @@ public: typedef ValueMatrix clear; typedef clear open_type; - typedef typename T::clear mac_key_type; + typedef typename T::mac_key_type mac_key_type; static string type_string() { diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 81e85931..4d03dd67 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -14,7 +14,7 @@ inline ShuffleSacrifice::ShuffleSacrifice() : - B(OnlineOptions::singleton.bucket_size), C(this->B) + ShuffleSacrifice(OnlineOptions::singleton.bucket_size, 3) { } @@ -22,6 +22,9 @@ inline ShuffleSacrifice::ShuffleSacrifice(int B, int C) : B(B), C(C) { + if (OnlineOptions::singleton.security_parameter > 40) + throw runtime_error("shuffle sacrifice not implemented for more than " + "40-bit security"); } template diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index 48deeadc..1dfd3ecb 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -21,6 +21,7 @@ void SohoPrep::basic_setup(Player& P) assert(not setup); setup = new PartSetup; MachineBase machine; + setup->params.set_sec(OnlineOptions::singleton.security_parameter); setup->secure_init(P, machine, T::clear::length(), 0); read_or_generate_secrets(*setup, P, machine, 1, true_type()); T::clear::template init(); diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 03a91ff2..d95a713d 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -8,7 +8,8 @@ #include "MascotPrep.h" #include "RingOnlyPrep.h" -#include "Spdz2kShare.h" + +template class Spdz2kShare; template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index 401070f8..762cd34d 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -18,6 +18,7 @@ template class Spdz2kMultiplier; template class Spdz2kTripleGenerator; +template class SPDZ2k; namespace GC { @@ -48,7 +49,7 @@ public: typedef MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef SPDZ2k Protocol; typedef Spdz2kPrep LivePrep; #ifndef NO_MIXED_CIRCUITS diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 2ea08ba4..b7a8c741 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -5,6 +5,8 @@ #include "SpdzWise.h" +#include "mac_key.hpp" + template SpdzWise::SpdzWise(Player& P) : internal(P), internal2(P), P(P) @@ -142,6 +144,7 @@ template void SpdzWise::zero_check(check_type t) { assert(T::clear::invertible); + check_field_size(); auto r = internal.get_random(); internal.init_mul(); internal.prepare_mul(t, r); diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9991dafb..9ad76198 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -44,9 +44,9 @@ public: { inner_MC.exchange(P); } - typename T::open_type finalize_open() + typename T::open_type finalize_raw() { - return inner_MC.finalize_open(); + return inner_MC.finalize_raw(); } void Check(const Player& P) { diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp index 6401c083..acd6fa29 100644 --- a/Protocols/SpdzWiseShare.hpp +++ b/Protocols/SpdzWiseShare.hpp @@ -13,14 +13,15 @@ template void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& mac_key) { + bool fresh = false; + try { read_mac_key(directory, P.N, mac_key); } catch (mac_key_error&) { - SeededPRNG G; - mac_key.randomize(G); + fresh = true; } try @@ -33,11 +34,12 @@ void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& } catch (mac_fail&) { -#ifdef VERBOSE - cerr << "Generating fresh MAC key for " << type_string() << endl; -#endif - mac_key = typename T::Honest::Protocol(P).get_random(); + fresh = true; + cerr << "Invalid " << type_string() << " MAC key, generating fresh one" << endl; } + + if (fresh) + mac_key = typename T::Honest::Protocol(P).get_random(); } template diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h index de7406bb..ad12837a 100644 --- a/Protocols/TemiPrep.h +++ b/Protocols/TemiPrep.h @@ -64,6 +64,8 @@ public: { } + ~TemiPrep(); + void buffer_triples(); vector*>& get_multipliers(); diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp index 1088a99c..1f2a415c 100644 --- a/Protocols/TemiPrep.hpp +++ b/Protocols/TemiPrep.hpp @@ -24,6 +24,7 @@ void TemiPrep::basic_setup(Player& P) assert(not setup); setup = new TemiSetup; MachineBase machine; + setup->params.set_sec(OnlineOptions::singleton.security_parameter); setup->secure_init(P, T::clear::length()); read_or_generate_secrets(*setup, P, machine, 1, true_type()); T::clear::template init(); @@ -104,6 +105,13 @@ TemiMultiplier::TemiMultiplier(Player& P) : P(P) { } +template +TemiPrep::~TemiPrep() +{ + for (auto& x : multipliers) + delete x; +} + template vector& TemiMultiplier::get_multiplicands( vector >& ciphertexts, const FHE_PK& pk) diff --git a/Protocols/config.h b/Protocols/config.h new file mode 100644 index 00000000..f88c3aa8 --- /dev/null +++ b/Protocols/config.h @@ -0,0 +1,13 @@ +/* + * config.h + * + */ + +#ifndef PROTOCOLS_CONFIG_H_ +#define PROTOCOLS_CONFIG_H_ + +#ifndef DEFAULT_SECURITY +#define DEFAULT_SECURITY 40 +#endif + +#endif /* PROTOCOLS_CONFIG_H_ */ diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index d15581eb..dc941427 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -11,7 +11,7 @@ using namespace std; template void check_share(vector& Sa, typename T::clear& value, - typename T::value_type& mac, int N, const typename T::value_type& key); + typename T::mac_type& mac, int N, const typename T::mac_key_type& key); template class Share; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index aeb51611..bae415c4 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -20,6 +20,7 @@ template class FixedVec; template class Share_; template class SpdzWiseShare; template class MaliciousRep3Share; +template class DealerShare; namespace GC { @@ -115,7 +116,6 @@ void make_share(GC::TinierSecret* Sa, const U& a, int N, const V& key, PRNG& template void make_share(SemiShare* Sa,const T& a,int N,const U&,PRNG& G) { - insecure("share generation", false); T x, S = a; for (int i=0; i* Sa,const T& a,int N,const U&,PRNG& G) Sa[N-1]=S; } +template +void make_share(DealerShare* Sa, const T& a, int N, const U&, PRNG& G) +{ + make_share((SemiShare*) Sa, a, N - 1, U(), G); + Sa[N - 1] = {}; +} + template void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G); @@ -234,7 +241,7 @@ void check_share(vector >& Sa, template void check_share(vector& Sa, typename T::clear& value, - typename T::value_type& mac, int N, const typename T::value_type& key) + typename T::mac_type& mac, int N, const typename T::mac_key_type& key) { assert(N == 3); value = 0; @@ -340,23 +347,27 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, { if (directory == "") directory = get_prep_sub_dir(P.num_players()); - typename T::mac_key_type res; + typename T::mac_key_type res, tmp; try { - read_mac_key(directory, P.my_num(), P.num_players(), res); + read_mac_key(directory, P.my_num(), P.num_players(), tmp); } catch (mac_key_error&) { - T::read_or_generate_mac_key(directory, P, res); - write_mac_key(directory, P.my_num(), P.num_players(), res); } + T::read_or_generate_mac_key(directory, P, res); + + // only write if changed + if (tmp != res) + write_mac_key(directory, P.my_num(), P.num_players(), res); + return res; } template -void read_global_mac_key(const string& directory, int nparties, U& key) +void read_global_mac_key(const string& directory, int nparties, U& key, false_type) { U pp; key.assign_zero(); @@ -372,6 +383,17 @@ void read_global_mac_key(const string& directory, int nparties, U& key) cout << "Final Keys : " << key << endl; } +template +void read_global_mac_key(const string&, int, U&, true_type) +{ +} + +template +void read_global_mac_key(const string& directory, int nparties, U& key) +{ + read_global_mac_key(directory, nparties, key, is_same()); +} + template T reconstruct(vector& shares) { @@ -548,4 +570,25 @@ void make_inverse(const typename T::mac_type& key, int N, int ntrip, bool zero, check_files(files.outf, N); } +template +void plain_edabits(vector& as, + vector& bs, int length, PRNG& G, + bool zero = false) +{ + int max_size = edabitvec::MAX_SIZE; + as.resize(max_size); + bs.clear(); + bs.resize(length); + bigint value; + for (int j = 0; j < max_size; j++) + { + if (not zero) + G.get_bigint(value, length, true); + as[j] = value; + for (int k = 0; k < length; k++) + bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + } + +} + #endif diff --git a/Protocols/mac_key.hpp b/Protocols/mac_key.hpp index 843e6d47..3f1000b8 100644 --- a/Protocols/mac_key.hpp +++ b/Protocols/mac_key.hpp @@ -22,4 +22,12 @@ typename T::mac_key_type read_or_generate_mac_key(const Player& P, return res; } +template +void check_field_size() +{ + if (T::length() < OnlineOptions::singleton.security_parameter) + throw runtime_error("Field too small for chosen security. " + "Increase size with -lgp or decrease security with -S"); +} + #endif /* PROTOCOLS_MAC_KEY_HPP_ */ diff --git a/README.md b/README.md index 96aa0fd5..d44190e1 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ parties and malicious security. On Linux, this requires a working toolchain and [all requirements](#requirements). On Ubuntu, the following might suffice: ``` -apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm +sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm ``` On MacOS, this requires [brew](https://brew.sh) to be installed, which will be used for all dependencies. @@ -103,6 +103,7 @@ The following table lists all protocols that are fully supported. | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | Modulo prime and modulo 2^k are the two settings that allow integer-like computation. For k = 64, the latter corresponds to the @@ -174,6 +175,9 @@ there are a few things to consider: adding `program.use_trunc_pr = True` at the beginning of your high-level program. +- Larger number of parties: ATLAS scales better than the plain Shamir + protocol, and Temi scale better than Hemi or Semi. + - Minor variants: Some command-line options change aspects of the protocols such as: @@ -771,7 +775,23 @@ the number of parties with `-N` and the maximum number of corrupted parties with `-T`. The latter can be at most half the number of parties. -### BMR +## Dealer model + +This security model defines a special party that generates correlated +randomness such as multiplication triples, which is then used by all +other parties. MP-SPDZ implements the canonical protocol where the +other parties run the online phase of the semi-honest protocol in +Semi(2k/Bin) and the dealer provides all preprocessing. The security +assumption is that dealer doesn't collude with any other party, but +all but one of the other parties are allowed to collude. In our +implementation, the dealer is the party with the highest number, so +with three parties overall, Party 0 and 1 run the online phase. + +| Program | Sharing | Domain | Malicious | \# parties | Script | +| --- | --- | --- | --- | --- | --- | +| `dealer-ring-party.x` | Additive | Mod 2^k | N | 3+ | `dealer-ring.sh` | + +## BMR BMR (Bellare-Micali-Rogaway) is a method of generating a garbled circuit using another secure computation protocol. We have implemented BMR diff --git a/Scripts/dealer-ring.sh b/Scripts/dealer-ring.sh new file mode 100755 index 00000000..b0d6692a --- /dev/null +++ b/Scripts/dealer-ring.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player dealer-ring-party.x $* || exit 1 diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index 15959ee6..eaec677f 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -12,7 +12,7 @@ if len(sys.argv) <= 1: print('Usage: %s ' % sys.argv[0]) res = collections.defaultdict(lambda: 0) -m = 0 +regs = collections.defaultdict(lambda: 0) for tapename in Program.read_tapes(sys.argv[1]): for inst in Tape.read_instructions(tapename): @@ -22,8 +22,9 @@ for tapename in Program.read_tapes(sys.argv[1]): res[t.arg_format[0]]) for arg in inst.args: if isinstance(arg, RegisterArgFormat): - m = max(m, arg.i + inst.size) + regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size) -print (res) -print (m) +reverse_formats = dict((v, k) for k, v in ArgFormats.items()) +print ('Memory:', dict(res)) +print ('Registers:', dict((reverse_formats[t], n) for t, n in regs.items())) diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index e8c02f6c..3771383b 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -52,7 +52,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py -R 64 $compile_opts tutorial for i in ring rep4-ring semi2k brain mal-rep-ring ps-rep-ring sy-rep-ring \ - spdz2k; do + spdz2k dealer-ring; do test_vm $i $run_opts done @@ -65,7 +65,7 @@ for dabit in ${dabit:-0 1 2}; do done for i in cowgear chaigear; do - test_vm $i $run_opts -l 3 -c 2 + test_vm $i $run_opts -S 3 -c 2 done done @@ -83,7 +83,7 @@ fi ./compile.py tutorial for i in cowgear chaigear; do - test_vm $i $run_opts -l 3 -c 2 -J + test_vm $i $run_opts -S 3 -c 2 -J done if test $skip_binary; then diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index ed6c0144..ce906a3a 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -27,7 +27,11 @@ if test "$flags"; then cpu=amd64 fi - cp -av bin/`uname`-$cpu/* . || { echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2; exit 1; } + if ! cp -av bin/`uname`-$cpu/* .; then + echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2 + echo Make sure NOT to download a source code only file 1>&2 + exit 1 + fi fi mkdir Player-Data 2> /dev/null diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 9dd15804..f3e67c82 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -85,6 +85,11 @@ void BufferBase::try_rewind() void BufferBase::prune() { + // only prune in secure mode +#ifdef INSECURE + return; +#endif + if (is_pipe()) return; diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index f6f4ba2e..c7f8c371 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -83,3 +83,8 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil "adding -DINSECURE to the compiler options.") { } + +gf2n_not_supported::gf2n_not_supported(int n) : + runtime_error("GF(2^" + to_string(n) + ") not supported") +{ +} diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index fff8b2de..bb347c6a 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -278,4 +278,10 @@ public: insufficient_memory(size_t size, const string& type); }; +class gf2n_not_supported : public runtime_error +{ +public: + gf2n_not_supported(int n); +}; + #endif diff --git a/Tools/avx_memcpy.h b/Tools/avx_memcpy.h index 231dc99c..a00a215e 100644 --- a/Tools/avx_memcpy.h +++ b/Tools/avx_memcpy.h @@ -20,6 +20,7 @@ template inline void avx_memcpy(void* dest, const void* source) { size_t length = L; +#ifdef __SSE2__ __m256i* d = (__m256i*)dest, *s = (__m256i*)source; #ifdef __AVX__ while (length >= 32) @@ -35,6 +36,10 @@ inline void avx_memcpy(void* dest, const void* source) _mm_storeu_si128(d2++, _mm_loadu_si128(s2++)); length -= 16; } +#else + void* d2 = dest; + const void* s2 = source; +#endif switch (length) { case 0: @@ -53,14 +58,16 @@ inline void avx_memcpy(void* dest, const void* source) inline void avx_memzero(void* dest, size_t length) { - __m256i* d = (__m256i*)dest; #ifdef __AVX__ + __m256i* d = (__m256i*)dest; __m256i s = _mm256_setzero_si256(); while (length >= 32) { _mm256_storeu_si256(d++, s); length -= 32; } +#else + void* d = dest; #endif switch (length) { diff --git a/Tools/benchmarking.cpp b/Tools/benchmarking.cpp index e956f15e..88eee709 100644 --- a/Tools/benchmarking.cpp +++ b/Tools/benchmarking.cpp @@ -5,6 +5,25 @@ #include "benchmarking.h" +void insecure(string message, bool warning) +{ +#ifdef INSECURE + if (warning) + cerr << "WARNING: insecure " << message << endl; +#else + (void)warning; + string msg = "You are trying to use insecure benchmarking functionality for " + + message + ".\nYou can activate this at compile time " + "by adding -DINSECURE to the compiler options.\n" + "Make sure to run 'make clean' as well before compiling."; + cerr << msg << endl; +#ifdef INSECURE_EXCEPTION + throw exception(); +#endif + exit(1); +#endif +} + void insecure_fake() { #if defined(INSECURE) or defined(INSECURE_FAKE) diff --git a/Tools/benchmarking.h b/Tools/benchmarking.h index 13fa9c36..e54990ca 100644 --- a/Tools/benchmarking.h +++ b/Tools/benchmarking.h @@ -12,20 +12,7 @@ using namespace std; // call before insecure benchmarking functionality -inline void insecure(string message, bool warning = true) -{ -#ifdef INSECURE - if (warning) - cerr << "WARNING: insecure " << message << endl; -#else - (void)warning; - string msg = "You are trying to use insecure benchmarking functionality for " - + message + ".\nYou can activate this at compile time " - "by adding -DINSECURE to the compiler options.\n" - "Make sure to run make clean as well."; - throw runtime_error(msg); -#endif -} +void insecure(string message, bool warning = true); void insecure_fake(); diff --git a/Tools/intrinsics.h b/Tools/intrinsics.h index 45664a72..e7cb87dc 100644 --- a/Tools/intrinsics.h +++ b/Tools/intrinsics.h @@ -10,6 +10,7 @@ #include #include #else +#ifdef __aarch64__ #define SIMDE_X86_AVX_ENABLE_NATIVE_ALIASES #define SIMDE_X86_AVX2_ENABLE_NATIVE_ALIASES #define SIMDE_X86_SSE2_ENABLE_NATIVE_ALIASES @@ -18,5 +19,6 @@ #include "simde/simde/x86/clmul.h" #include "aes-arm.h" #endif +#endif #endif /* TOOLS_INTRINSICS_H_ */ diff --git a/Tools/parse.h b/Tools/parse.h index af5e6de4..8ff0fee9 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -23,6 +23,14 @@ inline int get_int(istream& s) return be32toh(n); } +// Read an 8-byte integer +inline int64_t get_long(istream& s) +{ + int64_t n; + s.read((char*) &n, 8); + return be64toh(n); +} + // Read several integers inline void get_ints(int* res, istream& s, int count) { diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index 3a364414..203328f2 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -36,7 +36,8 @@ string PREP_DATA_PREFIX; template void check_mult_triples(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a,b,c,mac; + typename T::clear a,b,c; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -99,7 +100,8 @@ void check_tuple(const T& a, const T& b, int n, Dtype type) template void check_tuples(const typename T::mac_key_type& key,int N,vector*>& dataF, Dtype type) { - typename T::clear a,b,c,mac,res; + typename T::clear a,b,c,res; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -127,7 +129,8 @@ void check_tuples(const typename T::mac_key_type& key,int N,vector void check_bits(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a,b,c,mac,res; + typename T::clear a,b,c,res; + typename T::mac_type mac; vector Sa(N),Sb(N),Sc(N); int n = 0; @@ -157,7 +160,8 @@ void check_bits(const typename T::mac_key_type& key,int N,vector void check_inputs(const typename T::mac_key_type& key,int N,vector*>& dataF) { - typename T::clear a, mac, x; + typename T::clear a, x; + typename T::mac_type mac; vector Sa(N); for (int player = 0; player < N; player++) diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index f1158cfa..823c318b 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -154,16 +154,9 @@ void FakeParams::make_edabits(const typename T::mac_type& key, int N, int ntrip, int max_size = edabitvec::MAX_SIZE; for (int i = 0; i < ntrip / max_size; i++) { - vector as(max_size); - vector bs(length); - for (int j = 0; j < max_size; j++) - { - if (not zero) - G.get_bigint(value, length, true); - as[j] = value; - for (int k = 0; k < length; k++) - bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; - } + vector as; + vector bs; + plain_edabits(as, bs, length, G, zero); for (auto& a : as) files.template output_shares(a); for (auto& b : bs) @@ -737,9 +730,12 @@ int FakeParams::generate() if (nplayers == 3) { make_bits>({}, nplayers, nbitsp, zero); - make_basic>({}, nplayers, default_num, zero); - make_basic>({}, nplayers, default_num, zero); - make_with_mac_key>(nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, + zero); + make_basic>({}, nplayers, + default_num, zero); + make_with_mac_key>(nplayers, + default_num, zero); make_mult_triples({}, nplayers, ntrip2, zero, prep_data_prefix); make_bits({}, nplayers, nbits2, zero); @@ -748,17 +744,21 @@ int FakeParams::generate() make_basic>({}, nplayers, default_num, zero); make_basic>>({}, nplayers, default_num, zero); + make_basic>>({}, nplayers, default_num, zero); + make_minimal({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); make_bits({}, nplayers, default_num, zero); gf2n_short::reset(); - gf2n_short::init_field(40); + gf2n_short::init_field(); - Z2<41> keyt; - generate_mac_keys>(keyt, nplayers, prep_data_prefix); + Z2 keyt; + generate_mac_keys>(keyt, nplayers, + prep_data_prefix); - make_minimal>(keyt, nplayers, default_num / 64, zero); + make_minimal>(keyt, nplayers, + default_num / 64, zero); gf2n_short keytt; generate_mac_keys>(keytt, nplayers, prep_data_prefix); diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index 962b2775..d00acd2a 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -20,6 +20,7 @@ #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/SemiSecret.hpp" #include "Protocols/Atlas.hpp" #include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/Share.hpp" diff --git a/Utils/l2h-example.cpp b/Utils/l2h-example.cpp new file mode 100644 index 00000000..475bcb8a --- /dev/null +++ b/Utils/l2h-example.cpp @@ -0,0 +1,54 @@ +/* + * l2h-example.cpp + * + */ + +#include "Protocols/ProtocolSet.h" + +#include "Math/gfp.hpp" +#include "Machines/SPDZ.hpp" + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 2) + { + cerr << "Usage: " << argv[0] << " " << endl; + exit(1); + } + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + + // template parameters are share types for integer and GF(2^n) computation + Machine, Share> machine(N); + + // protocols to be used directly + ProtocolSet> set(machine.get_player(), machine.get_sint_mac_key()); + + // data to be used in steps + set.input.reset_all(machine.get_player()); + set.input.add_from_all(2 + my_number); + set.input.exchange(); + machine.Mp.MS.resize(n_parties); + for (int i = 0; i < n_parties; i++) + machine.Mp.MS[i] = set.input.finalize(i); + + machine.run_step("l2h_multiplication"); + machine.run_step("l2h_comparison"); + + // check results + // multiplication + assert(set.output.open(machine.Mp.MS[2], machine.get_player()) == 6); + // comparison + assert(set.output.open(machine.Mp.MS[3], machine.get_player()) == 1); + + set.check(); + + // print usage + auto res = machine.stop_threads(); + res.first.print_cost(); +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index a36949d6..6eda84e0 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -6,6 +6,7 @@ #include "Protocols/ProtocolSet.h" #include "Machines/SPDZ.hpp" +#include "Machines/SPDZ2k.hpp" #include "Machines/Semi2k.hpp" #include "Machines/Rep.hpp" #include "Machines/Rep4.hpp" diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 83571c21..e5346ade 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -9,6 +9,7 @@ #include "Math/gfp.hpp" #include "Machines/SPDZ.hpp" +#include "Machines/SPDZ2k.hpp" #include "Machines/MalRep.hpp" #include "Machines/ShamirMachine.hpp" #include "Machines/Semi2k.hpp" @@ -30,7 +31,9 @@ int main(int argc, char** argv) // need player number and number of players if (argc < 3) { - cerr << "Usage: " << argv[0] << " [protocol [threshold]]" << endl; + cerr << "Usage: " << argv[0] + << " [protocol [threshold]]" + << endl; exit(1); } diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 38cdc922..62896b3a 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -250,7 +250,7 @@ void YaoEvalWire::convcbit2s(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto& dest = processor.S[instruction.get_r(0) + i]; - dest.resize_regs(min(unsigned(unit), instruction.get_n() - i * unit)); + dest.resize_regs(min(size_t(unit), instruction.get_n() - i * unit)); for (auto& reg : dest.get_regs()) reg.set(0); } diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 05a8646d..e5062808 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -238,7 +238,7 @@ void YaoGarbleWire::convcbit2s(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto& dest = processor.S[instruction.get_r(0) + i]; - int n = min(unsigned(unit), instruction.get_n() - i * unit); + int n = min(size_t(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) dest.get_reg(j).public_input( diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index a947b1f5..b1e0e073 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -33,7 +33,7 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) "--threshold" // Flag token. ); auto& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, false_type()}; + online_opts = {opt, argc, argv, false}; NetworkOptionsWithNumber network_opts(opt, argc, argv, 2, false); online_opts.finalize(opt, argc, argv); diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp index aa04fe35..984db38f 100644 --- a/Yao/YaoWire.hpp +++ b/Yao/YaoWire.hpp @@ -55,7 +55,7 @@ void YaoWire::andm(GC::Processor& processor, for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) { auto &dest = processor.S[instruction.get_r(0) + i]; - int n = min(unsigned(unit), instruction.get_n() - i * unit); + int n = min(size_t(unit), instruction.get_n() - i * unit); dest.resize_regs(n); for (int j = 0; j < n; j++) if (processor.C[instruction.get_r(2) + i].get_bit(j)) diff --git a/doc/Compiler.rst b/doc/Compiler.rst index df3e13f5..db5c1e9c 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -28,12 +28,16 @@ Compiler.GC.types module .. automodule:: Compiler.GC.types :members: :no-undoc-members: - :no-inherited-members: - :show-inheritance: + :inherited-members: :exclude-members: PreOp, cbit, dynamic_array, conv_cint_vec, bitdec, bit_type, bitcom, clear_type, conv_regint, default_type, mov, dyn_sbits, int_type, mul, vec, load_mem, - DynamicArray, get_raw_input_from + DynamicArray, get_raw_input_from, bits, + input_tensor_from, input_tensor_from_client, + input_tensor_via, dot_product, Matrix, Tensor, + from_sint, read_from_file, receive_from_client, + reveal_to_clients, write_shares_to_socket, + write_to_file Compiler.library module ----------------------- diff --git a/doc/non-linear.rst b/doc/non-linear.rst index e5df4c20..969e6d6c 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -56,6 +56,10 @@ Power-of-two modulus mask-and-reveal approach above to the setting of computation modulo a power of two. +See also `this slide deck +`_ for an +introduction to non-linear computation in arithmetic MPC. + Mixed-Circuit Computation ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -70,6 +74,12 @@ more general methods such as `daBits `_ and `edaBits `_. +See also `this slide deck +`_ for an introduction to +mixed-circuit computation. + + +.. _protocol-pairs: Protocol Pairs ============== From 4e811ec1597edfbe2236d21dfdda65f4ce247413 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Wed, 20 Apr 2022 15:51:24 -0500 Subject: [PATCH 045/265] Fix tiny typo in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d44190e1..a3f6741f 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ there are a few things to consider: - Computation domain: Arithmetic protocols (modulo prime or power of two) are preferable for many applications because they offer integer addition and multiplication at low cost. However, binary circuits - might a better option if there is very little integer + might be a better option if there is very little integer computation. [See below](#finding-the-most-efficient-variant) to find the most efficient mixed-circuit variant. Furthermore, local computation modulo a power of two is cheaper, but MP-SPDZ does not From a5917de3cf5143baaf89e2909384caaf50ddd921 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 21 Apr 2022 12:36:54 +0200 Subject: [PATCH 046/265] Protocol setup with exact modulus. --- Protocols/ProtocolSetup.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h index b6d91b2b..5953b666 100644 --- a/Protocols/ProtocolSetup.h +++ b/Protocols/ProtocolSetup.h @@ -35,6 +35,22 @@ public: T::read_or_generate_mac_key(directory, P, mac_key); } + /** + * @param prime modulus for computation + * @param P communication instance (used for MAC generation if needed) + * @param directory location to read MAC if needed + */ + ProtocolSetup(bigint prime, Player& P, string directory = "") + { + static_assert(T::clear::prime_field, "must use computation modulo a prime"); + + T::clear::init_field(prime); + T::clear::next::init_field(prime, false); + + // must initialize MAC key for security of some protocols + T::read_or_generate_mac_key(directory, P, mac_key); + } + ~ProtocolSetup() { T::LivePrep::teardown(); From 2760659ad4cd740e659447a5854b933daaa139e9 Mon Sep 17 00:00:00 2001 From: Sylvain Bellemare Date: Thu, 21 Apr 2022 22:54:18 -0500 Subject: [PATCH 047/265] Fix comment/example in Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e01dd5c5..5d8c2888 100644 --- a/Dockerfile +++ b/Dockerfile @@ -132,7 +132,7 @@ RUN make clean && make ${machine} && cp ${machine} /usr/local/bin/ # --build-arg machine=replicated-ring-party.x \ # # --build-arg prep_dir=/opt/prep \ # # --build-arg ssl_dir=/opt/ssl \ # -# --build-arg nparties=3 \ # +# --build-arg cryptoplayers=3 \ # # --build-arg compile_options="--ring=64" . # # # # Test it: # From a858e5b440902ec25dbb97c555eef35e12fbf69c Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 21 Apr 2022 18:29:27 +0200 Subject: [PATCH 048/265] Security bug in homomorphic encryption parameter generation. --- FHE/NoiseBounds.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index a1fe3e03..e2df9583 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -105,7 +105,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants() { tp *= t; double lgtp = log(tp) / log(2.0); - if (C[i] < 0 && lgtp < FHE_epsilon) + if (C[i] < 0 && lgtp < -FHE_epsilon) { C[i] = pow(x, i); } From db6763513425d45b64396f57604c7729e59ba556 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 4 May 2022 14:11:13 +0200 Subject: [PATCH 049/265] Security bug in Temi matrix multiplication. --- Protocols/HemiMatrixPrep.hpp | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index b2dd92d2..3446733e 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -53,12 +53,14 @@ class MatrixRandMultJob : public ThreadJob public: MatrixRandMultJob(vector>& C, const vector>& A, - vector>& B) + vector>& B, + bool local_mul) { type = MATRX_RAND_MULT_JOB; output = &C; input = &A; supply = &B; + length = local_mul; } }; @@ -73,7 +75,8 @@ inline void matrix_rand_mult(ThreadJob job, true_type = {}) { A[i].randomize(G); B[i].randomize(G); - C[i] = A[i] * B[i]; + if (job.length) + C[i] = A[i] * B[i]; } } @@ -101,25 +104,22 @@ void HemiMatrixPrep::buffer_triples() B(n_matrices, {n_inner, n_cols}); SeededPRNG G; AddableVector> C(n_matrices); - MatrixRandMultJob job(C, A, B); + MatrixRandMultJob job(C, A, B, T::local_mul); - if (T::local_mul) + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) { - if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) - { - auto& queues = BaseMachine::s().queues; - int start = queues.distribute(job, n_matrices); - job.begin = start; - job.end = n_matrices; - matrix_rand_mult(job); - queues.wrap_up(job); - } - else - { - job.begin = 0; - job.end = n_matrices; - matrix_rand_mult(job); - } + auto& queues = BaseMachine::s().queues; + int start = queues.distribute(job, n_matrices); + job.begin = start; + job.end = n_matrices; + matrix_rand_mult(job); + queues.wrap_up(job); + } + else + { + job.begin = 0; + job.end = n_matrices; + matrix_rand_mult(job); } #ifdef VERBOSE_HE From 642d11f7dd99f6a5f5298090def695aeee906223 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 4 May 2022 14:09:15 +0200 Subject: [PATCH 050/265] Compile-time option for unencrypted client connections. --- ExternalIO/Client.h | 56 +++++++++++++++++++++++++++++++++-- ExternalIO/Client.hpp | 10 +++++-- Networking/sockets.h | 7 ----- Networking/ssl_sockets.h | 2 -- Processor/ExternalClients.cpp | 6 ++-- Processor/ExternalClients.h | 7 +++-- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index de9e9cad..4e1e4c4b 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,20 +8,72 @@ #include "Networking/ssl_sockets.h" +#ifdef NO_CLIENT_TLS +class client_ctx +{ +public: + client_ctx(string) + { + } +}; + +class client_socket +{ +public: + int socket; + + client_socket(boost::asio::io_service&, + client_ctx&, int plaintext_socket, string, + string, bool) : socket(plaintext_socket) + { + } + + ~client_socket() + { + close(socket); + } +}; + +inline void send(client_socket* socket, octet* data, size_t len) +{ + send(socket->socket, data, len); +} + +inline void receive(client_socket* socket, octet* data, size_t len) +{ + receive(socket->socket, data, len); +} + +#else + +typedef ssl_ctx client_ctx; + +class client_socket : public ssl_socket +{ +public: + client_socket(boost::asio::io_service& io_service, + boost::asio::ssl::context& ctx, int plaintext_socket, string other, + string me, bool client) : + ssl_socket(io_service, ctx, plaintext_socket, other, me, client) + { + } +}; +#endif + /** * Client-side interface */ class Client { vector plain_sockets; - ssl_ctx ctx; + client_ctx ctx; ssl_service io_service; public: /** * Sockets for cleartext communication */ - vector sockets; + vector sockets; /** * Specification of computation domain diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index 3af40f2f..ffc9705c 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -20,7 +20,7 @@ Client::Client(const vector& hostnames, int port_base, { set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i); octetStream(to_string(my_client_id)).Send(plain_sockets[i]); - sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i], + sockets[i] = new client_socket(io_service, ctx, plain_sockets[i], "P" + to_string(i), "C" + to_string(my_client_id), true); if (i == 0) specification.Receive(sockets[0]); @@ -50,11 +50,15 @@ void Client::send_private_inputs(const vector& values) // Receive num_inputs triples from SPDZ for (size_t j = 0; j < sockets.size(); j++) { +#ifdef VERBOSE_COMM + cerr << "receiving from " << j << endl << flush; +#endif + os.reset_write_head(); os.Receive(sockets[j]); #ifdef VERBOSE_COMM - cerr << "received " << os.get_length() << " from " << j << endl; + cerr << "received " << os.get_length() << " from " << j << endl << flush; #endif for (int j = 0; j < num_inputs; j++) @@ -101,7 +105,7 @@ vector Client::receive_outputs(int n) os.reset_write_head(); os.Receive(socket); #ifdef VERBOSE_COMM - cout << "received " << os.get_length() << endl; + cout << "received " << os.get_length() << endl << flush; #endif for (int j = 0; j < 3 * n; j++) { diff --git a/Networking/sockets.h b/Networking/sockets.h index 7f48aad1..b67a2076 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -35,11 +35,6 @@ void send(T& socket, size_t a, size_t len); template void receive(T& socket, size_t& a, size_t len); -template -void send(T socket, octet* msg, size_t len); -template -void receive(T socket, octet* msg, size_t len); - inline size_t send_non_blocking(int socket, octet* msg, size_t len) { @@ -54,7 +49,6 @@ inline size_t send_non_blocking(int socket, octet* msg, size_t len) return j; } -template<> inline void send(int socket,octet *msg,size_t len) { size_t i = 0; @@ -72,7 +66,6 @@ inline void send(T& socket, size_t a, size_t len) send(socket, blen, len); } -template<> inline void receive(int socket,octet *msg,size_t len) { size_t i=0; diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index fe9477a8..81613995 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -87,7 +87,6 @@ inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length) return socket->write_some(boost::asio::buffer(data, length)); } -template<> inline void send(ssl_socket* socket, octet* data, size_t length) { size_t sent = 0; @@ -103,7 +102,6 @@ inline void send(ssl_socket* socket, octet* data, size_t length) } } -template<> inline void receive(ssl_socket* socket, octet* data, size_t length) { size_t received = 0; diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index 65bb4598..48bb8bd1 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -51,8 +51,8 @@ int ExternalClients::get_client_connection(int portnum_base) client); client_id = stoi(client); if (ctx == 0) - ctx = new ssl_ctx("P" + to_string(get_party_num())); - external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket, + ctx = new client_ctx("P" + to_string(get_party_num())); + external_client_sockets[client_id] = new client_socket(io_service, *ctx, socket, "C" + to_string(client_id), "P" + to_string(get_party_num()), false); client_ports[client_id] = portnum_base; cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; @@ -75,7 +75,7 @@ int ExternalClients::get_party_num() return party_num; } -ssl_socket* ExternalClients::get_socket(int id) +client_socket* ExternalClients::get_socket(int id) { if (external_client_sockets.find(id) == external_client_sockets.end()) throw runtime_error("external connection not found for id " + to_string(id)); diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h index e437f316..5ea1b3fd 100644 --- a/Processor/ExternalClients.h +++ b/Processor/ExternalClients.h @@ -4,6 +4,7 @@ #include "Networking/sockets.h" #include "Networking/ssl_sockets.h" #include "Tools/Exceptions.h" +#include "ExternalIO/Client.h" #include #include #include @@ -25,11 +26,11 @@ class ExternalClients int party_num; // Maps holding per client values (indexed by unique 32-bit id) - std::map external_client_sockets; + std::map external_client_sockets; std::map client_ports; ssl_service io_service; - ssl_ctx* ctx; + client_ctx* ctx; public: @@ -43,7 +44,7 @@ class ExternalClients void close_connection(int client_id); // return the socket for a given client or server identifier - ssl_socket* get_socket(int socket_id); + client_socket* get_socket(int socket_id); int get_party_num(); }; From b3c39c4d37947fe7ce136ea0d0de5f8effcd564f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 11 May 2022 16:41:27 +0200 Subject: [PATCH 051/265] Missing vectorization. --- Compiler/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Compiler/types.py b/Compiler/types.py index 99ca6a8c..03fe7eb4 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1148,12 +1148,14 @@ class cint(_clear, _int): bit_length = bit_length or program.bit_length return floatingpoint.bits(self, bit_length) + @vectorize def legendre(self): """ Clear Legendre symbol computation. """ res = cint() legendrec(res, self) return res + @vectorize def digest(self, num_bytes): """ Clear hashing (libsodium default). """ res = cint() From 59fd44be22984218f774b251472cc583c30cffea Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 16 May 2022 15:25:33 +0100 Subject: [PATCH 052/265] Fix compilation with OpenSSL 3. --- ECDSA/P256Element.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 8437f39d..2c8c776d 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -29,7 +29,7 @@ P256Element::P256Element(const Scalar& other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, bigint(other).get_str().c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } @@ -38,7 +38,7 @@ P256Element::P256Element(word other) : { BIGNUM* exp = BN_new(); BN_dec2bn(&exp, to_string(other).c_str()); - assert(EC_POINTs_mul(curve, point, exp, 0, 0, 0, 0) != 0); + assert(EC_POINT_mul(curve, point, exp, 0, 0, 0) != 0); BN_free(exp); } @@ -56,7 +56,11 @@ void P256Element::check() P256Element::Scalar P256Element::x() const { BIGNUM* x = BN_new(); +#if OPENSSL_VERSION_MAJOR >= 3 + assert(EC_POINT_get_affine_coordinates(curve, point, x, 0, 0) != 0); +#else assert(EC_POINT_get_affine_coordinates_GFp(curve, point, x, 0, 0) != 0); +#endif char* xx = BN_bn2dec(x); Scalar res((bigint(xx))); OPENSSL_free(xx); From 8e4fd45c17412dbbccabfb4dc796ed689b2f42e5 Mon Sep 17 00:00:00 2001 From: Jakob Zierk <48928791+jakobzierk@users.noreply.github.com> Date: Tue, 17 May 2022 09:17:02 +0200 Subject: [PATCH 053/265] Windows/VirtualBox performance Added workaround. --- doc/troubleshooting.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 6a79ea19..26808480 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -140,6 +140,29 @@ This indicates an error in the internal accounting of preprocessing. Please file a bug report. +Windows/VirtualBox performance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Performance when using Windows/VirtualBox is by default abysmal, as +AVX/AVX2 instructions are deactivated (see e.g. +`here `_), +which causes a dramatic performance loss. Deactivate Hyper-V/Hypervisor +using:: + bcdedit /set hypervisorlaunchtype off + DISM /Online /Disable-Feature:Microsoft-Hyper-V + + +Performance can be further increased when compiling MP-SPDZ yourself: +:: + sudo apt-get update + sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm + git clone https://github.com/data61/MP-SPDZ.git + cd MP-SPDZ + make tldr + +See also `this issue `_ for a discussion. + + ``mac_fail`` ~~~~~~~~~~~~ From de12e08784636497dc315f287d08513640b6e50d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 23 May 2022 17:27:56 +0200 Subject: [PATCH 054/265] Fix bugs on macOS. --- Math/Z2k.h | 1 + Math/gf2n.h | 1 + Math/gfp.h | 1 + Tools/parse.h | 1 + 4 files changed, 4 insertions(+) diff --git a/Math/Z2k.h b/Math/Z2k.h index 586c78c0..cdde3f40 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -85,6 +85,7 @@ public: Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); } Z2(int x) : Z2(long(x)) { a[N_WORDS - 1] &= UPPER_MASK; } Z2(long x) : Z2(mp_limb_t(x)) { if (K > 64 and x < 0) memset(&a[1], -1, N_BYTES - 8); } + Z2(long long x) : Z2(long(x)) {} template Z2(const IntBase& x); /** diff --git a/Math/gf2n.h b/Math/gf2n.h index 485d8430..3ec8849a 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -118,6 +118,7 @@ protected: gf2n_(U a) : a(a & mask) {} gf2n_(long a) : gf2n_(U(a)) {} gf2n_(int a) : gf2n_(U(unsigned(a))) {} + gf2n_(long long a) : gf2n_(U(a)) {} template gf2n_(IntBase a) : a(a.get()) {} diff --git a/Math/gfp.h b/Math/gfp.h index 3bc23e19..9a50dc03 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -161,6 +161,7 @@ class gfp_ : public ValueInterface gfp_(const mpz_class& x) { to_modp(a, x, ZpD); } gfp_(int x) : gfp_(long(x)) {} gfp_(long x); + gfp_(long long x) : gfp_(long(x)) {} gfp_(word x) : gfp_(bigint::tmp = x) {} template gfp_(IntBase x) : gfp_(x.get()) {} diff --git a/Tools/parse.h b/Tools/parse.h index 8ff0fee9..c4b973dd 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -13,6 +13,7 @@ using namespace std; #ifdef __APPLE__ # include #define be32toh(x) OSSwapBigToHostInt32(x) +#define be64toh(x) OSSwapBigToHostInt64(x) #endif // Read a 4-byte integer From 1460c9b5748c7cb7779cb79b2d3de792106be1f2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 24 May 2022 15:54:56 +0200 Subject: [PATCH 055/265] Fix output issue. --- Compiler/types.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index 03fe7eb4..1d06f3f7 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5189,10 +5189,10 @@ class Array(_vectorizable): self.value_type.free(self.address) self.address = None - def get_address(self, index): + def get_address(self, index, size=None): if isinstance(index, (_secret, _single)): raise CompilerError('need cleartext index') - key = str(index) + key = str(index), size or 1 if self.length is not None: from .GC.types import cbits if isinstance(index, int): @@ -5211,6 +5211,8 @@ class Array(_vectorizable): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() + if size is not None and isinstance(base, _register): + base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ for i in range(n)]) @@ -5332,7 +5334,8 @@ class Array(_vectorizable): except: pass try: - self.value_type.conv(other).store_in_mem(self.get_address(base)) + other = self.value_type.conv(other) + other.store_in_mem(self.get_address(base, other.size)) if len(self) != None and util.is_constant(base): assert len(self) >= other.size + base except (AttributeError, CompilerError): @@ -5370,7 +5373,7 @@ class Array(_vectorizable): :param base: starting point (regint/cint/int) :param size: length (compile-time int) """ size = size or self.length - base - return self.value_type.load_mem(self.get_address(base), size=size) + return self.value_type.load_mem(self.get_address(base, size), size=size) get_part_vector = get_vector @@ -5581,6 +5584,9 @@ class Array(_vectorizable): # compatibility with registers return Array(size, self.value_type) + def output_if(self, cond): + library.print_str_if(cond, '%s', self.get_vector()) + def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), self.address) From 2dad77ba326e5266d6447c69824896d1b458c08f Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 24 May 2022 16:54:30 +0200 Subject: [PATCH 056/265] More flexible conversion. --- Compiler/GC/types.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 6c3abad0..b34e68c8 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -154,9 +154,15 @@ class bits(Tape.Register, _structure, _bit): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): - assert(other.size == math.ceil(self.n / self.unit)) - for i, (x, y) in enumerate(zip(self, other)): + assert self.unit == 64 + n_units = int(math.ceil(self.n / self.unit)) + n_convs = min(other.size, n_units) + for i in range(n_convs): + x = self[i] + y = other[i] self.conv_regint(min(self.unit, self.n - i * self.unit), x, y) + for i in range(n_convs, n_units): + inst.ldbits(self[i], min(self.unit, self.n - i * self.unit), 0) elif (isinstance(self, type(other)) or isinstance(other, type(self))) \ and self.n == other.n: for i in range(math.ceil(self.n / self.unit)): From 5ab8c702dde2f25ae7f2f2d0e4d47f5d716fa621 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 27 May 2022 14:19:33 +0200 Subject: [PATCH 057/265] Secure shuffling. --- BMR/Party.cpp | 1 - BMR/RealProgramParty.hpp | 5 +- CHANGELOG.md | 9 + Compiler/GC/types.py | 5 +- Compiler/instructions.py | 65 +++++ Compiler/instructions_base.py | 5 + Compiler/oram.py | 140 ++++++----- Compiler/path_oram.py | 31 +-- Compiler/permutation.py | 203 ++-------------- Compiler/sorting.py | 54 +++++ Compiler/types.py | 139 ++++++++++- ExternalIO/Client.h | 11 +- FHE/AddableVector.h | 11 +- FHE/Ciphertext.cpp | 31 ++- FHE/Ciphertext.h | 25 +- FHE/Diagonalizer.cpp | 3 + FHE/FFT_Data.cpp | 5 + FHE/FFT_Data.h | 2 +- FHE/FHE_Keys.cpp | 64 ++++- FHE/FHE_Keys.h | 55 ++++- FHE/FHE_Params.cpp | 37 +++ FHE/FHE_Params.h | 24 ++ FHE/NTL-Subs.cpp | 31 ++- FHE/NoiseBounds.cpp | 6 +- FHE/NoiseBounds.h | 2 + FHE/P2Data.cpp | 8 +- FHE/Plaintext.cpp | 100 ++++---- FHE/Plaintext.h | 45 +++- FHE/Ring.cpp | 2 +- FHE/Ring_Element.cpp | 3 +- FHE/Rq_Element.cpp | 8 +- FHE/Rq_Element.h | 8 +- FHEOffline/Multiplier.cpp | 18 +- FHEOffline/PairwiseSetup.cpp | 26 +- GC/NoShare.h | 12 +- Machines/dealer-ring-party.cpp | 2 + Machines/mama-party.cpp | 2 +- Makefile | 1 + Math/FixedVec.h | 7 +- Math/Setup.cpp | 4 +- Math/Z2k.h | 6 + Math/Z2k.hpp | 5 +- Math/Zp_Data.cpp | 3 +- Math/gf2n.cpp | 15 +- Math/gf2n.h | 4 + Math/gfpvar.h | 6 + Networking/AllButLastPlayer.h | 17 +- Networking/CryptoPlayer.cpp | 4 +- Networking/Player.cpp | 8 - Networking/Player.h | 1 - Processor/Data_Files.hpp | 2 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 10 + Processor/Instruction.hpp | 72 ++++-- Processor/Machine.hpp | 7 +- Processor/OnlineMachine.hpp | 4 +- Processor/Processor.h | 21 ++ Processor/Processor.hpp | 56 +++++ Processor/RingMachine.hpp | 5 +- Programs/Source/dijkstra_example.mpc | 50 ++++ Programs/Source/dijkstra_tutorial.mpc | 9 - Protocols/Dealer.h | 36 +++ Protocols/DealerInput.h | 1 + Protocols/DealerInput.hpp | 13 +- Protocols/DealerMC.h | 1 + Protocols/DealerMC.hpp | 7 + Protocols/DealerMatrixPrep.h | 32 +++ Protocols/DealerMatrixPrep.hpp | 87 +++++++ Protocols/DealerPrep.h | 13 + Protocols/DealerPrep.hpp | 51 ++++ Protocols/DealerShare.h | 14 +- Protocols/FakeShare.h | 1 + Protocols/Hemi.h | 10 +- Protocols/Hemi.hpp | 41 +++- Protocols/HemiShare.h | 4 + Protocols/MAC_Check.h | 3 +- Protocols/MAC_Check.hpp | 11 +- Protocols/MAC_Check_Base.h | 4 + Protocols/MAC_Check_Base.hpp | 7 + Protocols/MaliciousRep3Share.h | 4 + Protocols/MaliciousRepMC.hpp | 2 +- Protocols/MaliciousShamirShare.h | 2 + Protocols/Rep3Share.h | 1 + Protocols/Rep4Share.h | 2 + Protocols/Replicated.h | 4 +- Protocols/Replicated.hpp | 12 +- Protocols/ReplicatedInput.h | 2 +- Protocols/ReplicatedMC.hpp | 2 +- Protocols/SecureShuffle.h | 53 +++++ Protocols/SecureShuffle.hpp | 328 ++++++++++++++++++++++++++ Protocols/SemiShare.h | 1 + Protocols/ShamirShare.h | 1 + Protocols/Share.h | 1 + Protocols/ShareInterface.h | 9 +- Protocols/ShareMatrix.h | 175 +++++++++++++- Protocols/TemiShare.h | 3 + Protocols/fake-stuff.hpp | 29 ++- README.md | 5 +- Tools/Exceptions.cpp | 6 +- Tools/Exceptions.h | 2 +- Tools/PointerVector.h | 9 + Tools/Waksman.cpp | 91 +++++++ Tools/Waksman.h | 39 +++ Utils/he-example.cpp | 97 ++++++++ doc/Doxyfile | 2 +- doc/homomorphic-encryption.rst | 31 +++ doc/index.rst | 1 + doc/troubleshooting.rst | 2 + 108 files changed, 2227 insertions(+), 542 deletions(-) create mode 100644 Compiler/sorting.py create mode 100644 Programs/Source/dijkstra_example.mpc delete mode 100644 Programs/Source/dijkstra_tutorial.mpc create mode 100644 Protocols/Dealer.h create mode 100644 Protocols/DealerMatrixPrep.h create mode 100644 Protocols/DealerMatrixPrep.hpp create mode 100644 Protocols/SecureShuffle.h create mode 100644 Protocols/SecureShuffle.hpp create mode 100644 Tools/Waksman.cpp create mode 100644 Tools/Waksman.h create mode 100644 Utils/he-example.cpp create mode 100644 doc/homomorphic-encryption.rst diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 84ba909b..beddd64c 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,6 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 8e16c307..ae69cb7f 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -28,7 +28,7 @@ RealProgramParty* RealProgramParty::singleton = 0; template RealProgramParty::RealProgramParty(int argc, const char** argv) : - garble_processor(garble_machine), dummy_proc({{}, 0}) + garble_processor(garble_machine), dummy_proc({}, 0) { assert(singleton == 0); singleton = this; @@ -157,6 +157,9 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : MC->Check(*P); data_sent = P->total_comm().sent; + if (online_opts.verbose) + P->total_comm().print(); + this->machine.write_memory(this->N.my_num()); } diff --git a/CHANGELOG.md b/CHANGELOG.md index 744d0ff1..18cc92ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.2 (Mai 27, 2022) + +- Secure shuffling +- O(n log n) radix sorting +- Documented BGV encryption interface +- Optimized matrix multiplication in dealer protocol +- Fixed security bug in homomorphic encryption parameter generation +- Fixed Security bug in Temi matrix multiplication + ## 0.3.1 (Apr 19, 2022) - Protocol in dealer model diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index b34e68c8..fdd98722 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -382,7 +382,6 @@ class sbits(bits): reg_type = 'sb' is_clear = False clear_type = cbits - default_type = cbits load_inst = (inst.ldmsbi, inst.ldmsb) store_inst = (inst.stmsbi, inst.stmsb) bitdec = inst.bitdecs @@ -404,6 +403,9 @@ class sbits(bits): else: return sbits.get_type(n)(value) @staticmethod + def _new(value): + return value + @staticmethod def get_random_bit(): res = sbit() inst.bitb(res) @@ -909,6 +911,7 @@ class cbit(bit, cbits): sbits.bit_type = sbit cbits.bit_type = cbit sbit.clear_type = cbit +sbits.default_type = sbits class bitsBlock(oram.Block): value_type = sbits diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 8a10ee58..5f5b82db 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -17,6 +17,7 @@ right order. import itertools import operator +import math from . import tools from random import randint from functools import reduce @@ -2406,6 +2407,70 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +@base.gf2n +class secshuffle(base.VectorInstruction, base.DataInstruction): + """ Secure shuffling. + + :param: destination (sint) + :param: source (sint) + """ + __slots__ = [] + code = base.opcodes['SECSHUFFLE'] + arg_format = ['sw','s','int'] + + def __init__(self, *args, **kwargs): + super(secshuffle_class, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class gensecshuffle(base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (regint) + :param: size (int) + + """ + __slots__ = [] + code = base.opcodes['GENSECSHUFFLE'] + arg_format = ['ciw','int'] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class applyshuffle(base.VectorInstruction, base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (sint) + :param: source (sint) + :param: number of elements to be treated as one (int) + :param: handle (regint) + :param: reverse (0/1) + + """ + __slots__ = [] + code = base.opcodes['APPLYSHUFFLE'] + arg_format = ['sw','s','int','ci','int'] + + def __init__(self, *args, **kwargs): + super(applyshuffle, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'triple', 0), float('inf')) + +class delshuffle(base.Instruction): + """ Delete secure shuffle. + + :param: handle (regint) + + """ + code = base.opcodes['DELSHUFFLE'] + arg_format = ['ci'] + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 8ae0b86f..d598d8a7 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -106,6 +106,11 @@ opcodes = dict( CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + # Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/oram.py b/Compiler/oram.py index 543fc4aa..d4b43438 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -348,7 +348,7 @@ class Entry(object): def __len__(self): return 2 + len(self.x) def __repr__(self): - return '{empty=%s}' % self.is_empty if self.is_empty \ + return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \ else '{%s: %s}' % (self.v, self.x) def __add__(self, other): try: @@ -466,12 +466,14 @@ class AbstractORAM(object): def get_array(size, t, *args, **kwargs): return t.dynamic_array(size, t, *args, **kwargs) def read(self, index): - return self._read(self.value_type.hard_conv(index)) + res = self._read(self.index_type.hard_conv(index)) + res = [self.value_type._new(x) for x in res] + return res def write(self, index, value): + value = util.tuplify(value) + value = [self.value_type.conv(x) for x in value] new_value = [self.value_type.get_type(length).hard_conv(v) \ - for length,v in zip(self.entry_size, value \ - if isinstance(value, (tuple, list)) \ - else (value,))] + for length,v in zip(self.entry_size, value)] return self._write(self.index_type.hard_conv(index), *new_value) def access(self, index, new_value, write, new_empty=False): return self._access(self.index_type.hard_conv(index), @@ -795,7 +797,8 @@ class RefTrivialORAM(EndRecursiveEviction): for i,value in enumerate(values): index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.ram[i] = Entry(index, new_value, value_type=self.value_type) @@ -986,7 +989,8 @@ class List(EndRecursiveEviction): for i,value in enumerate(values): index = self.value_type.hard_conv(i) new_value = [self.value_type.hard_conv(v) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.__setitem__(index, new_value) def __repr__(self): @@ -1062,11 +1066,12 @@ class TreeORAM(AbstractORAM): stop_timer(1) start_timer() self.root = RefBucket(1, self) - self.index = self.index_structure(size, self.D, value_type, init_rounds, True) + self.index = self.index_structure(size, self.D, self.index_type, + init_rounds, True) - self.read_value = Array(self.value_length, value_type) + self.read_value = Array(self.value_length, value_type.default_type) self.read_non_empty = MemValue(self.value_type.bit_type(0)) - self.state = MemValue(self.value_type(0)) + self.state = MemValue(self.value_type.default_type(0)) @method_block def add_to_root(self, state, is_empty, v, *x): if len(x) != self.value_length: @@ -1106,10 +1111,10 @@ class TreeORAM(AbstractORAM): self.evict_bucket(RefBucket(p_bucket2, self), d) @method_block def read_and_renew_index(self, u): - l_star = random_block(self.D, self.value_type) + l_star = random_block(self.D, self.index_type) if use_insecure_randomness: new_path = regint.get_random(self.D) - l_star = self.value_type(new_path) + l_star = self.index_type(new_path) self.state.write(l_star) return self.index.update(u, l_star, evict=False).reveal() @method_block @@ -1120,7 +1125,7 @@ class TreeORAM(AbstractORAM): parallel = get_parallel(self.index_size, *self.internal_value_type()) @map_sum(get_n_threads_for_tree(self.size), parallel, levels, \ self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type] * self.value_length) + [self.value_type.default_type] * self.value_length) def process(level): b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level)) bucket = RefBucket(b_index, self) @@ -1142,9 +1147,9 @@ class TreeORAM(AbstractORAM): Program.prog.curr_tape.start_new_basicblock() crash() def internal_value_type(self): - return self.value_type, self.value_length + 1 + return self.value_type.default_type, self.value_length + 1 def internal_entry_size(self): - return self.value_type, [self.D] + list(self.entry_size) + return self.value_type.default_type, [self.D] + list(self.entry_size) def n_buckets(self): return 2**(self.D+1) @method_block @@ -1176,8 +1181,9 @@ class TreeORAM(AbstractORAM): #print 'pre-add', self maybe_start_timer(4) self.add_to_root(state, entry.empty(), \ - self.value_type(entry.v.read()), \ - *(self.value_type(i.read()) for i in entry.x)) + self.index_type(entry.v.read()), \ + *(self.value_type.default_type(i.read()) + for i in entry.x)) maybe_stop_timer(4) #print 'pre-evict', self if evict: @@ -1228,21 +1234,27 @@ class TreeORAM(AbstractORAM): raise CompilerError('Batch initialization only possible with sint.') depth = log2(m) - leaves = [0] * m - entries = [0] * m - indexed_values = [0] * m + leaves = self.value_type.Array(m) + indexed_values = \ + self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1) # assign indices 0, ..., m-1 - for i,value in enumerate(values): + @for_range(m) + def _(i): + value = values[i] index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ for v in (value if isinstance(value, (tuple, list)) \ else (value,))] indexed_values[i] = [index] + new_value - + entries = sint.Matrix(self.bucket_size * 2 ** self.D, + len(Entry(0, list(indexed_values[0]), False))) + # assign leaves - for i,index_value in enumerate(indexed_values): + @for_range(len(indexed_values)) + def _(i): + index_value = list(indexed_values[i]) leaves[i] = random_block(self.D, self.value_type) index = index_value[0] @@ -1252,18 +1264,20 @@ class TreeORAM(AbstractORAM): # save unsorted leaves for position map unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves] - permutation.sort(leaves, comp=permutation.normal_comparator) + leaves.sort() bucket_sz = 0 # B[i] = (pos, leaf, "last in bucket" flag) for i-th entry - B = [[0]*3 for i in range(m)] + B = sint.Matrix(m, 3) B[0] = [0, leaves[0], 0] B[-1] = [None, None, sint(1)] - s = 0 + s = MemValue(sint(0)) - for i in range(1, m): + @for_range_opt(m - 1) + def _(j): + i = j + 1 eq = leaves[i].equal(leaves[i-1]) - s = (s + eq) * eq + s.write((s + eq) * eq) B[i][0] = s B[i][1] = leaves[i] B[i-1][2] = 1 - eq @@ -1271,7 +1285,7 @@ class TreeORAM(AbstractORAM): #last_in_bucket[i-1] = 1 - eq # shuffle - permutation.shuffle(B, value_type=sint) + B.secure_shuffle() #cint(0).print_reg('shuf') sz = MemValue(0) #cint(0) @@ -1279,7 +1293,8 @@ class TreeORAM(AbstractORAM): empty_positions = Array(nleaves, self.value_type) empty_leaves = Array(nleaves, self.value_type) - for i in range(m): + @for_range(m) + def _(i): if_then(reveal(B[i][2])) #if B[i][2] == 1: #cint(i).print_reg('last') @@ -1291,12 +1306,13 @@ class TreeORAM(AbstractORAM): empty_positions[szval] = B[i][0] #pos[i][0] #empty_positions[szval].reveal().print_reg('ps0') empty_leaves[szval] = B[i][1] #pos[i][1] - sz += 1 + sz.iadd(1) end_if() - pos_bits = [] + pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2) - for i in range(nleaves): + @for_range_opt(nleaves) + def _(i): leaf = empty_leaves[i] # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: @@ -1315,46 +1331,39 @@ class TreeORAM(AbstractORAM): bucket_bits = [b for sl in zip(bits2,bits) for b in sl] else: bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0] - pos_bits += [[b, leaf] for b in bucket_bits] + assert len(bucket_bits) == self.bucket_size + for j, b in enumerate(bucket_bits): + pos_bits[i * self.bucket_size + j] = [b, leaf] # sort to get empty positions first - permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator) + pos_bits.sort(n_bits=1) # now assign positions to empty entries - empty_entries = [0] * (self.bucket_size*2**self.D - m) - - for i in range(self.bucket_size*2**self.D - m): + @for_range(len(entries) - m) + def _(i): vtype, vlength = self.internal_value_type() leaf = vtype(pos_bits[i][1]) # set leaf in empty entry for assigning after shuffle - value = tuple([leaf] + [vtype(0) for j in range(vlength)]) + value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)]) entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype) - empty_entries[i] = entry + entries[m + i] = entry # now shuffle, reveal positions and place entries - entries = entries + empty_entries - while len(entries) & (len(entries)-1) != 0: - entries.append(None) - permutation.shuffle(entries, value_type=sint) - entries = [entry for entry in entries if entry is not None] - clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries] + entries.secure_shuffle() + clear_leaves = Array.create_from( + Entry(entries.get_columns()).x[0].reveal()) Program.prog.curr_tape.start_new_basicblock() bucket_sizes = Array(2**self.D, regint) for i in range(2**self.D): bucket_sizes[i] = 0 - k = 0 - for entry,leaf in zip(entries, clear_leaves): - leaf = leaf.read() - k += 1 - # for some reason leaf_buckets is in bit-reversed order - bits = bit_decompose(leaf, self.D) - rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1])) - bucket = RefBucket(rev_leaf + (1 << self.D), self) - # hack: 1*entry ensures MemValues are converted to sints - bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry + @for_range_opt(len(entries)) + def _(k): + leaf = clear_leaves[k] + bucket = RefBucket(leaf + (1 << self.D), self) + bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k]) bucket_sizes[leaf] += 1 self.index.batch_init([leaf.read() for leaf in unsorted_leaves]) @@ -1599,16 +1608,20 @@ class PackedIndexStructure(object): def batch_init(self, values): """ Initialize m values with indices 0, ..., m-1 """ m = len(values) - n_entries = max(1, m/self.entries_per_block) - new_values = [0] * n_entries + n_entries = max(1, m//self.entries_per_block) + new_values = sint.Matrix(n_entries, self.elements_per_block) + values = Array.create_from(values) - for i in range(n_entries): + @for_range(n_entries) + def _(i): block = [0] * self.elements_per_block for j in range(self.elements_per_block): base = i * self.entries_per_block + j * self.entries_per_element for k in range(self.entries_per_element): - if base + k < m: - block[j] += values[base + k] << (k * self.entry_size) + @if_(base + k < m) + def _(): + block[j] += \ + values[base + k] << (k * sum(self.entry_size)) new_values[i] = block @@ -1667,7 +1680,8 @@ def OptimalORAM(size,*args,**kwargs): experiments. :param size: number of elements - :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` + :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / + :py:class:`sfix` """ if optimal_threshold is None: if n_threads == 1: @@ -1784,7 +1798,7 @@ def test_batch_init(oram_type, N): oram = oram_type(N, value_type) print('initialized') print_reg(cint(0), 'init') - oram.batch_init([value_type(i) for i in range(N)]) + oram.batch_init(Array.create_from(sint(regint.inc(N)))) print_reg(cint(0), 'done') @for_range(N) def f(i): diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index fb1601c3..b9e3952b 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -111,24 +111,6 @@ def bucket_size_sorter(x, y): return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z]) -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - - - Returns the network switching config so it may be re-used later. """ - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = permutation.configure_waksman(permutation.random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type(b) for b in c] - permutation.waksman(x, config, reverse=reverse) - permutation.waksman(x, config, reverse=reverse) - - return config - - def LT(a, b): a_bits = bit_decompose(a) b_bits = bit_decompose(b) @@ -472,10 +454,15 @@ class PathORAM(TreeORAM): print_ln() # shuffle entries and levels - while len(merged_entries) & (len(merged_entries)-1) != 0: - merged_entries.append(None) #self.root.bucket.empty_entry(False)) - permutation.rec_shuffle(merged_entries, value_type=self.value_type) - merged_entries = [e for e in merged_entries if e is not None] + flat = [] + for x in merged_entries: + flat += list(x[0]) + [x[1]] + flat = self.value_type(flat) + assert len(flat) % len(merged_entries) == 0 + l = len(flat) // len(merged_entries) + shuffled = flat.secure_shuffle(l) + merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]] + for i in range(len(shuffled) // l)] # need to copy entries/levels to memory for re-positioning entries_ram = RAM(self.temp_size, self.entry_type, self.get_array) diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 6e1273ec..07d3a3e7 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -10,16 +10,6 @@ if '_Array' not in dir(): from Compiler.program import Program _Array = Array -SORT_BITS = [] -insecure_random = Random(0) - -def predefined_comparator(x, y): - """ Assumes SORT_BITS is populated with the required sorting network bits """ - if predefined_comparator.sort_bits_iter is None: - predefined_comparator.sort_bits_iter = iter(SORT_BITS) - return next(predefined_comparator.sort_bits_iter) -predefined_comparator.sort_bits_iter = None - def list_comparator(x, y): """ Uses the first element in the list for comparison """ return x[0] < y[0] @@ -37,10 +27,6 @@ def bitwise_comparator(x, y): def cond_swap_bit(x,y, b): """ swap if b == 1 """ - if x is None: - return y, None - elif y is None: - return x, None if isinstance(x, list): t = [(xi - yi) * b for xi,yi in zip(x, y)] return [xi - ti for xi,ti in zip(x, t)], \ @@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): else: raise CompilerError('Length of list must be power of two') -def merge(a, b, comp): - """ General length merge (pads to power of 2) """ - while len(a) & (len(a)-1) != 0: - a.append(None) - while len(b) & (len(b)-1) != 0: - b.append(None) - if len(a) < len(b): - a += [None] * (len(b) - len(a)) - elif len(b) < len(a): - b += [None] * (len(b) - len(b)) - t = a + b - odd_even_merge(t, comp) - for i,v in enumerate(t[::]): - if v is None: - t.remove(None) - return t - def sort(a, comp): """ Pads to power of 2, sorts, removes padding """ length = len(a) @@ -112,47 +81,12 @@ def sort(a, comp): odd_even_merge_sort(a, comp) del a[length:] -def recursive_merge(a, comp): - """ Recursively merge a list of sorted lists (initially sorted by size) """ - if len(a) == 1: - return - # merge smallest two lists, place result in correct position, recurse - t = merge(a[0], a[1], comp) - del a[0] - del a[0] - added = False - for i,c in enumerate(a): - if len(c) >= len(t): - a.insert(i, t) - added = True - break - if not added: - a.append(t) - recursive_merge(a, comp) +# The following functionality for shuffling isn't used any more as it +# has been moved to the virtual machine. The code has been kept for +# reference. -def random_perm(n): - """ Generate a random permutation of length n - - WARNING: randomness fixed at compile-time, this is NOT secure - """ - if not Program.prog.options.insecure: - raise CompilerError('no secure implementation of Waksman permution, ' - 'use --insecure to activate') - a = list(range(n)) - for i in range(n-1, 0, -1): - j = insecure_random.randint(0, i) - t = a[i] - a[i] = a[j] - a[j] = t - return a - -def inverse(perm): - inv = [None] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - return inv - -def configure_waksman(perm): +def configure_waksman(perm, n_iter=[0]): + top = n_iter == [0] n = len(perm) if n == 2: return [(perm[0], perm[0])] @@ -175,6 +109,7 @@ def configure_waksman(perm): via = 0 j0 = j while True: + n_iter[0] += 1 #print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2) i = inv_perm[j] @@ -209,8 +144,11 @@ def configure_waksman(perm): assert sorted(p0) == list(range(n//2)) assert sorted(p1) == list(range(n//2)) - p0_config = configure_waksman(p0) - p1_config = configure_waksman(p1) + p0_config = configure_waksman(p0, n_iter) + p1_config = configure_waksman(p1, n_iter) + if top: + print(n_iter[0], 'iterations for Waksman') + assert O[0] == 0, 'not a Waksman network' return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] def waksman(a, config, depth=0, start=0, reverse=False): @@ -358,23 +296,10 @@ def iter_waksman(a, config, reverse=False): # nblocks /= 2 # depth -= 1 -def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False): - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = configure_waksman(random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type.bit_type(b) for b in c] - waksman(x, config, reverse=reverse) - waksman(x, config, reverse=reverse) - -def config_shuffle(n, value_type): - """ Compute config for oblivious shuffling. - - Take mod 2 for active sec. """ - perm = random_perm(n) +def config_from_perm(perm, value_type): + n = len(perm) + assert(list(sorted(perm))) == list(range(n)) if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) @@ -394,103 +319,3 @@ def config_shuffle(n, value_type): for j,b in enumerate(c): config[i * len(perm) + j] = b return config - -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - WARNING: This is not a properly secure implementation but has roughly the right complexity. - - Returns the network switching config so it may be re-used later. """ - n = len(x) - m = 2**int(math.ceil(math.log(n, 2))) - assert n == m, 'only working for powers of two' - if config is None: - config = config_shuffle(n, value_type) - - if isinstance(x, list): - if isinstance(x[0], list): - length = len(x[0]) - assert len(x) == length - for i in range(length): - xi = Array(m, value_type.reg_type) - for j in range(n): - xi[j] = x[j][i] - for j in range(n, m): - xi[j] = value_type(0) - iter_waksman(xi, config, reverse=reverse) - iter_waksman(xi, config, reverse=reverse) - for j, y in enumerate(xi): - x[j][i] = y - else: - xa = Array(m, value_type.reg_type) - for i in range(n): - xa[i] = x[i] - for i in range(n, m): - xa[i] = value_type(0) - iter_waksman(xa, config, reverse=reverse) - iter_waksman(xa, config, reverse=reverse) - x[:] = xa - elif isinstance(x, Array): - if len(x) != m and config is None: - raise CompilerError('Non-power of 2 Array input not yet supported') - iter_waksman(x, config, reverse=reverse) - iter_waksman(x, config, reverse=reverse) - else: - raise CompilerError('Invalid type for shuffle:', type(x)) - - return config - -def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None): - """ Shuffle a list of ORAM entries. - - Randomly permutes the first "perm_size" entries, leaving the rest (empty - entry padding) in the same position. """ - n = len(x) - l = len(x[0]) - if n & (n-1) != 0: - raise CompilerError('Entries must be padded to power of two length.') - if perm_size is None: - perm_size = n - - xarrays = [Array(n, value_type.reg_type) for i in range(l)] - for i in range(n): - for j,value in enumerate(x[i]): - if isinstance(value, MemValue): - xarrays[j][i] = value.read() - else: - xarrays[j][i] = value - - if config is None: - config = config_shuffle(perm_size, value_type) - for xi in xarrays: - shuffle(xi, config, value_type, reverse) - for i in range(n): - x[i] = entry_cls(xarrays[j][i] for j in range(l)) - return config - - -def sort_zeroes(bits, x, n_ones, value_type): - """ Return Array of values in "x" where the corresponding bit in "bits" is - a 0. - - The total number of zeroes in "bits" must be known. - "bits" and "x" must be Arrays. """ - config = config_shuffle(len(x), value_type) - shuffle(bits, config=config, value_type=value_type) - shuffle(x, config=config, value_type=value_type) - result = Array(n_ones, value_type.reg_type) - - sz = MemValue(0) - last_x = MemValue(value_type(0)) - #for i,b in enumerate(bits): - #if_then(b.reveal() == 0) - #result[sz.read()] = x[i] - #sz += 1 - #end_if() - @for_range(len(bits)) - def f(i): - found = (bits[i].reveal() == 0) - szval = sz.read() - result[szval] = last_x + (x[i] - last_x) * found - sz.write(sz + found) - last_x.write(result[szval]) - return result diff --git a/Compiler/sorting.py b/Compiler/sorting.py new file mode 100644 index 00000000..248b3ea0 --- /dev/null +++ b/Compiler/sorting.py @@ -0,0 +1,54 @@ +import itertools +from Compiler import types, library, instructions + +def dest_comp(B): + Bt = B.transpose() + Bt_flat = Bt.get_vector() + St_flat = Bt.value_type.Array(len(Bt_flat)) + St_flat.assign(Bt_flat) + @library.for_range(len(St_flat) - 1) + def _(i): + St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + Tt_flat = Bt.get_vector() * St_flat.get_vector() + Tt = types.Matrix(*Bt.sizes, B.value_type) + Tt.assign_vector(Tt_flat) + return sum(Tt) - 1 + +def reveal_sort(k, D, reverse=False): + assert len(k) == len(D) + library.break_point() + shuffle = types.sint.get_secure_shuffle(len(k)) + k_prime = k.get_vector().secure_permute(shuffle).reveal() + idx = types.Array.create_from(k_prime) + if reverse: + D.assign_vector(D.get_slice_vector(idx)) + library.break_point() + D.secure_permute(shuffle, reverse=True) + else: + D.secure_permute(shuffle) + library.break_point() + v = D.get_vector() + D.assign_slice_vector(idx, v) + library.break_point() + instructions.delshuffle(shuffle) + +def radix_sort(k, D, n_bits=None, signed=True): + assert len(k) == len(D) + bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) + if signed and len(bs) > 1: + bs[-1][:] = bs[-1][:].bit_not() + B = types.sint.Matrix(len(k), 2) + h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + @library.for_range(len(bs)) + def _(i): + b = bs[i] + B.set_column(0, 1 - b.get_vector()) + B.set_column(1, b.get_vector()) + c = types.Array.create_from(dest_comp(B)) + reveal_sort(c, h, reverse=False) + @library.if_e(i < len(bs) - 1) + def _(): + reveal_sort(h, bs[i + 1], reverse=True) + @library.else_ + def _(): + reveal_sort(h, D, reverse=True) diff --git a/Compiler/types.py b/Compiler/types.py index 1d06f3f7..098f493f 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1937,6 +1937,11 @@ class _secret(_register, _secret_structure): matmuls(res, A, B, n_rows, n, n_cols) return res + @staticmethod + def _new(self): + # mirror sfix + return self + @no_doc def __init__(self, reg_type, val=None, size=None): if isinstance(val, self.clear_type): @@ -2093,6 +2098,12 @@ class _secret(_register, _secret_structure): else: return self * self + @set_instruction_type + def secure_shuffle(self, unit_size=1): + res = type(self)(size=self.size) + secshuffle(res, self, unit_size) + return res + @set_instruction_type @vectorize def reveal(self): @@ -2741,6 +2752,17 @@ class sint(_secret, _int): return w + @staticmethod + def get_secure_shuffle(n): + res = regint() + gensecshuffle(res, n) + return res + + def secure_permute(self, shuffle, unit_size=1, reverse=False): + res = sint(size=self.size) + applyshuffle(res, self, unit_size, shuffle, reverse) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -4291,6 +4313,10 @@ class _fix(_single): k = self.k return revealed_fix._new(val) + def bit_decompose(self, n_bits=None): + """ Bit decomposition. """ + return self.v.bit_decompose(n_bits or self.k) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4312,6 +4338,8 @@ class sfix(_fix): int_type = sint bit_type = sintbit clear_type = cfix + get_type = staticmethod(lambda n: sint) + default_type = sint @vectorized_classmethod def get_input_from(cls, player): @@ -4385,6 +4413,10 @@ class sfix(_fix): def coerce(self, other): return parse_type(other, k=self.k, f=self.f) + def hard_conv_me(self, cls): + assert cls == sint + return self.v + def mul_no_reduce(self, other, res_params=None): assert self.f == other.f assert self.k == other.k @@ -4409,6 +4441,14 @@ class sfix(_fix): return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) + def secure_shuffle(self, *args, **kwargs): + return self._new(self.v.secure_shuffle(*args, **kwargs), + k=self.k, f=self.f) + + def secure_permute(self, *args, **kwargs): + return self._new(self.v.secure_permute(*args, **kwargs), + k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -5395,13 +5435,21 @@ class Array(_vectorizable): regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) - def get_slice_vector(self, slice): + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 assert len(slice) <= self.total_size() base = regint.inc(len(slice), slice.address, 1, 1) - inc = regint.inc(len(slice), 0, 1, 1, 1) + inc = regint.inc(len(slice), self.address, 1, 1, 1) addresses = slice.value_type.load_mem(base) + inc - return self.value_type.load_mem(self.address + addresses) + return addresses + + def get_slice_vector(self, slice): + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(addresses) def expand_to_vector(self, index, size): """ Create vector from single entry. @@ -5514,6 +5562,14 @@ class Array(_vectorizable): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def secure_shuffle(self): + """ Secure shuffle in place according to the security model. """ + self.assign_vector(self.get_vector().secure_shuffle()) + + def secure_permute(self, *args, **kwargs): + """ Secure permutate in place according to the security model. """ + self.assign_vector(self.get_vector().secure_permute(*args, **kwargs)) + def randomize(self, *args): """ Randomize according to data type. """ self.assign_vector(self.value_type.get_random(*args, size=len(self))) @@ -5570,15 +5626,26 @@ class Array(_vectorizable): """ return personal(player, self.create_from(self[:].reveal_to(player)._v)) - def sort(self, n_threads=None): + def sort(self, n_threads=None, batcher=False, n_bits=None): """ - Sort in place using Batchers' odd-even merge mergesort - with complexity :math:`O(n (\log n)^2)`. + Sort in place using radix sort with complexity :math:`O(n \log + n)` for :py:class:`sint` and :py:class:`sfix`, and Batcher's + odd-even mergesort with :math:`O(n (\log n)^2)` for + :py:class:`sfloat`. :param n_threads: number of threads to use (single thread by - default) + default), need to use Batcher's algorithm for several threads + :param batcher: use Batcher's odd-even mergesort in any case + :param n_bits: number of bits in keys (default: global bit length) """ - library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + if batcher or self.value_type.n_elements() > 1: + library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + else: + if n_threads or 1 > 1: + raise CompilerError('multi-threaded sorting only implemented ' + 'with Batcher\'s odd-even mergesort') + import sorting + sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): # compatibility with registers @@ -5619,6 +5686,8 @@ class SubMultiArray(_vectorizable): :return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise""" if isinstance(index, slice) and index == slice(None): return self.get_vector() + if isinstance(index, int) and index < 0: + index += self.sizes[0] key = program.curr_block, str(index) if key not in self.sub_cache: if util.is_constant(index) and \ @@ -5673,6 +5742,10 @@ class SubMultiArray(_vectorizable): def total_size(self): return reduce(operator.mul, self.sizes) * self.value_type.n_elements() + def part_size(self): + return reduce(operator.mul, self.sizes[1:]) * \ + self.value_type.n_elements() + def get_vector(self, base=0, size=None): """ Return vector with content. Not implemented for floating-point. @@ -5731,13 +5804,21 @@ class SubMultiArray(_vectorizable): :param slice: regint array """ + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(self.address + addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(self.address + addresses) + + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert len(slice) * part_size <= self.total_size() base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) addresses = slice.value_type.load_mem(base) * part_size + inc - return self.value_type.load_mem(self.address + addresses) + return addresses def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 @@ -6218,6 +6299,31 @@ class SubMultiArray(_vectorizable): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def secure_shuffle(self): + """ Securely shuffle rows (first index). """ + self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + + def secure_permute(self, permutation, reverse=False): + """ Securely permute rows (first index). """ + self.assign_vector(self.get_vector().secure_permute( + permutation, self.part_size(), reverse)) + + def sort(self, key_indices=None, n_bits=None): + """ Sort sub-arrays (different first index) in place. + + :param key_indices: indices to sorting keys, for example + ``(1, 2)`` to sort three-dimensional array ``a`` by keys + ``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length. + :param n_bits: number of bits in keys (default: global bit length) + + """ + if key_indices is None: + key_indices = (0,) * (len(self.sizes) - 1) + key_indices = (None,) + util.tuplify(key_indices) + import sorting + keys = self.get_vector_by_indices(*key_indices) + sorting.radix_sort(keys, self, n_bits=n_bits) + def randomize(self, *args): """ Randomize according to data type. """ if self.total_size() < program.options.budget: @@ -6334,6 +6440,18 @@ class Matrix(MultiArray): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + @staticmethod + def create_from(rows): + rows = list(rows) + if isinstance(rows[0], (list, tuple)): + t = type(rows[0][0]) + else: + t = type(rows[0]) + res = Matrix(len(rows), len(rows[0]), t) + for i in range(len(rows)): + res[i].assign(rows[i]) + return res + def get_column(self, index): """ Get column as vector. @@ -6344,6 +6462,9 @@ class Matrix(MultiArray): self.sizes[1]) return self.value_type.load_mem(addresses) + def get_columns(self): + return (self.get_column(i) for i in range(self.sizes[1])) + def get_column_by_row_indices(self, rows, column): assert self.value_type.n_elements() == 1 addresses = rows * self.sizes[1] + \ diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 4e1e4c4b..fc5571b1 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -47,17 +47,8 @@ inline void receive(client_socket* socket, octet* data, size_t len) #else typedef ssl_ctx client_ctx; +typedef ssl_socket client_socket; -class client_socket : public ssl_socket -{ -public: - client_socket(boost::asio::io_service& io_service, - boost::asio::ssl::context& ctx, int plaintext_socket, string other, - string me, bool client) : - ssl_socket(io_service, ctx, plaintext_socket, other, me, client) - { - } -}; #endif /** diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 1efe1e22..b0a28744 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -58,7 +58,8 @@ public: { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - for (unsigned int i = 0; i < this->size(); i++) + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) (*this)[i] += y[i]; return *this; } @@ -67,9 +68,11 @@ public: { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - AddableVector res(y.size()); - for (unsigned int i = 0; i < this->size(); i++) - res[i] = (*this)[i] - y[i]; + AddableVector res; + res.reserve(y.size()); + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) + res.push_back((*this)[i] - y[i]); return res; } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 9afef83c..00e05131 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -31,6 +31,12 @@ word check_pk_id(word a, word b) } +void Ciphertext::Scale() +{ + Scale(params->get_plaintext_modulus()); +} + + void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1) { if (c0.params!=c1.params) { throw params_mismatch(); } @@ -115,9 +121,28 @@ void Ciphertext::add(octetStream& os) *this += tmp; } +void Ciphertext::rerandomize(const FHE_PK& pk) +{ + Rq_Element tmp(*params); + SeededPRNG G; + vector r(params->FFTD()[0].m()); + bigint p = pk.p(); + assert(p != 0); + for (auto& x : r) + { + G.get(x, params->p0().numBits() - p.numBits() - 1); + x *= p; + } + tmp.from(r, 0); + Scale(); + cc0 += tmp; + auto zero = pk.encrypt(*params); + zero.Scale(pk.p()); + *this += zero; +} + template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); -template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); - - +template void mul(Ciphertext& ans, const Plaintext& a, + const Ciphertext& c); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index d455f126..11a23e2a 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -15,6 +15,12 @@ template void mul(Ciphertext& ans,const Ciphertext& c, void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1); void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk); +/** + * BGV ciphertext. + * The class allows adding two ciphertexts as well as adding a plaintext and + * a ciphertext via operator overloading. The multiplication of two ciphertexts + * requires the public key and thus needs a separate function. + */ class Ciphertext { Rq_Element cc0,cc1; @@ -54,6 +60,7 @@ class Ciphertext // Scale down an element from level 1 to level 0, if at level 0 do nothing void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); } + void Scale(); // Throws error if ans,c0,c1 etc have different params settings // - Thus programmer needs to ensure this rather than this being done @@ -90,6 +97,12 @@ class Ciphertext template Ciphertext& operator*=(const Plaintext_& other) { ::mul(*this, *this, other); return *this; } + /** + * Ciphertext multiplication. + * @param pk public key + * @param x second ciphertext + * @returns product ciphertext + */ Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const { Ciphertext res(*params); ::mul(res, *this, x, pk); return res; } @@ -98,14 +111,18 @@ class Ciphertext return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this}; } + /// Re-randomize for circuit privacy. + void rerandomize(const FHE_PK& pk); + int level() const { return cc0.level(); } - // pack/unpack (like IO) also assume params are known and already set - // correctly + /// Append to buffer void pack(octetStream& o) const { cc0.pack(o); cc1.pack(o); o.store(pk_id); } - void unpack(octetStream& o) - { cc0.unpack(o); cc1.unpack(o); o.get(pk_id); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& o) + { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } void output(ostream& s) const { cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); } diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp index 9cc1a084..958cd28c 100644 --- a/FHE/Diagonalizer.cpp +++ b/FHE/Diagonalizer.cpp @@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag( { auto& c = products.at(i); for (int j = 0; j < n_matrices; j++) + { + res.at(j).entries.init(); for (size_t k = 0; k < n_rows; k++) res.at(j)[{k, i}] = c.element(j * n_rows + k); + } } return res; } diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index c71a4c5d..d3b67b50 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -7,6 +7,11 @@ +FFT_Data::FFT_Data() : + twop(-1) +{ +} + void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { R=Rg; diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index c5d6b206..4fb37ed4 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -50,7 +50,7 @@ class FFT_Data void pack(octetStream& o) const; void unpack(octetStream& o); - FFT_Data() { ; } + FFT_Data(); FFT_Data(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 20dfb1bb..742c8545 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -12,6 +12,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) { } +FHE_SK::FHE_SK(const FHE_Params& pms) : + FHE_SK(pms, pms.get_plaintext_modulus()) +{ +} + FHE_SK& FHE_SK::operator+=(const FHE_SK& c) { @@ -38,6 +43,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G) } +FHE_PK::FHE_PK(const FHE_Params& pms) : + FHE_PK(pms, pms.get_plaintext_modulus()) +{ +} + Rq_Element FHE_PK::sample_secret_key(PRNG& G) { Rq_Element sk = FHE_SK(*this).s(); @@ -179,32 +189,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext& template Ciphertext FHE_PK::encrypt( const Plaintext& mess) const +{ + return encrypt(Rq_Element(*params, mess)); +} + +Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const { Random_Coins rc(*params); PRNG G; G.ReSeed(); rc.generate(G); - return encrypt(mess, rc); + Ciphertext res(*params); + quasi_encrypt(res, mess, rc); + return res; } template void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const { - if (&c.get_params()!=params) { throw params_mismatch(); } if (T::characteristic_two ^ (pr == 2)) throw pr_mismatch(); + Rq_Element ans = quasi_decrypt(c); + mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); +} + +Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const +{ + if (&c.get_params()!=params) { throw params_mismatch(); } + Rq_Element ans; mul(ans,c.c1(),sk); sub(ans,c.c0(),ans); ans.change_rep(polynomial); - mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); + return ans; } +Plaintext_ FHE_SK::decrypt(const Ciphertext& c) +{ + return decrypt(c, params->get_plaintext_field_data()); +} + template Plaintext FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD) { @@ -299,12 +328,12 @@ void FHE_PK::unpack(octetStream& o) o.consume((octet*) tag, 8); if (memcmp(tag, "PKPKPKPK", 8)) throw runtime_error("invalid serialization of public key"); - a0.unpack(o); - b0.unpack(o); + a0.unpack(o, *params); + b0.unpack(o, *params); if (params->n_mults() > 0) { - Sw_a.unpack(o); - Sw_b.unpack(o); + Sw_a.unpack(o, *params); + Sw_b.unpack(o, *params); } pr.unpack(o); } @@ -322,7 +351,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const return false; } - void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const { @@ -345,8 +373,6 @@ void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) throw runtime_error("incorrect key pair"); } - - void FHE_PK::check(const FHE_Params& params, const bigint& pr) const { if (this->pr != pr) @@ -361,6 +387,24 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const } } +bigint FHE_SK::get_noise(const Ciphertext& c) +{ + sk.lower_level(); + Ciphertext cc = c; + if (cc.level()) + cc.Scale(); + Rq_Element tmp = quasi_decrypt(cc); + bigint res; + bigint q = tmp.get_modulus(); + bigint half_q = q / 2; + for (auto& x : tmp.to_vec_bigint()) + { +// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " "; + res = max(res, x > half_q ? x - q : x); + } + return res; +} + template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 30ecc292..f342e203 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -12,6 +12,10 @@ class FHE_PK; class Ciphertext; +/** + * BGV secret key. + * The class allows addition. + */ class FHE_SK { Rq_Element sk; @@ -29,6 +33,8 @@ class FHE_SK // secret key always on lower level void assign(const Rq_Element& s) { sk=s; sk.lower_level(); } + FHE_SK(const FHE_Params& pms); + FHE_SK(const FHE_Params& pms, const bigint& p) : sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } @@ -38,8 +44,11 @@ class FHE_SK const Rq_Element& s() const { return sk; } + /// Append to buffer void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } - void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } // Assumes Ring and prime of mess have already been set correctly // Ciphertext c must be at level 0 or an error occurs @@ -50,9 +59,14 @@ class FHE_SK template Plaintext decrypt(const Ciphertext& c, const FD& FieldD); + /// Decryption for cleartexts modulo prime + Plaintext_ decrypt(const Ciphertext& c); + template void decrypt_any(Plaintext_& mess, const Ciphertext& c); + Rq_Element quasi_decrypt(const Ciphertext& c) const; + // Three stage procedure for Distributed Decryption // - First stage produces my shares // - Second stage adds in another players shares, do this once for each other player @@ -62,7 +76,6 @@ class FHE_SK void dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_number,int num_players) const; void dist_decrypt_2(vector& vv,const vector& vv1) const; - friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); /* Add secret keys @@ -82,10 +95,15 @@ class FHE_SK template void check(const FHE_PK& pk, const FD& FieldD); + bigint get_noise(const Ciphertext& c); + friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; } }; +/** + * BGV public key. + */ class FHE_PK { Rq_Element a0,b0; @@ -104,8 +122,10 @@ class FHE_PK ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } - - FHE_PK(const FHE_Params& pms, const bigint& p = 0) + + FHE_PK(const FHE_Params& pms); + + FHE_PK(const FHE_Params& pms, const bigint& p) : a0(pms.FFTD(),evaluation,evaluation), b0(pms.FFTD(),evaluation,evaluation), Sw_a(pms.FFTD(),evaluation,evaluation), @@ -143,8 +163,11 @@ class FHE_PK template Ciphertext encrypt(const Plaintext& mess, const Random_Coins& rc) const; + + /// Encryption template Ciphertext encrypt(const Plaintext& mess) const; + Ciphertext encrypt(const Rq_Element& mess) const; friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); @@ -156,8 +179,10 @@ class FHE_PK void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; - // params setting is done out of these IO/pack/unpack functions + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); bool operator!=(const FHE_PK& x) const; @@ -170,21 +195,39 @@ class FHE_PK void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); +/** + * BGV key pair + */ class FHE_KeyPair { public: + /// Public key FHE_PK pk; + /// Secret key FHE_SK sk; - FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) : + FHE_KeyPair(const FHE_Params& params, const bigint& pr) : pk(params, pr), sk(params, pr) { } + /// Initialization + FHE_KeyPair(const FHE_Params& params) : + pk(params), sk(params) + { + } + void generate(PRNG& G) { KeyGen(pk, sk, G); } + + /// Generate fresh keys + void generate() + { + SeededPRNG G; + generate(G); + } }; template diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 5fb07f23..5a0f3991 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -1,5 +1,6 @@ #include "FHE_Params.h" +#include "NTL-Subs.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" #include "Protocols/HemiOptions.h" @@ -67,6 +68,7 @@ void FHE_Params::pack(octetStream& o) const Bval.pack(o); o.store(sec_p); o.store(matrix_dim); + fd.pack(o); } void FHE_Params::unpack(octetStream& o) @@ -80,6 +82,7 @@ void FHE_Params::unpack(octetStream& o) Bval.unpack(o); o.get(sec_p); o.get(matrix_dim); + fd.unpack(o); } bool FHE_Params::operator!=(const FHE_Params& other) const @@ -92,3 +95,37 @@ bool FHE_Params::operator!=(const FHE_Params& other) const else return false; } + +void FHE_Params::basic_generation_mod_prime(int plaintext_length) +{ + if (n_mults() == 0) + generate_semi_setup(plaintext_length, 0, *this, fd, false); + else + { + Parameters parameters(1, plaintext_length, 0); + parameters.generate_setup(*this, fd); + } +} + +template<> +const FFT_Data& FHE_Params::get_plaintext_field_data() const +{ + return fd; +} + +template<> +const P2Data& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +template<> +const PPData& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +bigint FHE_Params::get_plaintext_modulus() const +{ + return fd.get_prime(); +} diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8821e2e2..4733245c 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -15,6 +15,9 @@ #include "Tools/random.h" #include "Protocols/config.h" +/** + * Cryptosystem parameters + */ class FHE_Params { protected: @@ -29,8 +32,15 @@ class FHE_Params bigint Bval; int matrix_dim; + FFT_Data fd; + public: + /** + * Initialization. + * @param n_mults number of ciphertext multiplications (0/1) + * @param drown_sec parameter for function privacy (default 40) + */ FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } @@ -59,10 +69,24 @@ class FHE_Params int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer void unpack(octetStream& o); bool operator!=(const FHE_Params& other) const; + + /** + * Generate parameter for computation modulo a prime + * @param plaintext_length bit length of prime + */ + void basic_generation_mod_prime(int plaintext_length); + + template + const FD& get_plaintext_field_data() const; + + bigint get_plaintext_modulus() const; }; #endif diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 22705bed..794e7431 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -107,10 +107,12 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up) { +#ifdef VERBOSE cout << "Need ciphertext modulus of length " << lgp0; if (params.n_mults() > 0) cout << "+" << lgp1; cout << " and " << phi_N(m) << " slots" << endl; +#endif int extra_slack = 0; if (round_up) @@ -125,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, } extra_slack = i - 1; lgp0 += extra_slack; +#ifdef VERBOSE cout << "Rounding up to " << lgp0 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } Ring R; @@ -148,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool round_up, FHE_Params& params) { + (void) lg2pi, (void) n; + +#ifdef VERBOSE if (n >= 2 and n <= 10) cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; +#endif int extra_slack = 0; if (round_up) @@ -171,11 +179,15 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, extra_slack = 2 * i; lg2p0 += i; lg2p1 += i; +#ifdef VERBOSE cout << "Rounding up to " << lg2p0 << "+" << lg2p1 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } +#ifdef VERBOSE cout << "Total length: " << lg2p0 + lg2p1 << endl; +#endif return extra_slack; } @@ -215,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, { double phi_m_bound = NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1); + +#ifdef VERBOSE cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl; +#endif + if (phi_N(m) < phi_m_bound) { int old_m = m; + (void) old_m; m = 2 << int(ceil(log2(phi_m_bound))); + +#ifdef VERBOSE cout << "m = " << old_m << " too small, increasing it to " << m << endl; +#endif + generate_prime(p, numBits(p), m); } else @@ -244,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p, void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, const string& i, const bigint& pr0) { + (void) i; + if (lg2pr==0) { throw invalid_params(); } bigint step=m; @@ -260,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, assert(numBits(pr) == lg2pr); } +#ifdef VERBOSE cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl; + cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; +#endif assert(pr % m == 1); assert(pr % p == 1); assert(numBits(pr) == lg2pr); - - cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; } /* @@ -626,6 +650,9 @@ void char_2_dimension(int& m, int& lg2) case 16: m = 4369; break; + case 15: + m = 4681; + break; case 12: m = 4095; break; diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index e2df9583..f4502317 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -167,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1) bigint NoiseBounds::min_p1() { - return drown * B_KS + 1; + return max(bigint(drown * B_KS), bigint((phi_m * p) << 10)); } bigint NoiseBounds::opt_p1() @@ -181,8 +181,10 @@ bigint NoiseBounds::opt_p1() // solve mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a); bigint res = ceil(s); +#ifdef VERBOSE cout << "Optimal p1 vs minimal: " << numBits(res) << "/" << numBits(min_p1()) << endl; +#endif return res; } @@ -194,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1) { min_p0 *= 2; min_p1 *= 2; +#ifdef VERBOSE cout << "increasing lengths: " << numBits(min_p0) << "/" << numBits(min_p1) << endl; +#endif } lg2p1 = numBits(min_p1); lg2p0 = numBits(min_p0); diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index ccd50808..565c663e 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -42,6 +42,8 @@ public: bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } static double min_phi_m(int log_q, double sigma); static double min_phi_m(int log_q, const FHE_Params& params); + + bigint get_B_clean() { return B_clean; } }; // as per ePrint 2012:642 for slack = 0 diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 7d9a8ca4..ac4ae6f1 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -55,13 +55,13 @@ void P2Data::check_dimensions() const // cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl; if (A.size() != Ai.size()) throw runtime_error("forward and backward mapping dimensions mismatch"); - if (A.size() != A[0].size()) + if (A.size() != A.at(0).size()) throw runtime_error("forward mapping not square"); - if (Ai.size() != Ai[0].size()) + if (Ai.size() != Ai.at(0).size()) throw runtime_error("backward mapping not square"); - if ((int)A[0].size() != slots * gf2n_short::degree()) + if ((int)A.at(0).size() != slots * gf2n_short::degree()) throw runtime_error( - "mapping dimension incorrect: " + to_string(A[0].size()) + "mapping dimension incorrect: " + to_string(A.at(0).size()) + " != " + to_string(slots) + " * " + to_string(gf2n_short::degree())); } diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 84cbb9d1..4eba6e8f 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -11,10 +11,43 @@ +template +Plaintext::Plaintext(const FHE_Params& params) : + Plaintext(params.get_plaintext_field_data(), Both) +{ +} + + +template +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).phi_m(); +} + +template +int Plaintext::degree() const +{ + return (*Field_Data).phi_m(); +} + + +template<> +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).num_slots(); +} + +template<> +int Plaintext::degree() const +{ + return (*Field_Data).degree(); +} + + template<> void Plaintext::from(const Generator& source) const { - b.resize(degree); + b.resize(degree()); for (auto& x : b) { source.get(bigint::tmp); @@ -31,7 +64,7 @@ void Plaintext::from_poly() const Ring_Element e(*Field_Data,polynomial); e.from(b); e.change_rep(evaluation); - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::from_poly() const for (unsigned int i=0; iget_prD()}; type=Both; @@ -90,7 +123,7 @@ template<> void Plaintext::from_poly() const { if (type!=Polynomial) { return; } - a.resize(n_slots); + a.resize(num_slots()); (*Field_Data).backward(a,b); type=Both; } @@ -106,34 +139,13 @@ void Plaintext::to_poly() const -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).num_slots(); - degree = (*Field_Data).degree(); -} - - template void Plaintext::allocate(PT_Type type) const { if (type != Evaluation) - b.resize(degree); + b.resize(degree()); if (type != Polynomial) - a.resize(n_slots); + a.resize(num_slots()); this->type = type; } @@ -141,7 +153,7 @@ void Plaintext::allocate(PT_Type type) const template void Plaintext::allocate_slots(const bigint& value) { - b.resize(degree); + b.resize(degree()); for (auto& x : b) x.allocate_slots(value); } @@ -236,7 +248,7 @@ void Plaintext::randomize(PRNG& G,condition cond) type=Polynomial; break; case Diagonal: - a.resize(n_slots); + a.resize(num_slots()); a[0].randomize(G); for (unsigned int i=1; i::randomize(PRNG& G,condition cond) break; default: // Gen a plaintext with 0/1 in each slot - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t) b[0].generateUniform(G, n_bits, false); } else - for (int i = 0; i < n_slots; i++) + for (size_t i = 0; i < num_slots(); i++) b[i].generateUniform(G, n_bits, false); break; default: @@ -288,7 +300,7 @@ void Plaintext::assign_zero(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::assign_one(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::equals(const Plaintext& x) const if (type!=Polynomial and x.type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::unpack(octetStream& o) unsigned int size; o.get(size); allocate(); - if (size != b.size()) + if (size != b.size() and size != 0) throw length_error("unexpected length received"); - for (unsigned int i = 0; i < b.size(); i++) + for (unsigned int i = 0; i < size; i++) b[i] = o.get(); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index 52ff8b6d..c8fb93c7 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -18,6 +18,7 @@ */ #include "FHE/Generator.h" +#include "FHE/FFT_Data.h" #include "Math/fixint.h" #include @@ -25,6 +26,8 @@ using namespace std; class FHE_PK; class Rq_Element; +class FHE_Params; +class FFT_Data; template class AddableVector; // Forward declaration as apparently this is needed for friends in templates @@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits }; enum PT_Type { Polynomial, Evaluation, Both }; +/** + * BGV plaintext. + * Use ``Plaintext_mod_prime`` instead of filling in the templates. + * The plaintext is held in one of the two representations or both, + * polynomial and evaluation. The latter is the one allowing element-wise + * multiplication over a vector. + * Plaintexts can be added, subtracted, and multiplied via operator overloading. + */ template class Plaintext { typedef typename FD::poly_type S; - int n_slots; - int degree; mutable vector a; // The thing in evaluation/FFT form mutable vector b; // Now in polynomial form @@ -58,33 +67,47 @@ class Plaintext const FD *Field_Data; - void set_sizes(); + int degree() const; public: const FD& get_field() const { return *Field_Data; } - unsigned int num_slots() const { return n_slots; } + + /// Number of slots + unsigned int num_slots() const; Plaintext(const FD& FieldD, PT_Type type = Polynomial) - { Field_Data=&FieldD; set_sizes(); allocate(type); } + { Field_Data=&FieldD; allocate(type); } Plaintext(const FD& FieldD, const Rq_Element& other); + /// Initialization + Plaintext(const FHE_Params& params); + void allocate(PT_Type type) const; void allocate() const { allocate(type); } void allocate_slots(const bigint& value); int get_min_alloc(); - // Access evaluation representation + /** + * Read slot. + * @param i slot number + * @returns slot content + */ T element(int i) const { if (type==Polynomial) { from_poly(); } return a[i]; } + /** + * Write to slot + * @param i slot number + * @param e new slot content + */ void set_element(int i,const T& e) { if (type==Polynomial) { throw not_implemented(); } - a.resize(n_slots); + a.resize(num_slots()); a[i]=e; type=Evaluation; } @@ -171,10 +194,10 @@ class Plaintext bool is_diagonal() const; - /* Pack and unpack into an octetStream - * For unpack we assume the FFTD has been assigned correctly already - */ + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); size_t report_size(ReportType type); @@ -185,4 +208,6 @@ class Plaintext template using Plaintext_ = Plaintext; +typedef Plaintext_ Plaintext_mod_prime; + #endif diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index c1c318b8..3b63f306 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o) o.get(pi_inv); o.get(poly); } - else + else if (mm != 0) init(*this, mm); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 554d4dc1..39690fa6 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -87,7 +87,6 @@ void Ring_Element::negate() void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { - if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) { @@ -100,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) return; } + if (a.rep!=b.rep) { throw rep_mismatch(); } + if (&ans == &a) { ans += b; diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index 531df90f..d6a14aab 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -5,7 +5,7 @@ #include "Math/modp.hpp" Rq_Element::Rq_Element(const FHE_PK& pk) : - Rq_Element(pk.get_params().FFTD()) + Rq_Element(pk.get_params().FFTD(), evaluation, evaluation) { } @@ -347,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const return sz; } +void Rq_Element::unpack(octetStream& o, const FHE_Params& params) +{ + set_data(params.FFTD()); + unpack(o); +} + void Rq_Element::print_first_non_zero() const { vector v = to_vec_bigint(); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index a58cb7de..4e0cdf97 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -69,8 +69,9 @@ protected: a({b0}), lev(n_mults()) {} template - Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : - Rq_Element(params) + Rq_Element(const FHE_Params& params, const Plaintext& plaintext, + RepType r0 = polynomial, RepType r1 = polynomial) : + Rq_Element(params, r0, r1) { from(plaintext.get_iterator()); } @@ -159,6 +160,9 @@ protected: void pack(octetStream& o) const; void unpack(octetStream& o); + // without prior initialization + void unpack(octetStream& o, const FHE_Params& params); + void output(ostream& s) const; void input(istream& s); diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 43ad7e84..3df98c85 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -57,7 +57,7 @@ void Multiplier::multiply_and_add(Plaintext_& res, template void Multiplier::add(Plaintext_& res, const Ciphertext& c, - OT_ROLE role, int n_summands) + OT_ROLE role, int) { o.reset_write_head(); @@ -67,20 +67,10 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, G.ReSeed(); timers["Mask randomization"].start(); product_share.randomize(G); - bigint B = 6 * machine.setup().params.get_R(); - B *= machine.setup().FieldD.get_prime(); - B <<= machine.setup().params.secp(); - // slack - B *= NonInteractiveProof::slack(machine.sec, - machine.setup().params.phi_m()); - B <<= machine.extra_slack; - B *= n_summands; - rc.generateUniform(G, 0, B, B); + mask = c; + mask.rerandomize(other_pk); timers["Mask randomization"].stop(); - timers["Encryption"].start(); - other_pk.encrypt(mask, product_share, rc); - timers["Encryption"].stop(); - mask += c; + mask += product_share; mask.pack(o); res -= product_share; } diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 59223ad0..01971182 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -75,6 +75,8 @@ void secure_init(T& setup, Player& P, U& machine, + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); + string reason; + try { ifstream file(filename); @@ -82,12 +84,30 @@ void secure_init(T& setup, Player& P, U& machine, os.input(file); os.get(machine.extra_slack); setup.unpack(os); + } + catch (exception& e) + { + reason = e.what(); + } + + try + { setup.check(P, machine); } - catch (...) + catch (exception& e) { - cout << "Finding parameters for security " << sec << " and field size ~2^" - << plaintext_length << endl; + reason = e.what(); + } + + if (not reason.empty()) + { + if (OnlineOptions::singleton.verbose) + cerr << "Generating parameters for security " << sec + << " and field size ~2^" << plaintext_length + << " because no suitable material " + "from a previous run was found (" << reason << ")" + << endl; + setup = {}; setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; diff --git a/GC/NoShare.h b/GC/NoShare.h index 49f93ac4..917e71c5 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -50,11 +50,6 @@ public: return "no"; } - static string type_short() - { - return "no"; - } - static DataFieldType field_type() { throw not_implemented(); @@ -66,7 +61,7 @@ public: static void fail() { - throw runtime_error("VM does not support binary circuits"); + throw runtime_error("functionality not available"); } NoValue() {} @@ -143,6 +138,11 @@ public: return 0; } + static int length() + { + return 0; + } + static void fail() { NoValue::fail(); diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp index 4bc8fab1..890a24ab 100644 --- a/Machines/dealer-ring-party.cpp +++ b/Machines/dealer-ring-party.cpp @@ -5,6 +5,7 @@ #include "Protocols/DealerShare.h" #include "Protocols/DealerInput.h" +#include "Protocols/Dealer.h" #include "Processor/RingMachine.hpp" #include "Processor/Machine.hpp" @@ -12,6 +13,7 @@ #include "Protocols/DealerPrep.hpp" #include "Protocols/DealerInput.hpp" #include "Protocols/DealerMC.hpp" +#include "Protocols/DealerMatrixPrep.hpp" #include "Protocols/Beaver.hpp" #include "Semi.hpp" #include "GC/DealerPrep.h" diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp index f270b87c..87bf15ea 100644 --- a/Machines/mama-party.cpp +++ b/Machines/mama-party.cpp @@ -21,5 +21,5 @@ using MamaShare_ = MamaShare; int main(int argc, const char** argv) { ez::ezOptionParser opt; - DishonestMajorityFieldMachine(argc, argv, opt); + DishonestMajorityFieldMachine(argc, argv, opt); } diff --git a/Makefile b/Makefile index 3c2be009..03366f89 100644 --- a/Makefile +++ b/Makefile @@ -244,6 +244,7 @@ paper-example.x: $(VM) $(OT) $(FHEOFFLINE) binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o l2h-example.x: $(VM) $(OT) Machines/Tinier.o +he-example.x: $(FHEOFFLINE) mascot-offline.x: $(VM) $(TINIER) cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index c0b2373e..489ec5ae 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -24,7 +24,12 @@ public: typedef T value_type; typedef FixedVec Scalar; - static const int length = L; + static const int vector_length = L; + + static int length() + { + return L * T::length(); + } static int size() { diff --git a/Math/Setup.cpp b/Math/Setup.cpp index dc76e47d..715d480d 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -136,7 +136,7 @@ void write_online_setup(string dirname, const bigint& p) if (mkdir_p(ss.str().c_str()) == -1) { cerr << "mkdir_p(" << ss.str() << ") failed\n"; - throw file_error(ss.str()); + throw file_error("cannot create " + dirname); } // Output the data @@ -167,6 +167,6 @@ string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, res += "-" + to_string(log2mod); res += "/"; if (mkdir_p(res.c_str()) < 0) - throw file_error(res); + throw file_error("cannot create " + res); return res; } diff --git a/Math/Z2k.h b/Math/Z2k.h index cdde3f40..e8d2ba53 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -439,6 +439,12 @@ void Z2::randomize(PRNG& G, int n) template void Z2::randomize_part(PRNG& G, int n) { + if (n >= N_BITS) + { + randomize(G); + return; + } + *this = {}; G.get_octets((octet*)a, DIV_CEIL(n, 8)); a[DIV_CEIL(n, 64) - 1] &= mp_limb_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 876aef93..ef2f84c9 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -67,7 +67,10 @@ Z2::Z2(const IntBase& x) : template bool Z2::get_bit(int i) const { - return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + if (i < N_BITS) + return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + else + return false; } template diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 17fcdf24..9dd0b7f0 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -174,7 +174,8 @@ void Zp_Data::unpack(octetStream& o) int m; o.get(m); montgomery = m; - init(pr, m); + if (pr != 0) + init(pr, m); } bool Zp_Data::operator!=(const Zp_Data& other) const diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 44e42479..d39a8593 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -44,6 +44,19 @@ int fields_2[num_2_fields][4] = { 128, 7, 2, 1 }, }; +template +string gf2n_::options() +{ + string res = to_string(fields_2[0][0]); + for (int i = 1; i < num_2_fields; i++) + { + int n = fields_2[i][0]; + if (n <= MAX_N_BITS) + res += ", " + to_string(n); + } + return res; +} + template void gf2n_::init_tables() { @@ -113,7 +126,7 @@ void gf2n_::init_field(int nn) if (j==-1) { - throw gf2n_not_supported(nn); + throw gf2n_not_supported(nn, options()); } n=nn; diff --git a/Math/gf2n.h b/Math/gf2n.h index 3ec8849a..56377072 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -86,6 +86,8 @@ protected: static bool allows(Dtype type) { (void) type; return true; } + static string options(); + static const true_type invertible; static const true_type characteristic_two; @@ -154,6 +156,8 @@ protected: gf2n_ operator*(int x) const { return *this * gf2n_(x); } gf2n_ invert() const; + + gf2n_ operator-() const { return *this; } void negate() { return; } /* Bitwise Ops */ diff --git a/Math/gfpvar.h b/Math/gfpvar.h index a3b475f8..7d332fdd 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -107,6 +107,12 @@ public: a = other.get(); } + template + gfpvar_(const Z2& other) : + gfpvar_(bigint(other)) + { + } + void assign(const void* buffer); void assign_zero(); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h index 22482c48..3d6d1834 100644 --- a/Networking/AllButLastPlayer.h +++ b/Networking/AllButLastPlayer.h @@ -50,17 +50,12 @@ public: void Broadcast_Receive_no_stats(vector& os) const { - vector to_send(P.num_players(), os[P.my_num()]); - vector> channels(P.num_players(), - vector(P.num_players(), true)); - for (auto& x: channels) - x.back() = false; - channels.back() = vector(P.num_players(), false); - vector to_receive; - P.send_receive_all(channels, to_send, to_receive); - for (int i = 0; i < P.num_players() - 1; i++) - if (i != P.my_num()) - os[i] = to_receive[i]; + vector senders(P.num_players(), true), receivers(P.num_players(), + true); + senders.back() = false; + receivers.back() = false; + P.partial_broadcast(senders, receivers, os); + os.resize(num_players()); } }; diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 43b2ada5..faf8fda6 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -212,8 +212,8 @@ void CryptoPlayer::partial_broadcast(const vector& my_senders, for (int offset = 1; offset < num_players(); offset++) { int other = get_player(offset); - bool receive = my_senders[other]; - if (my_receivers[other]) + bool receive = my_senders.at(other); + if (my_receivers.at(other)) { this->senders[other]->request(os[my_num()]); sent += os[my_num()].get_length(); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index a7935f30..3a894214 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -811,14 +811,6 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const return res; } -size_t NamedCommStats::total_data() -{ - size_t res = 0; - for (auto& x : *this) - res += x.second.data; - return res; -} - void NamedCommStats::print(bool newline) { for (auto it = begin(); it != end(); it++) diff --git a/Networking/Player.h b/Networking/Player.h index a547d479..cf8579c0 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -157,7 +157,6 @@ public: NamedCommStats& operator+=(const NamedCommStats& other); NamedCommStats operator+(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const; - size_t total_data(); void print(bool newline = false); void reset(); #ifdef VERBOSE_COMM diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index dd746605..3d40e2ca 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -230,7 +230,7 @@ void Sub_Data_Files::prune() my_input_buffers.prune(); for (int j = 0; j < num_players; j++) input_buffers[j].prune(); - for (auto it : extended) + for (auto& it : extended) it.second.prune(); dabit_buffer.prune(); if (part != 0) diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 09c6e056..2eb8a63a 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -293,7 +293,7 @@ int InputBase::get_player(SubProcessor& Proc, int arg, bool player_from_re if (player_from_reg) { assert(Proc.Proc); - auto res = Proc.Proc->read_Ci(arg); + auto res = Proc.Proc->sync_Ci(arg); if (res >= Proc.P.num_players()) throw runtime_error("player id too large: " + to_string(res)); return res; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 5279b258..fd91e35d 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -13,6 +13,7 @@ using namespace std; template class Machine; template class Processor; +template class SubProcessor; class ArithmeticProcessor; class SwitchableOutput; @@ -107,6 +108,11 @@ enum CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + // Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -250,6 +256,7 @@ enum GMULS = 0x1A6, GMULRS = 0x1A7, GDOTPRODS = 0x1A8, + GSECSHUFFLE = 0x1FA, // Data access GTRIPLE = 0x150, GBIT = 0x151, @@ -388,6 +395,9 @@ public: template void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0, T* nan = 0) const; + + template + typename T::clear sanitize(SubProcessor& proc, int reg) const; }; #endif diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 2a5dce70..5bed3703 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -157,6 +157,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case LISTEN: case CLOSECLIENTCONNECTION: case CRASH: + case DELSHUFFLE: r[0]=get_int(s); break; // instructions with 2 registers + 1 integer operand @@ -203,6 +204,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DIGESTC: case INPUTMASK: case GINPUTMASK: + case SECSHUFFLE: + case GSECSHUFFLE: get_ints(r, s, 2); n = get_int(s); break; @@ -230,6 +233,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CONDPRINTSTR: case CONDPRINTSTRB: case RANDOMS: + case GENSECSHUFFLE: r[0]=get_int(s); n = get_int(s); break; @@ -269,6 +273,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 5 register operands case PRINTFLOATPLAIN: case PRINTFLOATPLAINB: + case APPLYSHUFFLE: get_vector(5, start, s); break; case INCINT: @@ -558,6 +563,7 @@ int BaseInstruction::get_reg_type() const case CONVCBITVEC: case INTOUTPUT: case ACCEPTCLIENTCONNECTION: + case GENSECSHUFFLE: return INT; case PREP: case GPREP: @@ -835,11 +841,13 @@ inline void Instruction::execute(Processor& Proc) const { for (int i = 0; i < size; i++) Proc.write_Ci(r[0] + i, - Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get()); + Proc.sync( + Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get())); } else if (n <= 64) for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, Integer(Proc.read_Cp(r[1] + i), n).get()); + Proc.write_Ci(r[0] + i, + Proc.sync(Integer(Proc.read_Cp(r[1] + i), n).get())); else throw Processor_Error(to_string(n) + "-bit conversion impossible; " "integer registers only have 64 bits"); @@ -856,40 +864,32 @@ inline void Instruction::execute(Processor& Proc) const n++; break; case LDMCI: - Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1]))); + Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.sync_Ci(r[1]))); break; case STMC: Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0])); n++; break; case STMCI: - Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0])); + Proc.machine.Mp.write_C(Proc.sync_Ci(r[1]), Proc.read_Cp(r[0])); break; case MOVC: Proc.write_Cp(r[0],Proc.read_Cp(r[1])); break; case DIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / Proc.read_Cp(r[2])); + Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2])); break; case GDIVC: - if (Proc.read_C2(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_C2(r[0], Proc.read_C2(r[1]) / Proc.read_C2(r[2])); + Proc.write_C2(r[0], Proc.read_C2(r[1]) / sanitize(Proc.Proc2, r[2])); break; case FLOORDIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); Proc.temp.aa.from_signed(Proc.read_Cp(r[1])); - Proc.temp.aa2.from_signed(Proc.read_Cp(r[2])); + Proc.temp.aa2.from_signed(sanitize(Proc.Procp, r[2])); Proc.write_Cp(r[0], bigint(Proc.temp.aa / Proc.temp.aa2)); break; case MODC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Modulo by zero from register"); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); - to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); + to_bigint(Proc.temp.aa2, sanitize(Proc.Procp, r[2])); mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -948,7 +948,7 @@ inline void Instruction::execute(Processor& Proc) const Procp.protocol.randoms_inst(Procp.get_S(), *this); return; case INPUTMASKREG: - Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2])); + Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.sync_Ci(r[2])); Proc.write_Cp(r[1], Proc.temp.rrp); break; case INPUTMASK: @@ -1034,7 +1034,7 @@ inline void Instruction::execute(Processor& Proc) const return; case MATMULSM: Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this, - Proc.read_Ci(r[1]), Proc.read_Ci(r[2])); + Proc.sync_Ci(r[1]), Proc.sync_Ci(r[2])); return; case CONV2DS: Proc.Procp.protocol.conv2ds(Proc.Procp, *this); @@ -1042,6 +1042,21 @@ inline void Instruction::execute(Processor& Proc) const case TRUNC_PR: Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); return; + case SECSHUFFLE: + Proc.Procp.secure_shuffle(*this); + return; + case GSECSHUFFLE: + Proc.Proc2.secure_shuffle(*this); + return; + case GENSECSHUFFLE: + Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this)); + return; + case APPLYSHUFFLE: + Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3))); + return; + case DELSHUFFLE: + Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); + return; case CHECK: { CheckJob job; @@ -1056,14 +1071,14 @@ inline void Instruction::execute(Processor& Proc) const Proc.PC += (signed int) n; break; case JMPI: - Proc.PC += (signed int) Proc.read_Ci(r[0]); + Proc.PC += (signed int) Proc.sync_Ci(r[0]); break; case JMPNZ: - if (Proc.read_Ci(r[0]) != 0) + if (Proc.sync_Ci(r[0]) != 0) { Proc.PC += (signed int) n; } break; case JMPEQZ: - if (Proc.read_Ci(r[0]) == 0) + if (Proc.sync_Ci(r[0]) == 0) { Proc.PC += (signed int) n; } break; case PRINTREG: @@ -1123,7 +1138,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.join_tape(r[0]); break; case CRASH: - if (Proc.read_Ci(r[0])) + if (Proc.sync_Ci(r[0])) throw crash_requested(); break; case STARTGRIND: @@ -1146,7 +1161,7 @@ inline void Instruction::execute(Processor& Proc) const // *** case LISTEN: // listen for connections at port number n - Proc.external_clients.start_listening(Proc.read_Ci(r[0])); + Proc.external_clients.start_listening(Proc.sync_Ci(r[0])); break; case ACCEPTCLIENTCONNECTION: { @@ -1335,4 +1350,15 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c out << "]"; } +template +typename T::clear Instruction::sanitize(SubProcessor& proc, int reg) const +{ + if (not T::real_shares(proc.P)) + return 1; + auto& res = proc.get_C_ref(reg); + if (res.is_zero()) + throw Processor_Error("Division by zero from register"); + return res; +} + #endif diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e0299c2f..ce90e1b2 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -30,7 +30,7 @@ void Machine::init_binary_domains(int security_parameter, int lg2) if (not is_same()) { - if (sgf2n::clear::degree() < security_parameter) + if (sgf2n::mac_key_type::length() < security_parameter) { cerr << "Security parameter needs to be at most n in GF(2^n)." << endl; @@ -469,7 +469,10 @@ void Machine::run(const string& progname) for (auto& x : comm_stats) rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds - << " rounds (party " << my_number << ")" << endl; + << " rounds (party " << my_number; + if (threads.size() > 1) + cerr << "; rounds counted double due to multi-threading"; + cerr << ")" << endl; auto& P = *this->P; Bundle bundle(P); diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index d4c66e9a..85ee25d0 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -36,7 +36,9 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - ("Bit length of GF(2^n) field (default: " + to_string(V::default_degree()) + ")").c_str(), // Help description. + ("Bit length of GF(2^n) field (default: " + + to_string(V::default_degree()) + "; options are " + + V::options() + ")").c_str(), // Help description. "-lg2", // Flag token. "--lg2" // Flag token. ); diff --git a/Processor/Processor.h b/Processor/Processor.h index 38ea7f25..927e9327 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -20,6 +20,7 @@ #include "Tools/CheckVector.h" #include "GC/Processor.h" #include "GC/ShareThread.h" +#include "Protocols/SecureShuffle.h" class Program; @@ -31,6 +32,8 @@ class SubProcessor DataPositions bit_usage; + SecureShuffle shuffler; + void resize(size_t size) { C.resize(size); S.resize(size); } template friend class Processor; @@ -70,6 +73,11 @@ public: size_t b); void conv2ds(const Instruction& instruction); + void secure_shuffle(const Instruction& instruction); + size_t generate_secure_shuffle(const Instruction& instruction); + void apply_shuffle(const Instruction& instruction, int handle); + void delete_shuffle(int handle); + void input_personal(const vector& args); void send_personal(const vector& args); void private_output(const vector& args); @@ -127,6 +135,10 @@ public: ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), sent(0), rounds(0), opts(opts) {} + virtual ~ArithmeticProcessor() + { + } + bool use_stdin() { return thread_num == 0 and opts.interactive; @@ -146,6 +158,11 @@ public: CheckVector& get_Ci() { return Ci; } + virtual long sync_Ci(size_t) const + { + throw not_implemented(); + } + void shuffle(const Instruction& instruction); void bitdecint(const Instruction& instruction); }; @@ -241,6 +258,10 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); + // synchronize in asymmetric protocols + long sync_Ci(size_t i) const; + long sync(long x) const; + private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d74594b3..861e8cfe 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -9,6 +9,7 @@ #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" +#include "Protocols/SecureShuffle.hpp" #include #include @@ -23,6 +24,7 @@ SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& template SubProcessor::SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, ArithmeticProcessor* Proc) : + shuffler(*this), Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC), bit_prep(bit_usage) { @@ -340,6 +342,9 @@ void Processor::read_socket_private(int client_id, // Tolerent to no file if no shares yet persisted. template void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename; filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; @@ -370,6 +375,9 @@ template void Processor::write_shares_to_file(long start_pos, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -633,6 +641,33 @@ void SubProcessor::conv2ds(const Instruction& instruction) } } +template +void SubProcessor::secure_shuffle(const Instruction& instruction) +{ + SecureShuffle(S, instruction.get_size(), instruction.get_n(), + instruction.get_r(0), instruction.get_r(1), *this); +} + +template +size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction) +{ + return shuffler.generate(instruction.get_n()); +} + +template +void SubProcessor::apply_shuffle(const Instruction& instruction, int handle) +{ + shuffler.apply(S, instruction.get_size(), instruction.get_start()[2], + instruction.get_start()[0], instruction.get_start()[1], handle, + instruction.get_start()[4]); +} + +template +void SubProcessor::delete_shuffle(int handle) +{ + shuffler.del(handle); +} + template void SubProcessor::input_personal(const vector& args) { @@ -690,4 +725,25 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } +template +long Processor::sync_Ci(size_t i) const +{ + return sync(read_Ci(i)); +} + +template +long Processor::sync(long x) const +{ + if (not sint::symmetric) + { + // send number to dealer + if (P.my_num() == 0) + P.send_long(P.num_players() - 1, x); + if (not sint::real_shares(P)) + return P.receive_long(0); + } + + return x; +} + #endif diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index 62694221..8527f98f 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -50,7 +50,10 @@ RingMachine::RingMachine(int argc, const char** argv, case L: \ machine.template run, V>(); \ break; - X(64) X(72) X(128) X(192) + X(64) +#ifndef FEWER_RINGS + X(72) X(128) X(192) +#endif #ifdef RING_SIZE X(RING_SIZE) #endif diff --git a/Programs/Source/dijkstra_example.mpc b/Programs/Source/dijkstra_example.mpc new file mode 100644 index 00000000..950fe331 --- /dev/null +++ b/Programs/Source/dijkstra_example.mpc @@ -0,0 +1,50 @@ +# example code for graph with vertices 0,1,2 and with following weights +# 0 -> 1: 5 +# 0 -> 2: 20 +# 1 -> 2: 10 + +# output should be the following +# from 0 to 0 at cost 0 via vertex 0 +# from 0 to 1 at cost 5 via vertex 0 +# from 0 to 2 at cost 15 via vertex 1 + +from oram import OptimalORAM +from dijkstra import dijkstra + +# structure for edges +# contains tuples of form (neighbor, cost, last neighbor bit) +edges = OptimalORAM(4, # number of edges + entry_size=(2, # enough bits for vertices + 5, # enough bits for costs + 1) # always one +) + +# first edge from vertex 0 +edges[0] = (1, 5, 0) +# second and last edge from vertex 0 +edges[1] = (2, 20, 1) +# edge from vertex 1 +edges[2] = (2, 10, 1) +# dummy edge from vertex 2 to itself +edges[3] = (2, 0, 1) + +# structure assigning edge list indices to vertices +e_index = OptimalORAM(3, # number vertices + entry_size=2) # enough bits for edge indices + +# edges from 0 start at 0 +e_index[0] = 0 +# edges from 1 start at 2 +e_index[1] = 2 +# edges from 2 start at 3 +e_index[2] = 3 + +source = sint(0) + +res = dijkstra(source, edges, e_index, OptimalORAM) + +@for_range(res.size) +def _(i): + import util + print_ln('from %s to %s at cost %s via vertex %s', source.reveal(), i, + res[i][0].reveal(), res[i][1].reveal()) diff --git a/Programs/Source/dijkstra_tutorial.mpc b/Programs/Source/dijkstra_tutorial.mpc deleted file mode 100644 index 7ab22023..00000000 --- a/Programs/Source/dijkstra_tutorial.mpc +++ /dev/null @@ -1,9 +0,0 @@ -import dijkstra -from path_oram import OptimalORAM - -n = 1000 - -dist = dijkstra.test_dijkstra_on_cycle(n, OptimalORAM) - -for i in range(n): - print_ln('%s: %s', i, dist[i][0].reveal()) diff --git a/Protocols/Dealer.h b/Protocols/Dealer.h new file mode 100644 index 00000000..cc2c45ba --- /dev/null +++ b/Protocols/Dealer.h @@ -0,0 +1,36 @@ +/* + * Dealer.h + * + */ + +#ifndef PROTOCOLS_DEALER_H_ +#define PROTOCOLS_DEALER_H_ + +#include "Beaver.h" + +template +class Dealer : public Beaver +{ + SeededPRNG G; + +public: + Dealer(Player& P) : + Beaver(P) + { + } + + T get_random() + { + if (T::real_shares(this->P)) + return G.get(); + else + return {}; + } + + vector get_relevant_players() + { + return vector(1, this->P.num_players() - 1); + } +}; + +#endif /* PROTOCOLS_DEALER_H_ */ diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h index 7d0699da..7f0a26dd 100644 --- a/Protocols/DealerInput.h +++ b/Protocols/DealerInput.h @@ -24,6 +24,7 @@ public: DealerInput(SubProcessor& proc, typename T::MAC_Check&); DealerInput(typename T::MAC_Check&, Preprocessing&, Player& P); DealerInput(Player& P); + DealerInput(SubProcessor*, Player& P); ~DealerInput(); bool is_dealer(int player = -1); diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp index 26bfb9a1..8b1ea855 100644 --- a/Protocols/DealerInput.hpp +++ b/Protocols/DealerInput.hpp @@ -10,7 +10,7 @@ template DealerInput::DealerInput(SubProcessor& proc, typename T::MAC_Check&) : - DealerInput(proc.P) + DealerInput(&proc, proc.P) { } @@ -23,6 +23,13 @@ DealerInput::DealerInput(typename T::MAC_Check&, Preprocessing&, template DealerInput::DealerInput(Player& P) : + DealerInput(0, P) +{ +} + +template +DealerInput::DealerInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P), to_send(P), shares(P.num_players()), from_dealer(false), sub_player(P) { @@ -68,8 +75,8 @@ void DealerInput::add_mine(const typename T::open_type& input, if (is_dealer()) { make_share(shares.data(), input, P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(to_send[i]); + for (int i = 0; i < P.num_players() - 1; i++) + shares.at(i).pack(to_send[i]); from_dealer = true; } else diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index 5311f813..4e668136 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -25,6 +25,7 @@ public: void prepare_open(const T& secret); void exchange(const Player& P); typename T::open_type finalize_raw(); + array finalize_several(int n); DealerMC& get_part_MC() { diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index a9ddc035..0f63b93d 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -73,4 +73,11 @@ typename T::open_type DealerMC::finalize_raw() return {}; } +template +array DealerMC::finalize_several(int n) +{ + assert(sub_player); + return internal.finalize_several(n); +} + #endif /* PROTOCOLS_DEALERMC_HPP_ */ diff --git a/Protocols/DealerMatrixPrep.h b/Protocols/DealerMatrixPrep.h new file mode 100644 index 00000000..78739725 --- /dev/null +++ b/Protocols/DealerMatrixPrep.h @@ -0,0 +1,32 @@ +/* + * DealerMatrixPrep.h + * + */ + +#ifndef PROTOCOLS_DEALERMATRIXPREP_H_ +#define PROTOCOLS_DEALERMATRIXPREP_H_ + +#include "ShareMatrix.h" + +template +class DealerMatrixPrep : public BufferPrep> +{ + typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; + + int n_rows, n_inner, n_cols; + + LivePrep* prep; + +public: + DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep&, DataPositions& usage); + + void set_protocol(typename ShareMatrix::Protocol&) + { + } + + void buffer_triples(); +}; + +#endif /* PROTOCOLS_DEALERMATRIXPREP_H_ */ diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp new file mode 100644 index 00000000..faf98ec7 --- /dev/null +++ b/Protocols/DealerMatrixPrep.hpp @@ -0,0 +1,87 @@ +/* + * DealerMatrixPrep.hpp + * + */ + +#include "DealerMatrixPrep.h" + +template +DealerMatrixPrep::DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep& prep, DataPositions& usage) : + super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), + prep(&prep) +{ +} + +template +void append_shares(vector& os, + ValueMatrix& M, PRNG& G) +{ + size_t n = os.size(); + for (auto& value : M.entries) + { + T sum; + for (size_t i = 0; i < n - 2; i++) + { + auto share = G.get(); + sum += share; + share.pack(os[i]); + } + (value - sum).pack(os[n - 2]); + } +} + +template +ShareMatrix receive_shares(octetStream& o, int n, int m) +{ + ShareMatrix res(n, m); + for (size_t i = 0; i < res.entries.size(); i++) + res.entries.v.push_back(o.get()); + return res; +} + +template +void DealerMatrixPrep::buffer_triples() +{ + assert(this->prep); + assert(this->prep->proc); + auto& P = this->prep->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + int batch_size = 100; + if (not T::real_shares(P)) + { + SeededPRNG G; + ValueMatrix A(n_rows, n_inner), B(n_inner, n_cols), + C(n_rows, n_cols); + for (int i = 0; i < P.num_players() - 1; i++) + os[i].reserve( + batch_size * T::size() + * (A.entries.size() + B.entries.size() + + C.entries.size())); + for (int i = 0; i < batch_size; i++) + { + A.randomize(G); + B.randomize(G); + C = A * B; + append_shares(os, A, G); + append_shares(os, B, G); + append_shares(os, C, G); + this->triples.push_back({{{n_rows, n_inner}, {n_inner, n_cols}, + {n_rows, n_cols}}}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < batch_size; i++) + { + auto& o = to_receive.back(); + this->triples.push_back({{receive_shares(o, n_rows, n_inner), + receive_shares(o, n_inner, n_cols), + receive_shares(o, n_rows, n_cols)}}); + } + } +} diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h index ae28ec69..417fdbac 100644 --- a/Protocols/DealerPrep.h +++ b/Protocols/DealerPrep.h @@ -11,6 +11,13 @@ template class DealerPrep : virtual public BitPrep { + friend class DealerMatrixPrep; + + template + void buffer_inverses(true_type); + template + void buffer_inverses(false_type); + template void buffer_edabits(int n_bits, true_type); template @@ -23,8 +30,14 @@ public: } void buffer_triples(); + void buffer_inverses(); void buffer_bits(); + void buffer_inputs(int player) + { + this->buffer_inputs_as_usual(player, this->proc); + } + void buffer_dabits(ThreadQueues* = 0); void buffer_edabits(int n_bits, ThreadQueues*); void buffer_sedabits(int n_bits, ThreadQueues*); diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index d4a0a91d..cc010dd7 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -45,6 +45,57 @@ void DealerPrep::buffer_triples() } } +template +void DealerPrep::buffer_inverses() +{ + buffer_inverses(T::invertible); +} + +template +template +void DealerPrep::buffer_inverses(false_type) +{ + throw not_implemented(); +} + +template +template +void DealerPrep::buffer_inverses(true_type) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T tuple[2]; + while (tuple[0] == 0) + tuple[0] = G.get(); + tuple[1] = tuple[0].invert(); + for (auto& value : tuple) + { + make_share(shares.data(), typename T::clear(value), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + this->inverses.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->inverses.push_back(to_receive.back().get>().get()); + } +} + template void DealerPrep::buffer_bits() { diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h index 38900ff3..e59e1949 100644 --- a/Protocols/DealerShare.h +++ b/Protocols/DealerShare.h @@ -13,12 +13,16 @@ template class DealerPrep; template class DealerInput; template class DealerMC; template class DirectDealerMC; +template class DealerMatrixPrep; +template class Hemi; namespace GC { class DealerSecret; } +template class Dealer; + template class DealerShare : public SemiShare { @@ -30,22 +34,26 @@ public: typedef DealerMC MAC_Check; typedef DirectDealerMC Direct_MC; - typedef Beaver Protocol; + typedef Hemi Protocol; typedef DealerInput Input; typedef DealerPrep LivePrep; typedef ::PrivateOutput PrivateOutput; + typedef DealerMatrixPrep MatrixPrep; + typedef Dealer BasicProtocol; + static false_type dishonest_majority; const static bool needs_ot = false; + const static bool symmetric = false; static string type_short() { return "DD" + string(1, T::type_char()); } - static int threshold(int) + static bool real_shares(const Player& P) { - throw runtime_error("undefined threshold"); + return P.my_num() != P.num_players() - 1; } static This constant(const T& other, int my_num, diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index c0a269d1..e5bb9e9e 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -33,6 +33,7 @@ public: static const bool has_trunc_pr = true; static const bool dishonest_majority = false; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index f43260ea..0aa61bcb 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -13,22 +13,24 @@ * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public Semi +class Hemi : public T::BasicProtocol { - map, HemiMatrixPrep*> matrix_preps; + map, typename T::MatrixPrep*> matrix_preps; DataPositions matrix_usage; + MatrixMC mc; + ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); public: Hemi(Player& P) : - Semi(P) + T::BasicProtocol(P) { } ~Hemi(); - HemiMatrixPrep& get_matrix_prep(const array& dimensions, + typename T::MatrixPrep& get_matrix_prep(const array& dimensions, SubProcessor& processor); void matmulsm(SubProcessor& processor, CheckVector& source, diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 1b3d8f5b..1549e2cf 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -21,12 +21,12 @@ Hemi::~Hemi() } template -HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, +typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, SubProcessor& processor) { if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, - new HemiMatrixPrep(dims[0], dims[1], dims[2], + new typename T::MatrixPrep(dims[0], dims[1], dims[2], dynamic_cast(processor.DataF), matrix_usage)}); return *matrix_preps.at(dims); @@ -52,22 +52,27 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int k = 0; k < dim[1]; k++) + if (not T::real_shares(processor.P)) { - for (int i = 0; i < dim[0]; i++) + matrix_multiply(A, B, processor); + return; + } + + for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { auto kk = Proc->get_Ci().at(dim[4] + k); auto ii = Proc->get_Ci().at(dim[3] + i); - A[{i, k}] = source.at(a + ii * dim[7] + kk); + A.entries.v.push_back(source.at(a + ii * dim[7] + kk)); } + for (int k = 0; k < dim[1]; k++) for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); auto ll = Proc->get_Ci().at(dim[5] + k); - B[{k, j}] = source.at(b + ll * dim[8] + jj); + B.entries.v.push_back(source.at(b + ll * dim[8] + jj)); } - } auto res = matrix_multiply(A, B, processor); @@ -94,13 +99,16 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[1] = min(max_inner, A.n_cols - i); subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); - MatrixMC mc; beaver.init(prep, mc); beaver.init_mul(); - beaver.prepare_mul(A.from(0, i, subdim.data()), - B.from(i, j, subdim.data() + 1)); - beaver.exchange(); - C.add_from_col(j, beaver.finalize_mul()); + bool for_real = T::real_shares(processor.P); + beaver.prepare_mul(A.from(0, i, subdim.data(), for_real), + B.from(i, j, subdim.data() + 1, for_real)); + if (for_real) + { + beaver.exchange(); + C.add_from_col(j, beaver.finalize_mul()); + } } } @@ -150,6 +158,15 @@ void Hemi::conv2ds(SubProcessor& processor, array dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); + if (not T::real_shares(processor.P)) + { + matrix_multiply(A, B, processor); + return; + } + + A.entries.init(); + B.entries.init(); + for (int i_batch = 0; i_batch < batch_size; i_batch ++) { size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index 4a85cbe3..ddf7e186 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -10,6 +10,7 @@ template class HemiPrep; template class Hemi; +template class HemiMatrixPrep; template class HemiShare : public SemiShare @@ -26,6 +27,9 @@ public: typedef typename conditional, Beaver>::type Protocol; typedef HemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = true; static true_type triple_matmul; diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 19d5e72d..fccd2ef5 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -298,7 +298,8 @@ void TreeSum::start(vector& values, const Player& P) { // send from the root player os.reset_write_head(); - for (unsigned int i=0; i::~Direct_MAC_Check() { template void direct_add_openings(vector& values, const PlayerBase& P, vector& os) { - for (unsigned int i=0; i(); } template diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index e855214f..1f745251 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -13,6 +13,7 @@ using namespace std; #include "Tools/PointerVector.h" template class Preprocessing; +template class MatrixMC; /** * Abstract base class for opening protocols @@ -20,6 +21,8 @@ template class Preprocessing; template class MAC_Check_Base { + friend class MatrixMC; + protected: /* MAC Share */ typename T::mac_key_type::Scalar alphai; @@ -59,6 +62,7 @@ public: /// Get next opened value virtual typename T::clear finalize_open(); virtual typename T::open_type finalize_raw(); + array finalize_several(size_t n); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 59c6c5de..47528e00 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -70,6 +70,13 @@ typename T::open_type MAC_Check_Base::finalize_raw() return values.next(); } +template +array MAC_Check_Base::finalize_several(size_t n) +{ + assert(values.left() >= n); + return {{values.skip(0), values.skip(n)}}; +} + template void MAC_Check_Base::CheckFor(const typename T::open_type& value, const vector& shares, const Player& P) diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index e6f3a8a6..7c94b5d8 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,8 +42,12 @@ public: typedef GC::MaliciousRepSecret bit_type; + // indicate security relevance of field size + typedef T mac_key_type; + const static bool expensive = true; static const bool has_trunc_pr = false; + static const bool malicious = true; static string type_short() { diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 17eec6f1..631ef766 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -160,7 +160,7 @@ template void CommMaliciousRepMC::POpen_Begin(vector& values, const vector& S, const Player& P) { - assert(T::length == 2); + assert(T::vector_length == 2); (void)values; os.resize(2); for (auto& o : os) diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index ceedc915..332996dd 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -45,6 +45,8 @@ public: typedef GC::MaliciousCcdSecret bit_type; #endif + static const bool malicious = true; + static string type_short() { return "M" + super::type_short(); diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index afb45662..78627697 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -122,6 +122,7 @@ public: const static bool expensive = false; const static bool variable_players = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index 5e197804..7befb7f4 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -37,6 +37,8 @@ public: typedef GC::Rep4Secret bit_type; + static const bool malicious = true; + static string type_short() { return "R4" + string(1, T::type_char()); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index ba5b85c8..48b01440 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -121,6 +121,8 @@ public: virtual void cisc(SubProcessor&, const Instruction&) { throw runtime_error("CISC instructions not implemented"); } + + virtual vector get_relevant_players(); }; /** @@ -146,7 +148,7 @@ public: static void assign(T& share, const typename T::clear& value, int my_num) { - assert(T::length == 2); + assert(T::vector_length == 2); share.assign_zero(); if (my_num < 2) share[my_num] = value; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 2d9eba57..f398da7f 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -28,7 +28,7 @@ ProtocolBase::ProtocolBase() : template Replicated::Replicated(Player& P) : ReplicatedBase(P) { - assert(T::length == 2); + assert(T::vector_length == 2); } template @@ -152,6 +152,16 @@ T ProtocolBase::get_random() return res; } +template +vector ProtocolBase::get_relevant_players() +{ + vector res; + int n = dynamic_cast(*this).P.num_players(); + for (int i = 0; i < T::threshold(n) + 1; i++) + res.push_back(i); + return res; +} + template void Replicated::init_mul() { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 9bb3c30a..9e1498df 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -71,7 +71,7 @@ public: ReplicatedInput(SubProcessor* proc, Player& P) : PrepLessInput(proc), proc(proc), P(P), protocol(P) { - assert(T::length == 2); + assert(T::vector_length == 2); expect.resize(P.num_players()); this->reset_all(P); } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index e72c0d83..4d875a3b 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -28,7 +28,7 @@ void ReplicatedMC::POpen_Begin(vector&, template void ReplicatedMC::prepare(const vector& S) { - assert(T::length == 2); + assert(T::vector_length == 2); o.reset_write_head(); to_send.reset_write_head(); to_send.reserve(S.size() * T::value_type::size()); diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h new file mode 100644 index 00000000..a90c6e64 --- /dev/null +++ b/Protocols/SecureShuffle.h @@ -0,0 +1,53 @@ +/* + * SecureShuffle.h + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_H_ +#define PROTOCOLS_SECURESHUFFLE_H_ + +#include +using namespace std; + +template class SubProcessor; + +template +class SecureShuffle +{ + SubProcessor& proc; + vector to_shuffle; + vector> config; + vector tmp; + int unit_size; + + vector>>> shuffles; + size_t n_shuffle; + bool exact; + + void player_round(int config_player); + void generate(int config_player, int n_shuffle); + + void waksman(vector& a, int depth, int start); + void cond_swap(T& x, T& y, const T& b); + + void iter_waksman(bool reverse = false); + void waksman_round(int size, bool inwards, bool reverse); + + void pre(vector& a, size_t n, size_t input_base); + void post(vector& a, size_t n, size_t input_base); + +public: + SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc); + + SecureShuffle(SubProcessor& proc); + + int generate(int n_shuffle); + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse); + + void del(int handle); +}; + +#endif /* PROTOCOLS_SECURESHUFFLE_H_ */ diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp new file mode 100644 index 00000000..d2b0676a --- /dev/null +++ b/Protocols/SecureShuffle.hpp @@ -0,0 +1,328 @@ +/* + * SecureShuffle.hpp + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_HPP_ +#define PROTOCOLS_SECURESHUFFLE_HPP_ + +#include "SecureShuffle.h" +#include "Tools/Waksman.h" + +#include +#include + +template +SecureShuffle::SecureShuffle(SubProcessor& proc) : + proc(proc), unit_size(0), n_shuffle(0), exact(false) +{ +} + +template +SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc) : + proc(proc), unit_size(unit_size) +{ + pre(a, n, input_base); + + for (auto i : proc.protocol.get_relevant_players()) + player_round(i); + + post(a, n, output_base); +} + +template +void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse) +{ + this->unit_size = unit_size; + + pre(a, n, input_base); + + auto& shuffle = shuffles.at(handle); + assert(shuffle.size() == proc.protocol.get_relevant_players().size()); + + if (reverse) + for (auto it = shuffle.end(); it > shuffle.begin(); it--) + { + this->config = *(it - 1); + iter_waksman(reverse); + } + else + for (auto& config : shuffle) + { + this->config = config; + iter_waksman(reverse); + } + + post(a, n, output_base); +} + +template +void SecureShuffle::del(int handle) +{ + shuffles.at(handle).clear(); +} + +template +void SecureShuffle::pre(vector& a, size_t n, size_t input_base) +{ + n_shuffle = n / unit_size; + assert(unit_size * n_shuffle == n); + size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; + to_shuffle.clear(); + + if (exact) + { + to_shuffle.resize(n_shuffle_pow2 * unit_size); + for (size_t i = 0; i < n; i++) + to_shuffle[i] = a[input_base + i]; + } + else + { + // sorting power of two elements together with indicator bits + to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle)))); + for (size_t i = 0; i < n_shuffle; i++) + { + for (int j = 0; j < unit_size; j++) + to_shuffle[i * (unit_size + 1) + j] = a[input_base + + i * unit_size + j]; + to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1, + proc.P.my_num(), proc.MC.get_alphai()); + } + this->unit_size++; + } +} + +template +void SecureShuffle::post(vector& a, size_t n, size_t output_base) +{ + if (exact) + for (size_t i = 0; i < n; i++) + a[output_base + i] = to_shuffle[i]; + else + { + auto& MC = proc.MC; + MC.init_open(proc.P); + int shuffle_unit_size = this->unit_size; + int unit_size = shuffle_unit_size - 1; + for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++) + MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1)); + MC.exchange(proc.P); + size_t i_shuffle = 0; + for (size_t i = 0; i < n_shuffle; i++) + { + auto bit = MC.finalize_open(); + if (bit == 1) + { + // only output real elements + for (int j = 0; j < unit_size; j++) + a.at(output_base + i_shuffle * unit_size + j) = + to_shuffle.at(i * shuffle_unit_size + j); + i_shuffle++; + } + } + if (i_shuffle != n_shuffle) + throw runtime_error("incorrect shuffle"); + } +} + +template +void SecureShuffle::player_round(int config_player) +{ + generate(config_player, n_shuffle); + iter_waksman(); +} + +template +int SecureShuffle::generate(int n_shuffle) +{ + int res = shuffles.size(); + shuffles.push_back({}); + auto& shuffle = shuffles.back(); + + for (auto i : proc.protocol.get_relevant_players()) + { + generate(i, n_shuffle); + shuffle.push_back(config); + } + + return res; +} + +template +void SecureShuffle::generate(int config_player, int n) +{ + auto& P = proc.P; + auto& input = proc.input; + input.reset_all(P); + int n_pow2 = 1 << int(ceil(log2(n))); + Waksman waksman(n_pow2); + + if (P.my_num() == config_player) + { + vector perm; + int shuffle_size = n; + for (int j = 0; j < n_pow2; j++) + perm.push_back(j); + SeededPRNG G; + for (int i = 0; i < shuffle_size; i++) + { + int j = G.get_uint(shuffle_size - i); + swap(perm[i], perm[i + j]); + } + + auto config_bits = waksman.configure(perm); + for (size_t i = 0; i < config_bits.size(); i++) + { + auto& x = config_bits[i]; + for (size_t j = 0; j < x.size(); j++) + if (waksman.matters(i, j)) + input.add_mine(int(x[j])); + else + assert(x[j] == 0); + } + } + else + for (size_t i = 0; i < waksman.n_bits(); i++) + input.add_other(config_player); + + input.exchange(); + config.clear(); + typename T::Protocol checker(P); + checker.init(proc.DataF, proc.MC); + checker.init_dotprod(); + auto one = T::constant(1, P.my_num(), proc.MC.get_alphai()); + for (size_t i = 0; i < waksman.n_rounds(); i++) + { + config.push_back({}); + for (int j = 0; j < n_pow2; j++) + { + if (waksman.matters(i, j)) + { + config.back().push_back(input.finalize(config_player)); + if (T::malicious) + checker.prepare_dotprod(config.back().back(), + one - config.back().back()); + } + else + config.back().push_back({}); + } + } + + if (T::malicious) + { + checker.next_dotprod(); + checker.exchange(); + assert( + typename T::clear( + proc.MC.open(checker.finalize_dotprod(waksman.n_bits()), + P)) == 0); + checker.check(); + } +} + +template +void SecureShuffle::waksman(vector& a, int depth, int start) +{ + int n = a.size(); + + if (n == 2) + { + cond_swap(a[0], a[1], config.at(depth).at(start)); + return; + } + + vector a0(n / 2), a1(n / 2); + for (int i = 0; i < n / 2; i++) + { + a0.at(i) = a.at(2 * i); + a1.at(i) = a.at(2 * i + 1); + + cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2)); + } + + waksman(a0, depth + 1, start); + waksman(a1, depth + 1, start + n / 2); + + for (int i = 0; i < n / 2; i++) + { + a.at(2 * i) = a0.at(i); + a.at(2 * i + 1) = a1.at(i); + cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start)); + } +} + +template +void SecureShuffle::cond_swap(T& x, T& y, const T& b) +{ + auto diff = proc.protocol.mul(x - y, b); + x -= diff; + y += diff; +} + +template +void SecureShuffle::iter_waksman(bool reverse) +{ + int n = to_shuffle.size() / unit_size; + + for (int depth = 0; depth < log2(n); depth++) + waksman_round(depth, true, reverse); + + for (int depth = log2(n) - 2; depth >= 0; depth--) + waksman_round(depth, false, reverse); +} + +template +void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) +{ + int n = to_shuffle.size() / unit_size; + assert((int) config.at(depth).size() == n); + int nblocks = 1 << depth; + int size = n / (2 * nblocks); + bool outwards = !inwards; + proc.protocol.init_mul(); + vector> indices; + indices.reserve(n / 2); + Waksman waksman(n); + for (int k = 0; k < n / 2; k++) + { + int j = k % size; + int i = k / size; + int base = 2 * i * size; + int in1 = base + j + j * inwards; + int in2 = in1 + inwards + size * outwards; + int out1 = base + j + j * outwards; + int out2 = out1 + outwards + size * inwards; + int i_bit = base + j + size * (outwards ^ reverse); + bool run = waksman.matters(depth, i_bit); + if (run) + { + for (int l = 0; l < unit_size; l++) + proc.protocol.prepare_mul(config.at(depth).at(i_bit), + to_shuffle.at(in1 * unit_size + l) + - to_shuffle.at(in2 * unit_size + l)); + } + indices.push_back({{in1, in2, out1, out2, run}}); + } + proc.protocol.exchange(); + tmp.resize(to_shuffle.size()); + for (int k = 0; k < n / 2; k++) + { + auto idx = indices.at(k); + for (int l = 0; l < unit_size; l++) + { + T diff; + if (idx[4]) + diff = proc.protocol.finalize_mul(); + tmp.at(idx[2] * unit_size + l) = to_shuffle.at( + idx[0] * unit_size + l) - diff; + tmp.at(idx[3] * unit_size + l) = to_shuffle.at( + idx[1] * unit_size + l) + diff; + } + } + swap(tmp, to_shuffle); +} + +#endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index b306d5c3..432b599b 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -78,6 +78,7 @@ public: const static bool variable_players = true; const static bool expensive = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index aea0bb97..bf40cb28 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -49,6 +49,7 @@ public: const static bool dishonest_majority = false; const static bool variable_players = true; const static bool expensive = false; + const static bool malicious = true; static string type_short() { diff --git a/Protocols/Share.h b/Protocols/Share.h index e2a9f0bb..9ca86cea 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -56,6 +56,7 @@ class Share_ : public ShareInterface const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; const static bool has_mac = true; + static const bool malicious = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index e5af8ddd..a8ef7a22 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -40,12 +40,17 @@ public: static const bool has_trunc_pr = false; static const bool has_split = false; static const bool has_mac = false; + static const bool malicious = false; static const false_type triple_matmul; + const static bool symmetric = true; + static const int default_length = 1; - static string type_short() { return "undef"; } + static string type_short() { throw runtime_error("don't call this"); } + + static bool real_shares(const Player&) { return true; } template static void split(vector, vector, int, T*, int, @@ -63,6 +68,8 @@ public: template static void generate_mac_key(T&, U&) {} + + static int threshold(int) { throw runtime_error("undefined threshold"); } }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index 7f84213e..b31aa708 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -14,6 +14,124 @@ using namespace std; template class MatrixMC; +template +class NonInitVector +{ + template friend class NonInitVector; + + size_t size_; +public: + AddableVector v; + + NonInitVector(size_t size) : + size_(size) + { + v.reserve(size); + } + + template + NonInitVector(const NonInitVector& other) : + size_(other.size()), v(other.v) + { + } + + size_t size() const + { + return size_; + } + + void init() + { + v.resize(size_); + } + + void check() const + { +#ifdef DEBUG_MATRIX + assert(not v.empty()); +#endif + } + + typename vector::iterator begin() + { + check(); + return v.begin(); + } + + typename vector::iterator end() + { + check(); + return v.end(); + } + + T& at(size_t index) + { + check(); + return v.at(index); + } + + const T& at(size_t index) const + { +#ifdef DEBUG_MATRIX + assert(index < size()); +#endif + return (*this)[index]; + } + + T& operator[](size_t index) + { + check(); + return v[index]; + } + + const T& operator[](size_t index) const + { + check(); + return v[index]; + } + + NonInitVector operator-(const NonInitVector& other) const + { + assert(size() == other.size()); + NonInitVector res(size()); + if (other.v.empty()) + return *this; + else if (v.empty()) + { + res.init(); + res.v = res.v - other.v; + } + else + res.v = v - other.v; + return res; + } + + NonInitVector& operator+=(const NonInitVector& other) + { + assert(size() == other.size()); + if (not other.v.empty()) + { + if (v.empty()) + *this = other; + else + v += other.v; + } + return *this; + } + + bool operator!=(const NonInitVector& other) const + { + return v != other.v; + } + + void randomize(PRNG& G) + { + v.clear(); + for (size_t i = 0; i < size(); i++) + v.push_back(G.get()); + } +}; + template class ValueMatrix : public ValueInterface { @@ -21,7 +139,7 @@ class ValueMatrix : public ValueInterface public: int n_rows, n_cols; - AddableVector entries; + NonInitVector entries; static DataFieldType field_type() { @@ -48,15 +166,19 @@ public: T& operator[](const pair& indices) { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } const T& operator[](const pair& indices) const { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } @@ -80,6 +202,9 @@ public: { assert(n_cols == other.n_rows); This res(n_rows, other.n_cols); + if (entries.v.empty() or other.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < n_rows; i++) for (int j = 0; j < other.n_cols; j++) for (int k = 0; k < n_cols; k++) @@ -103,9 +228,9 @@ public: ValueMatrix transpose() const { ValueMatrix res(this->n_cols, this->n_rows); - for (int i = 0; i < this->n_rows; i++) - for (int j = 0; j < this->n_cols; j++) - res[{j, i}] = (*this)[{i, j}]; + for (int j = 0; j < this->n_cols; j++) + for (int i = 0; i < this->n_rows; i++) + res.entries.v.push_back((*this)[{i, j}]); return res; } @@ -139,7 +264,7 @@ public: { This res(other.n_rows, other.n_cols); for (size_t i = 0; i < other.entries.size(); i++) - res.entries[i] = T::constant(other.entries[i], my_num, key); + res.entries.v.push_back(T::constant(other.entries[i], my_num, key)); res.check(); return res; } @@ -167,24 +292,29 @@ public: ShareMatrix from_col(int start, int size) const { ShareMatrix res(this->n_rows, min(size, this->n_cols - start)); + res.entries.clear(); for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{i, start + j}]; + res.entries.v.push_back((*this)[{i, start + j}]); return res; } - ShareMatrix from(int start_row, int start_col, int* sizes) const + ShareMatrix from(int start_row, int start_col, int* sizes, bool for_real = + true) const { ShareMatrix res(min(sizes[0], this->n_rows - start_row), min(sizes[1], this->n_cols - start_col)); + if (not for_real) + return res; for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{start_row + i, start_col + j}]; + res.entries.v.push_back((*this)[{start_row + i, start_col + j}]); return res; } void add_from_col(int start, const ShareMatrix& other) { + this->entries.init(); for (int i = 0; i < this->n_rows; i++) for (int j = 0; j < other.n_cols; j++) (*this)[{i, start + j}] += other[{i, j}]; @@ -197,6 +327,9 @@ ShareMatrix operator*(const ValueMatrix& a, { assert(a.n_cols == b.n_rows); ShareMatrix res(a.n_rows, b.n_cols); + if (a.entries.v.empty() or b.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < a.n_rows; i++) for (int j = 0; j < b.n_cols; j++) for (int k = 0; k < a.n_cols; k++) @@ -208,9 +341,22 @@ ShareMatrix operator*(const ValueMatrix& a, template class MatrixMC : public MAC_Check_Base> { - typename T::MAC_Check inner; + typename T::MAC_Check& inner; public: + MatrixMC() : + inner( + *(OnlineOptions::singleton.direct ? + new typename T::Direct_MC : + new typename T::MAC_Check)) + { + } + + ~MatrixMC() + { + delete &inner; + } + void exchange(const Player& P) { inner.init_open(P); @@ -224,8 +370,15 @@ public: for (auto& share : this->secrets) { this->values.push_back({share.n_rows, share.n_cols}); - for (auto& entry : this->values.back().entries) - entry = inner.finalize_open(); + if (share.entries.v.empty()) + for (size_t i = 0; i < share.entries.size(); i++) + inner.finalize_open(); + else + { + auto range = inner.finalize_several(share.entries.size()); + auto& v = this->values.back().entries.v; + v.insert(v.begin(), range[0], range[1]); + } } } }; diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h index f4f37dcd..049881ff 100644 --- a/Protocols/TemiShare.h +++ b/Protocols/TemiShare.h @@ -25,6 +25,9 @@ public: typedef typename conditional, Beaver>::type Protocol; typedef TemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = false; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index bae415c4..63b058e0 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -130,7 +130,6 @@ template void make_share(DealerShare* Sa, const T& a, int N, const U&, PRNG& G) { make_share((SemiShare*) Sa, a, N - 1, U(), G); - Sa[N - 1] = {}; } template @@ -273,6 +272,11 @@ inline string mac_filename(string directory, int playerno) + to_string(playerno); } +template <> +inline void write_mac_key(const string&, int, int, GC::NoValue) +{ +} + template void write_mac_key(const string& directory, int i, int nplayers, U key) { @@ -301,6 +305,11 @@ void read_mac_key(const string& directory, const Names& N, T& key) read_mac_key(directory, N.my_num(), N.num_players(), key); } +template <> +inline void read_mac_key(const string&, int, int, GC::NoValue&) +{ +} + template void read_mac_key(const string& directory, int player_num, int nplayers, U& key) { @@ -367,7 +376,7 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, } template -void read_global_mac_key(const string& directory, int nparties, U& key, false_type) +void read_global_mac_key(const string& directory, int nparties, U& key) { U pp; key.assign_zero(); @@ -383,17 +392,11 @@ void read_global_mac_key(const string& directory, int nparties, U& key, false_ty cout << "Final Keys : " << key << endl; } -template -void read_global_mac_key(const string&, int, U&, true_type) +template <> +inline void read_global_mac_key(const string&, int, GC::NoValue&) { } -template -void read_global_mac_key(const string& directory, int nparties, U& key) -{ - read_global_mac_key(directory, nparties, key, is_same()); -} - template T reconstruct(vector& shares) { @@ -579,14 +582,14 @@ void plain_edabits(vector& as, as.resize(max_size); bs.clear(); bs.resize(length); - bigint value; + Z2 value; for (int j = 0; j < max_size; j++) { if (not zero) - G.get_bigint(value, length, true); + value.randomize_part(G, length); as[j] = value; for (int k = 0; k < length; k++) - bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + bs[k] ^= BitVec(value.get_bit(k)) << j; } } diff --git a/README.md b/README.md index a3f6741f..cd0e9781 100644 --- a/README.md +++ b/README.md @@ -101,8 +101,9 @@ The following table lists all protocols that are fully supported. | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep3 / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Malicious, honest supermajority | [Rep4](#honest-majority) | [Rep4](#honest-majority) | [Rep4](#honest-majority) | N/A | | Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | Modulo prime and modulo 2^k are the two settings that allow @@ -280,6 +281,8 @@ compute the preprocessing time for a particular computation. - Python 3.5 or later - NTL library for homomorphic encryption (optional; tested with NTL 10.5) - If using macOS, Sierra or later + - Windows/VirtualBox: see [this + issue](https://github.com/data61/MP-SPDZ/issues/557) for a discussion #### Compilation diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index c7f8c371..ec39b772 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -84,7 +84,9 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil { } -gf2n_not_supported::gf2n_not_supported(int n) : - runtime_error("GF(2^" + to_string(n) + ") not supported") +gf2n_not_supported::gf2n_not_supported(int n, string options) : + runtime_error( + "GF(2^" + to_string(n) + ") not supported" + + (options.empty() ? "" : ", options are " + options)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index bb347c6a..a3ca3a5d 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -281,7 +281,7 @@ public: class gf2n_not_supported : public runtime_error { public: - gf2n_not_supported(int n); + gf2n_not_supported(int n, string options = ""); }; #endif diff --git a/Tools/PointerVector.h b/Tools/PointerVector.h index 32d1b46e..404c4ee9 100644 --- a/Tools/PointerVector.h +++ b/Tools/PointerVector.h @@ -30,6 +30,15 @@ public: { return (*this)[i++]; } + T* skip(size_t n) + { + i += n; + return &(*this)[i]; + } + size_t left() + { + return this->size() - i; + } }; #endif /* TOOLS_POINTERVECTOR_H_ */ diff --git a/Tools/Waksman.cpp b/Tools/Waksman.cpp new file mode 100644 index 00000000..a54b7766 --- /dev/null +++ b/Tools/Waksman.cpp @@ -0,0 +1,91 @@ +/* + * Waksman.cpp + * + */ + +#include "Waksman.h" + +#include +#include +#include + +template +void append(vector& x, const vector& y) +{ + x.insert(x.end(), y.begin(), y.end()); +} + +vector > Waksman::configure(const vector& perm) +{ + int n = perm.size(); + assert(n > 1); + + if (n == 2) + return {{perm[0] == 1, perm[0] == 1}}; + + vector I(n / 2); + vector O(n / 2, -1); + vector p0(n / 2, -1), p1(n / 2, -1), inv_perm(n); + + for (int i = 0; i < n; i++) + inv_perm[perm[i]] = i; + + while (true) + { + auto it = find(O.begin(), O.end(), -1); + if (it == O.end()) + break; + int j = 2 * (it - O.begin()); + O.at(j / 2) = 0; + int j0 = j; + + while (true) + { + int i = inv_perm.at(j); + p0.at(i / 2) = j / 2; + I.at(i / 2) = i % 2; + O.at(j / 2) = j % 2; + if (i % 2 == 1) + i--; + else + i++; + j = perm.at(i); + if (j % 2 == 1) + j--; + else + j++; + p1.at(i / 2) = perm.at(i) / 2; + if (j == j0) + break; + } + + if ((find(p1.begin(), p1.end(), -1) == p1.end()) + and (find(p0.begin(), p0.end(), -1) == p0.end())) + break; + } + + auto p0_config = configure(p0); + auto p1_config = configure(p1); + + vector> res; + res.push_back(I); + for (auto& x : O) + res.back().push_back(x); + + assert(p0_config.size() == p1_config.size()); + + for (size_t i = 0; i < p0_config.size(); i++) + { + res.push_back(p0_config.at(i)); + append(res.back(), p1_config.at(i)); + } + + assert(res.size() == Waksman(perm.size()).n_rounds()); + return res; +} + +Waksman::Waksman(int n_elements) : + n_elements(n_elements), nr(log2(n_elements)) +{ + assert(n_elements == (1 << nr)); +} diff --git a/Tools/Waksman.h b/Tools/Waksman.h new file mode 100644 index 00000000..521e990f --- /dev/null +++ b/Tools/Waksman.h @@ -0,0 +1,39 @@ +/* + * Waksman.h + * + */ + +#ifndef TOOLS_WAKSMAN_H_ +#define TOOLS_WAKSMAN_H_ + +#include +using namespace std; + +class Waksman +{ + int n_elements; + int nr; + +public: + static vector> configure(const vector& perm); + + Waksman(int n_elements); + + size_t n_rounds() const + { + return nr; + } + + bool matters(int i, int j) const + { + int block = n_elements >> i; + return block == 2 or j % block != block / 2; + } + + size_t n_bits() const + { + return nr * n_elements - (1 << (nr - 1)) + 1; + } +}; + +#endif /* TOOLS_WAKSMAN_H_ */ diff --git a/Utils/he-example.cpp b/Utils/he-example.cpp new file mode 100644 index 00000000..179028a5 --- /dev/null +++ b/Utils/he-example.cpp @@ -0,0 +1,97 @@ +/* + * he-example.cpp + * + */ + +#include "FHE/FHE_Params.h" +#include "FHE/NTL-Subs.h" +#include "FHE/FHE_Keys.h" +#include "FHE/Plaintext.h" + +void first_phase(string filename, int n_mults, int circuit_sec); +void second_phase(string filename); + +int main() +{ + for (int n_mults = 0; n_mults < 2; n_mults++) + for (int sec = 0; sec <= 120; sec += 40) + { + string filename = "mp-spdz-he"; + first_phase(filename, n_mults, sec); + second_phase(filename); + } +} + +void first_phase(string filename, int n_mults, int circuit_sec) +{ + // specify number of multiplications (at most one) and function privacy parameter + // increase the latter to accommodate more operations + FHE_Params params(n_mults, circuit_sec); + + // generate parameters for computation modulo a 32-bit prime + params.basic_generation_mod_prime(32); + + // find computation modulus (depends on parameter generation) + cout << "computation modulo " << params.get_plaintext_modulus() << endl; + + // generate key pair + FHE_KeyPair pair(params); + pair.generate(); + + Plaintext_mod_prime plaintext(params); + + // set first two plaintext slots + plaintext.set_element(0, 4); + plaintext.set_element(1, -1); + + // encrypt + Ciphertext ciphertext = pair.pk.encrypt(plaintext); + + // store for second phase + octetStream os; + params.pack(os); + pair.pk.pack(os); + ciphertext.pack(os); + plaintext.pack(os); + pair.sk.pack(os); + ofstream out(filename); + os.output(out); +} + +void second_phase(string filename) +{ + // read from file + ifstream in(filename); + octetStream os; + os.input(in); + FHE_Params params; + FHE_PK pk(params); + FHE_SK sk(params); + Plaintext_mod_prime plaintext(params); + Ciphertext ciphertext(params); + + // parameter must be set correctly first + params.unpack(os); + pk.unpack(os); + ciphertext.unpack(os); + plaintext.unpack(os); + + if (params.n_mults() == 0) + // public-private multiplication is always available + ciphertext *= plaintext; + else + // private-private multiplication only with matching parameters + ciphertext = ciphertext.mul(pk, ciphertext); + + // re-randomize for circuit privacy + ciphertext.rerandomize(pk); + + // read secret key and decrypt + sk.unpack(os); + plaintext = sk.decrypt(ciphertext); + + cout << "should be 16: " << plaintext.element(0) << endl; + cout << "should be 1: " << plaintext.element(1) << endl; + assert(plaintext.element(0) == 16); + assert(plaintext.element(1) == 1); +} diff --git a/doc/Doxyfile b/doc/Doxyfile index 9820ba50..5f1143e3 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h ../FHE/Ciphertext.h ../FHE/FHE_Keys.h ../FHE/FHE_Params.h ../FHE/Plaintext.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/homomorphic-encryption.rst b/doc/homomorphic-encryption.rst new file mode 100644 index 00000000..95c922fd --- /dev/null +++ b/doc/homomorphic-encryption.rst @@ -0,0 +1,31 @@ +Homomorphic Encryption +---------------------- + +MP-SPDZ uses BGV encryption for triple generation in a number of +protocols. This involves zero-knowledge proofs in some protocols and +considerations about function privacy in all of them. The interface +described below allows directly accessing the basic cryptographic +operations in contexts where these considerations are not relevant. +See ``Utils/he-example.cpp`` for some example code. + + +Reference +~~~~~~~~~ + +.. doxygenclass:: FHE_Params + :members: + +.. doxygenclass:: FHE_KeyPair + :members: + +.. doxygenclass:: FHE_SK + :members: + +.. doxygenclass:: FHE_PK + :members: + +.. doxygenclass:: Plaintext + :members: + +.. doxygenclass:: Ciphertext + :members: diff --git a/doc/index.rst b/doc/index.rst index d2a2c4dc..59caa58d 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -175,6 +175,7 @@ Reference non-linear preprocessing add-protocol + homomorphic-encryption troubleshooting diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 26808480..8b02fa3e 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -148,12 +148,14 @@ AVX/AVX2 instructions are deactivated (see e.g. `here `_), which causes a dramatic performance loss. Deactivate Hyper-V/Hypervisor using:: + bcdedit /set hypervisorlaunchtype off DISM /Online /Disable-Feature:Microsoft-Hyper-V Performance can be further increased when compiling MP-SPDZ yourself: :: + sudo apt-get update sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm git clone https://github.com/data61/MP-SPDZ.git From 88534961b3492b7804f2de0d8425f5ee0b401bdb Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 2 Jun 2022 17:12:11 +0200 Subject: [PATCH 058/265] Fix biases in PRNG. --- Tools/random.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Tools/random.cpp b/Tools/random.cpp index 7cf1924f..94f97cf6 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -179,10 +179,11 @@ unsigned int PRNG::get_uint(int upper) } // not power of 2 unsigned int r, reduced; + bool use_char = upper <= 128; do { - r = (upper < 255) ? get_uchar() : get_uint(); + r = use_char ? get_uchar() : get_uint(); reduced = r % upper; - } while (int(r - reduced + (upper - 1)) < 0); + } while (int(r - reduced + (upper - 1)) > (use_char ? 256 : 0)); return reduced; } @@ -260,7 +261,9 @@ void PRNG::get_bigint(bigint& res, int n_bits, bool positive) octet* bytes = (octet*) words; words[n_words - 1] = 0; get_octets(bytes, n_bytes); - octet mask = (1 << (n_bits % 8)) - 1; + octet mask = -1; + if (n_bits % 8 > 0) + mask = (1 << (n_bits % 8)) - 1; bytes[n_bytes - 1] &= mask; mpz_import(res.get_mpz_t(), n_words, -1, sizeof(word), -1, 0, bytes); if (not positive and (get_bit())) From 6755a8fa5105be5d196f2a8b71e2609fad564611 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 13 Jun 2022 11:20:55 +0200 Subject: [PATCH 059/265] Python client example. --- ExternalIO/README.md | 10 ++- ExternalIO/bankers-bonus-client.py | 35 +++++++++ ExternalIO/client.py | 113 +++++++++++++++++++++++++++++ ExternalIO/domains.py | 67 +++++++++++++++++ 4 files changed, 223 insertions(+), 2 deletions(-) create mode 100755 ExternalIO/bankers-bonus-client.py create mode 100644 ExternalIO/client.py create mode 100644 ExternalIO/domains.py diff --git a/ExternalIO/README.md b/ExternalIO/README.md index d4f99288..f5f418ed 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -2,12 +2,13 @@ The ExternalIO directory contains an example of managing I/O between external cl ## Working Examples -[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a +[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) and +[bankers-bonus-client.py](./bankers-bonus-client.py) act as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output as described by [Damgård et al.](https://eprint.iacr.org/2015/1006) The computation allows up to eight clients to input a number and computes the client -with the largest input. You can run it as follows from the main +with the largest input. You can run the C++ code as follows from the main directory: ``` make bankers-bonus-client.x @@ -30,6 +31,11 @@ protocol script. The setup scripts generate the necessary SSL certificates and keys. Therefore, if you run the computation on different hosts, you will have to distribute the `*.pem` files. +For the Python client, make sure to install +[gmpy2](https://pypi.org/project/gmpy2), and run +`ExternalIO/bankers-bonus-client.py` instead of +`bankers-bonus-client.x`. + ## I/O MPC Instructions ### Connection Setup diff --git a/ExternalIO/bankers-bonus-client.py b/ExternalIO/bankers-bonus-client.py new file mode 100755 index 00000000..d0f8d285 --- /dev/null +++ b/ExternalIO/bankers-bonus-client.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 + +import sys + +sys.path.append('.') + +from client import * +from domains import * + +client_id = int(sys.argv[1]) +n_parties = int(sys.argv[2]) +bonus = float(sys.argv[3]) +finish = int(sys.argv[4]) + +client = Client(['localhost'] * n_parties, 14000, client_id) + +type = client.specification.get_int(4) + +if type == ord('R'): + domain = Z2(client.specification.get_int(4)) +elif type == ord('p'): + domain = Fp(client.specification.get_bigint()) +else: + raise Exception('invalid type') + +for socket in client.sockets: + os = octetStream() + os.store(finish) + os.Send(socket) + +for x in bonus, bonus * 2 ** 16: + client.send_private_inputs([domain(x)]) + + print('Winning client id is :', + client.receive_outputs(domain, 1)[0].v % 2 ** 64) diff --git a/ExternalIO/client.py b/ExternalIO/client.py new file mode 100644 index 00000000..819647ff --- /dev/null +++ b/ExternalIO/client.py @@ -0,0 +1,113 @@ +import socket, ssl +import struct +import time + +class Client: + def __init__(self, hostnames, port_base, my_client_id): + ctx = ssl.SSLContext() + name = 'C%d' % my_client_id + prefix = 'Player-Data/%s' % name + ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key') + ctx.load_verify_locations(capath='Player-Data') + + self.sockets = [] + for i, hostname in enumerate(hostnames): + for j in range(10000): + try: + plain_socket = socket.create_connection( + (hostname, port_base + i)) + break + except ConnectionRefusedError: + if j < 60: + time.sleep(1) + else: + raise + octetStream(b'%d' % my_client_id).Send(plain_socket) + self.sockets.append(ctx.wrap_socket(plain_socket, + server_hostname='P%d' % i)) + + self.specification = octetStream() + self.specification.Receive(self.sockets[0]) + + def receive_triples(self, T, n): + triples = [[0, 0, 0] * n] + os = octetStream() + for socket in self.sockets: + os.Receive(socket) + for triple in triples: + for i in range(3): + t = T() + t.unpack(os) + triple[i] += t + res = [] + for triple in triples: + prod = triple[0] * triple[1] + if prod != triple[2]: + raise Exception( + 'invalid triple, diff %s' % hex(prod.v - triple[2].v)) + return triples + + def send_private_inputs(self, values): + T = type(values[0]) + triples = self.receive_triples(T, len(values)) + os = octetStream() + for value, triple in zip(values, triples): + (value + triple[0]).pack(os) + for socket in self.sockets: + os.Send(socket) + + def receive_outputs(self, T, n): + triples = self.receive_triples(T, n) + return [triple[0] for triple in triples] + +class octetStream: + def __init__(self, value=None): + self.buf = b'' + self.ptr = 0 + if value is not None: + self.buf += value + + def reset_write_head(self): + self.buf = b'' + self.ptr = 0 + + def Send(self, socket): + socket.send(struct.pack('= 0) + + def __add__(self, other): + try: + res = self.v + other.v + except: + res = self.v + other + return type(self)(res) + + def __mul__(self, other): + try: + res = self.v * other.v + except: + res = self.v * other + return type(self)(res) + + __radd__ = __add__ + + def __eq__(self, other): + return self.v == other.v + + def __neq__(self, other): + return self.v != other.v + + def unpack(self, os): + self.v = 0 + buf = os.consume(self.n_bytes) + for i, b in enumerate(buf): + self.v += b << (i * 8) + + def pack(self, os): + v = self.v + for i in range(self.n_bytes): + os.buf += struct.pack('>= 8 + +def Z2(k): + class Z(Domain): + modulus = 2 ** k + n_words = (k + 63) // 64 + n_bytes = (k + 7) // 8 + + return Z + +def Fp(mod): + import gmpy2 + + class Fp(Domain): + modulus = mod + n_words = (modulus.bit_length() + 63) // 64 + n_bytes = 8 * n_words + R = 2 ** (64 * n_words) % modulus + R_inv = gmpy2.invert(R, modulus) + + def unpack(self, os): + Domain.unpack(self, os) + self.v = self.v * self.R_inv % self.modulus + + def pack(self, os): + Domain.pack(type(self)(self.v * self.R), os) + + return Fp From 4c8e616b58710ddc118d5dee1de6ac7be41c908c Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 14 Jun 2022 16:14:37 +0200 Subject: [PATCH 060/265] Improved binary circuit functionality. --- Compiler/GC/types.py | 35 +++++++++++++++++++++++++++++++---- Compiler/types.py | 3 ++- Compiler/util.py | 5 +++++ Processor/Instruction.hpp | 1 + 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index fdd98722..4287844a 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -661,6 +661,9 @@ class sbitvec(_vec): return sbit.malloc(size * n, creator_tape=creator_tape) @staticmethod def n_elements(): + return 1 + @staticmethod + def mem_size(): return n @classmethod def get_input_from(cls, player): @@ -692,22 +695,33 @@ class sbitvec(_vec): self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n @classmethod - def load_mem(cls, address): + def load_mem(cls, address, size=None): + if size not in (None, 1): + assert isinstance(address, int) or len(address) == 1 + sb = sbits.get_type(size) + return cls.from_vec(sb.bit_compose( + sbit.load_mem(address + i + j * n) for j in range(size)) + for i in range(n)) if not isinstance(address, int) and len(address) == n: return cls.from_vec(sbit.load_mem(x) for x in address) else: return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): + size = 1 for x in self.v: - assert util.is_constant(x) or x.n == 1 - v = [sbit.conv(x) for x in self.v] + if not util.is_constant(x): + size = max(size, x.n) + v = [sbits.get_type(size).conv(x) for x in self.v] if not isinstance(address, int) and len(address) == n: + assert max_n == 1 for x, y in zip(v, address): x.store_in_mem(y) else: + assert isinstance(address, int) or len(address) == 1 for i in range(n): - v[i].store_in_mem(address + i) + for j, x in enumerate(v[i].bit_decompose()): + x.store_in_mem(address + i + j * n) def reveal(self): if len(self) > cbits.unit: return self.elements()[0].reveal() @@ -861,6 +875,19 @@ class sbitvec(_vec): return self ^ other def right_shift(self, m, k, security=None, signed=True): return self.from_vec(self.v[m:]) + def tree_reduce(self, function): + elements = self.elements() + while len(elements) > 1: + size = len(elements) + half = size // 2 + left = elements[:half] + right = elements[half:2*half] + odd = elements[2*half:] + sides = [self.from_vec(sbitvec(x).v) for x in (left, right)] + red = function(*sides) + elements = red.elements() + elements += odd + return self.from_vec(sbitvec(elements).v) class bit(object): n = 1 diff --git a/Compiler/types.py b/Compiler/types.py index 098f493f..735fddea 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5701,7 +5701,8 @@ class SubMultiArray(_vectorizable): self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ self.address + index * self.sizes[1] * - self.value_type.n_elements(), \ + self.value_type.n_elements() * \ + self.value_type.mem_size(), \ debug=self.debug) else: self.sub_cache[key] = \ diff --git a/Compiler/util.py b/Compiler/util.py index aa491e42..9d84df22 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -116,6 +116,11 @@ def round_to_int(x): return x.round_to_int() def tree_reduce(function, sequence): + try: + return sequence.tree_reduce(function) + except AttributeError: + pass + sequence = list(sequence) assert len(sequence) > 0 n = len(sequence) diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 5bed3703..a0f7a490 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -730,6 +730,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case ANDM: case NOTS: case NOTCB: + case TRANS: size = DIV_CEIL(n, 64); break; case CONVCBIT2S: From ec1d302b03bb8e747ef3ce51d7a04e50c2c8f796 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 23 Jun 2022 14:42:54 +0200 Subject: [PATCH 061/265] Local right shift for GF(2^n). --- Compiler/instructions.py | 1 + Compiler/instructions_base.py | 4 ++-- Compiler/types.py | 24 +++++++++++++++--------- Processor/Instruction.h | 5 +++-- Processor/Instruction.hpp | 4 ++++ Protocols/Rep3Share.h | 2 +- Protocols/Semi2kShare.h | 11 ----------- Protocols/SemiShare.h | 25 +++++++++++++++++++++++++ 8 files changed, 51 insertions(+), 25 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 5f5b82db..aac0c34c 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1051,6 +1051,7 @@ class shrci(base.ClearShiftInstruction): code = base.opcodes['SHRCI'] op = '__rshift__' +@base.gf2n @base.vectorize class shrsi(base.ClearShiftInstruction): """ Bitwise right shift of secret register (vector) by (constant) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d598d8a7..3a56e604 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -207,8 +207,8 @@ opcodes = dict( CONDPRINTPLAIN = 0xE1, INTOUTPUT = 0xE6, FLOATOUTPUT = 0xE7, - GBITDEC = 0x184, - GBITCOM = 0x185, + GBITDEC = 0x18A, + GBITCOM = 0x18B, # Secure socket INITSECURESOCKET = 0x1BA, RESPSECURESOCKET = 0x1BB diff --git a/Compiler/types.py b/Compiler/types.py index 735fddea..93991df1 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2126,6 +2126,21 @@ class _secret(_register, _secret_structure): res = personal(player, masked.reveal() - mask[1]) return res + @set_instruction_type + @vectorize + def raw_right_shift(self, length): + """ Local right shift in supported protocols. + In integer-like protocols, the output is potentially off by one. + + :param length: number of bits + """ + res = type(self)() + shrsi(res, self, length) + return res + + def raw_mod2m(self, m): + return self - (self.raw_right_shift(m) << m) + class sint(_secret, _int): """ @@ -2668,15 +2683,6 @@ class sint(_secret, _int): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize - def raw_right_shift(self, length): - res = sint() - shrsi(res, self, length) - return res - - def raw_mod2m(self, m): - return self - (self.raw_right_shift(m) << m) - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. diff --git a/Processor/Instruction.h b/Processor/Instruction.h index fd91e35d..f3caf565 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -284,8 +284,9 @@ enum // Bitwise shifts GSHLCI = 0x182, GSHRCI = 0x183, - GBITDEC = 0x184, - GBITCOM = 0x185, + GSHRSI = 0x184, + GBITDEC = 0x18A, + GBITCOM = 0x18B, // Conversion GCONVINT = 0x1C0, GCONVGF2N = 0x1C1, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index a0f7a490..5b0589b6 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -198,6 +198,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case GORCI: case GSHLCI: case GSHRCI: + case GSHRSI: case USE: case USE_INP: case USE_EDABIT: @@ -1006,6 +1007,9 @@ inline void Instruction::execute(Processor& Proc) const case SHRSI: sint::shrsi(Procp, *this); return; + case GSHRSI: + sgf2n::shrsi(Proc2, *this); + return; case OPEN: Proc.Procp.POpen(start, Proc.P, size); return; diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 78627697..fb02d26f 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -71,7 +71,7 @@ public: template static void shrsi(SubProcessor& proc, const Instruction& inst) { - shrsi(proc, inst, T::invertible); + shrsi(proc, inst, T::prime_field); } template diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index cc41d023..3d98cf1b 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -85,17 +85,6 @@ public: } } } - - template - static void shrsi(SubProcessor& proc, const Instruction& inst) - { - for (int i = 0; i < inst.get_size(); i++) - { - auto& dest = proc.get_S_ref(inst.get_r(0) + i); - auto& source = proc.get_S_ref(inst.get_r(1) + i); - dest = source >> inst.get_n(); - } - } }; #endif /* PROTOCOLS_SEMI2KSHARE_H_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 432b599b..8d9b1146 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -130,6 +130,31 @@ public: { super::unpack(os, n_bits); } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + shrsi(proc, inst, T::prime_field); + } + + template + static void shrsi(SubProcessor&, const Instruction&, + true_type) + { + throw runtime_error("shrsi not implemented"); + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst, + false_type) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; #endif /* PROTOCOLS_SEMISHARE_H_ */ From af5af2df251a84626dd451753039333e03ee51b7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 23 Jun 2022 19:17:08 +0200 Subject: [PATCH 062/265] Fix bug in logistic regression benchmark. --- Programs/Source/logreg.mpc | 1 + 1 file changed, 1 insertion(+) diff --git a/Programs/Source/logreg.mpc b/Programs/Source/logreg.mpc index 492e46e0..036a6a23 100644 --- a/Programs/Source/logreg.mpc +++ b/Programs/Source/logreg.mpc @@ -9,6 +9,7 @@ cfix.set_precision(16, 31) dim = int(program.args[1]) batch = int(program.args[2]) +ml.Layer.back_batch_size = batch try: n_iterations = int(program.args[3]) From 12a0f0c6c887658729a96425d5bb47c8f856817d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 10:00:33 +0200 Subject: [PATCH 063/265] Fix bug when using specific port numbers. --- Networking/Server.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Networking/Server.cpp b/Networking/Server.cpp index f9ff3e89..f8b545b9 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -176,9 +176,11 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, { pthread_create(&thread, 0, Server::start_in_thread, server = new Server(nplayers, portnum)); - N.init(my_num, portnum, my_port, hostname.c_str(), false); + bool default_port = my_port == Names::DEFAULT_PORT or my_port == portnum; + N.init(my_num, portnum, my_port, hostname.c_str(), not default_port); pthread_join(thread, 0); - N.set_server(server->get_socket()); + if (default_port) + N.set_server(server->get_socket()); delete server; } else From 31f32f5e667bdc5334a085f77ec5bd171b5c022e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 16:52:35 +0200 Subject: [PATCH 064/265] Fix bug in example code for adding protocols. --- Protocols/fake-stuff.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 63b058e0..564c79f9 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -375,6 +375,13 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, return res; } +template<> +inline GC::NoValue read_generate_write_mac_key(Player&, + string) +{ + return {}; +} + template void read_global_mac_key(const string& directory, int nparties, U& key) { From a3b7d49cb9061e7165f338e1b88381bdba4cda57 Mon Sep 17 00:00:00 2001 From: Richard Hernandez <3848345+RHG101997@users.noreply.github.com> Date: Fri, 24 Jun 2022 12:57:04 -0400 Subject: [PATCH 065/265] Small error --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18cc92ae..ac643580 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.3.2 (Mai 27, 2022) +## 0.3.2 (May 27, 2022) - Secure shuffling - O(n log n) radix sorting From 8707864c30b15d651e767de17f6bdb09bcfd8bc8 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 24 Jun 2022 11:50:20 +0200 Subject: [PATCH 066/265] Improved error message for unclosed if blocks. --- Compiler/library.py | 1 + Compiler/program.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index ef2fe1ab..cd32b84b 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1519,6 +1519,7 @@ def if_then(condition): state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False + state.caller = [frame[1:] for frame in inspect.stack()[1:]] instructions.program.curr_tape.if_states.append(state) def else_then(): diff --git a/Compiler/program.py b/Compiler/program.py index 78b802e1..e06418f3 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -776,7 +776,14 @@ class Tape: return if self.if_states: - raise CompilerError('Unclosed if/else blocks') + print('Tracebacks for open blocks:') + for state in self.if_states: + try: + print(util.format_trace(state.caller)) + except AttributeError: + pass + print() + raise CompilerError('Unclosed if/else blocks, see tracebacks above') if self.program.verbose: print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) From 1cdc6824207061268af28f71fe1a57dae93bac9f Mon Sep 17 00:00:00 2001 From: Richard Hernandez <3848345+RHG101997@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:28:42 -0400 Subject: [PATCH 067/265] Duplicated word --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd0e9781..81e4f88d 100644 --- a/README.md +++ b/README.md @@ -581,7 +581,7 @@ secure versions of LowGear and HighGear. In all relevant programs, option `-T` activates [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs in both. -Hemi and Soho denote the stripped version version of LowGear and +Hemi and Soho denote the stripped version of LowGear and HighGear, respectively, for semi-honest security similar to Semi, that is, generating additively shared Beaver triples using semi-homomorphic encryption. From 505d4838c18394e8bb87bc5bae5a8b9cc00d65ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 1 Jul 2022 12:18:30 +1000 Subject: [PATCH 068/265] Parameter for ring size in fake preprocessing. --- Utils/Fake-Offline.cpp | 51 +++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 823c318b..c1d14d20 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -61,6 +61,9 @@ public: { } + template + void generate_ring(); + template void make_with_mac_key(int nplayers, int default_num, bool zero); template @@ -394,7 +397,7 @@ int main(int argc, const char** argv) 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Bit length of GF(p) field (default: 128)", // Help description. + "Bit length of GF(p) field (default: 128) and Z_2^k rings (default: 64)", // Help description. "-lgp", // Flag token. "--lgp" // Flag token. ); @@ -729,22 +732,12 @@ int FakeParams::generate() // replicated secret sharing only for three parties if (nplayers == 3) { - make_bits>({}, nplayers, nbitsp, zero); - make_basic>({}, nplayers, default_num, - zero); - make_basic>({}, nplayers, - default_num, zero); - make_with_mac_key>(nplayers, - default_num, zero); - make_mult_triples({}, nplayers, ntrip2, zero, prep_data_prefix); make_bits({}, nplayers, nbits2, zero); } else if (nplayers == 4) make_basic>({}, nplayers, default_num, zero); - make_basic>>({}, nplayers, default_num, zero); - make_basic>>({}, nplayers, default_num, zero); make_minimal({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); @@ -778,6 +771,22 @@ int FakeParams::generate() generate_field(T::clear::prime_field); generate_field(true_type()); + // default + generate_ring<64>(); + + // reuse lgp for simplified interface + switch (lgp) + { + case 64: + break; +#define X(L) case L: generate_ring(); break; + X(128) X(192) X(256) + default: + cerr << "Not compiled for " << lgp << "-bit rings." << endl << "Add 'X(" + << lgp << "') to line " << (__LINE__ - 2) << " in " << __FILE__ << endl; + exit(1); + } + return 0; } @@ -803,3 +812,23 @@ void FakeParams::generate_field(true_type) default_num, zero); } } + +template +inline void FakeParams::generate_ring() +{ + if (nplayers == 3) + { + make_bits>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, + zero); + make_basic>({}, nplayers, + default_num, zero); + make_with_mac_key>(nplayers, + default_num, zero); + } + else if (nplayers == 4) + make_basic>({}, nplayers, default_num, zero); + + make_basic>>({}, nplayers, default_num, zero); + make_basic>>({}, nplayers, default_num, zero); +} From 7e2c0eda53289517eece67d8146a1c5cf689de23 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 4 Jul 2022 22:39:45 +1000 Subject: [PATCH 069/265] Splitting for any number of bits in Semi2k. --- Protocols/Semi2kShare.h | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index 3d98cf1b..679c6bc8 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -55,7 +55,6 @@ public: { auto& P = protocol.P; int my_num = P.my_num(); - assert(n_bits <= 64); int unit = GC::Clear::N_BITS; for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) { @@ -67,21 +66,27 @@ public: to_string(n) + "-way split not working with " + to_string(P.num_players()) + " parties"); - for (int i = 0; i < n_bits; i++) - for (int j = 0; j < n; j++) - dest.at(regs.at(n * i + j) + k) = {}; - - square64 square; - - for (int j = 0; j < m; j++) - square.rows[j] = Integer(source[j + start]).get(); - - square.transpose(m, n_bits); - - for (int j = 0; j < n_bits; j++) + for (int l = 0; l < n_bits; l += unit) { - auto& dest_reg = dest.at(regs.at(n * j + my_num) + k); - dest_reg = square.rows[j]; + int base = l; + int n_left = min(n_bits - base, unit); + for (int i = base; i < base + n_left; i++) + for (int j = 0; j < n; j++) + dest.at(regs.at(n * i + j) + k) = {}; + + square64 square; + + for (int j = 0; j < m; j++) + square.rows[j] = source[j + start].get_limb(l / unit); + + square.transpose(m, n_left); + + for (int j = 0; j < n_left; j++) + { + auto& dest_reg = dest.at( + regs.at(n * (base + j) + my_num) + k); + dest_reg = square.rows[j]; + } } } } From 2a1ca6ae74350aaaeee8671ea381cb78f46ce155 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 7 Jul 2022 14:39:26 +0200 Subject: [PATCH 070/265] Fix cryptic assert statement in oram.py --- Compiler/oram.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Compiler/oram.py b/Compiler/oram.py index d4b43438..f218dfdb 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1227,7 +1227,8 @@ class TreeORAM(AbstractORAM): """ Batch initalization. Obliviously shuffles and adds N entries to random leaf buckets. """ m = len(values) - assert((m & (m-1)) == 0) + if not (m & (m-1)) == 0: + raise CompilerError('Batch size must a power of 2.') if m != self.size: raise CompilerError('Batch initialization must have N values.') if self.value_type != sint: From 0d642822378bbd6f65bbf63478d3b076cd065d05 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 8 Jul 2022 14:53:21 +1000 Subject: [PATCH 071/265] Basic estimate for shuffling cost. --- Compiler/instructions.py | 41 ++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index aac0c34c..5d6bf5fc 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2408,8 +2408,36 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +class shuffle_base(base.DataInstruction): + n_relevant_parties = 2 + + @staticmethod + def logn(n): + return int(math.ceil(math.log(n, 2))) + + def add_gen_usage(self, req_node, n): + # hack for unknown usage + req_node.increment(('bit', 'inverse'), float('inf')) + # minimal usage with two relevant parties + logn = self.logn(n) + n_switches = logn * 2 ** logn + for i in range(self.n_relevant_parties): + req_node.increment((self.field_type, 'input', i), n_switches) + # multiplications for bit check + req_node.increment((self.field_type, 'triple'), + n_switches * self.n_relevant_parties) + + def add_apply_usage(self, req_node, n, record_size): + req_node.increment(('bit', 'inverse'), float('inf')) + logn = self.logn(n) + n_switches = logn * 2 ** logn * self.n_relevant_parties + if n != 2 ** logn: + record_size += 1 + req_node.increment((self.field_type, 'triple'), + n_switches * record_size) + @base.gf2n -class secshuffle(base.VectorInstruction, base.DataInstruction): +class secshuffle(base.VectorInstruction, shuffle_base): """ Secure shuffling. :param: destination (sint) @@ -2425,9 +2453,10 @@ class secshuffle(base.VectorInstruction, base.DataInstruction): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) -class gensecshuffle(base.DataInstruction): +class gensecshuffle(shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (regint) @@ -2439,9 +2468,9 @@ class gensecshuffle(base.DataInstruction): arg_format = ['ciw','int'] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, self.args[1]) -class applyshuffle(base.VectorInstruction, base.DataInstruction): +class applyshuffle(base.VectorInstruction, shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (sint) @@ -2461,7 +2490,7 @@ class applyshuffle(base.VectorInstruction, base.DataInstruction): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'triple', 0), float('inf')) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) class delshuffle(base.Instruction): """ Delete secure shuffle. From dce0b427d21e19821a15e351cf6c0a564d5138b7 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 14 Jul 2022 15:48:03 +1000 Subject: [PATCH 072/265] Missing vectorization. --- Compiler/types.py | 2 ++ SimpleOT | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 93991df1..03e84e9b 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5063,10 +5063,12 @@ class sfloat(_number, _secret_structure): """ Secret floating-point comparison. """ return 1 - (self < other) + @vectorize def __gt__(self, other): """ Secret floating-point comparison. """ return self.conv(other) < self + @vectorize def __le__(self, other): """ Secret floating-point comparison. """ return self.conv(other) >= self diff --git a/SimpleOT b/SimpleOT index 84d73522..96f8a97e 160000 --- a/SimpleOT +++ b/SimpleOT @@ -1 +1 @@ -Subproject commit 84d73522619f90ba2aabce8d660baef1442aa26d +Subproject commit 96f8a97e6c049e11059337fd33457d84cb730f4c From 6db0ed1bc59e577746e10dd56b956a136365a196 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 15 Jul 2022 11:49:21 +1000 Subject: [PATCH 073/265] Array and matrix sorting in binary circuits. --- Compiler/library.py | 4 ++++ Compiler/types.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index cd32b84b..524a55e1 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -460,6 +460,10 @@ def method_block(function): return wrapper def cond_swap(x,y): + from .types import SubMultiArray + if isinstance(x, (Array, SubMultiArray)): + b = x[0] > y[0] + return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)])) b = x < y if isinstance(x, sfloat): res = ([], []) diff --git a/Compiler/types.py b/Compiler/types.py index 03e84e9b..e74f2630 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -452,6 +452,10 @@ class _bit(Tape._no_truth): s = a ^ b return a ^ (s & (self ^ a)) + def cond_swap(self, a, b): + prod = self * (a ^ b) + return a ^ prod, b ^ prod + class _gf2n(_bit): """ :math:`\mathrm{GF}(2^n)` functionality. """ @@ -5646,7 +5650,8 @@ class Array(_vectorizable): :param batcher: use Batcher's odd-even mergesort in any case :param n_bits: number of bits in keys (default: global bit length) """ - if batcher or self.value_type.n_elements() > 1: + if batcher or self.value_type.n_elements() > 1 or \ + program.options.binary: library.loopy_odd_even_merge_sort(self, n_threads=n_threads) else: if n_threads or 1 > 1: @@ -5739,6 +5744,13 @@ class SubMultiArray(_vectorizable): def to_array(self): return Array(self.total_size(), self.value_type, address=self.address) + def maybe_get(self, condition, index): + return self[condition * index] + + def maybe_set(self, condition, index, value): + for i, x in enumerate(value): + self.maybe_get(condition, index).maybe_set(condition, i, x) + def assign_all(self, value): """ Assign the same value to all entries. @@ -6326,6 +6338,11 @@ class SubMultiArray(_vectorizable): :param n_bits: number of bits in keys (default: global bit length) """ + if program.options.binary: + assert key_indices is None + assert len(self.sizes) == 2 + library.loopy_odd_even_merge_sort(self) + return if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) key_indices = (None,) + util.tuplify(key_indices) From 1a9bcd25e4019a0994fd73ca3253137a05b342d2 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 16 Jul 2022 18:26:10 +1000 Subject: [PATCH 074/265] Correct SimpleOT version. --- SimpleOT | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SimpleOT b/SimpleOT index 96f8a97e..84d73522 160000 --- a/SimpleOT +++ b/SimpleOT @@ -1 +1 @@ -Subproject commit 96f8a97e6c049e11059337fd33457d84cb730f4c +Subproject commit 84d73522619f90ba2aabce8d660baef1442aa26d From 1961a78fa8c341281285be8b76ba28366c36d7ef Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 16 Jul 2022 18:27:09 +1000 Subject: [PATCH 075/265] Fixed bug in MMO with prime fields longer than 1024 bits. --- Tools/MMO.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tools/MMO.hpp b/Tools/MMO.hpp index 4309e1fe..2081df83 100644 --- a/Tools/MMO.hpp +++ b/Tools/MMO.hpp @@ -18,7 +18,7 @@ void MMO::zeroIV() { octet key[AES_BLK_SIZE]; memset(key, 0, AES_BLK_SIZE * sizeof(octet)); - key[i] = i; + key[0] = i; setIV(i, key); } } From 1bbbcd277044da826995a0cd73aa08ab667a8d94 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 18 Jul 2022 14:02:06 +1000 Subject: [PATCH 076/265] Fixed bug in Python client. --- ExternalIO/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ExternalIO/client.py b/ExternalIO/client.py index 819647ff..c0033275 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -30,7 +30,7 @@ class Client: self.specification.Receive(self.sockets[0]) def receive_triples(self, T, n): - triples = [[0, 0, 0] * n] + triples = [[0, 0, 0] for i in range(n)] os = octetStream() for socket in self.sockets: os.Receive(socket) @@ -51,6 +51,7 @@ class Client: T = type(values[0]) triples = self.receive_triples(T, len(values)) os = octetStream() + assert len(values) == len(triples) for value, triple in zip(values, triples): (value + triple[0]).pack(os) for socket in self.sockets: From ac252cd951fe0b6c4a99f0d1df899de8b7cf1b8d Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 19 Jul 2022 12:27:09 +1000 Subject: [PATCH 077/265] Fixed bug in MemValue of size larger than one. --- Compiler/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index e74f2630..52a17cf8 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6633,7 +6633,9 @@ class MemValue(_mem): :return: relevant basic type instance """ self.check() if program.curr_block != self.last_write_block: - self.register = self.value_type.load_mem(self.address) + self.register = self.value_type.load_mem( + self.address, size=self.size \ + if issubclass(self.value_type, _register) else None) self.last_write_block = program.curr_block return self.register From a1074ca69a654b190d8d4ae4810629c04869c0e3 Mon Sep 17 00:00:00 2001 From: prayforwind Date: Fri, 22 Jul 2022 09:27:42 +0800 Subject: [PATCH 078/265] Fix BMR's --input-file and --output-file --- BMR/RealProgramParty.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index ae69cb7f..64efc550 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -110,7 +110,8 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); - this->processor.open_input_file(N.my_num(), 0); + this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file); + this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out); shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); From a0f5bb258e5826fd46e664d6707f4f78503a4c77 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 09:38:39 -0700 Subject: [PATCH 079/265] Update Makefile for macs where Homebrew is installed in non-traditional locations --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 03366f89..f0d6581b 100644 --- a/Makefile +++ b/Makefile @@ -294,8 +294,8 @@ mpir: mpir-setup mac-setup: mac-machine-setup brew install openssl boost libsodium mpir yasm ntl - -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I/opt/homebrew/opt/openssl/include -I/opt/homebrew/include >> CONFIG.mine - -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/openssl/lib >> CONFIG.mine + -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine + -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine # -echo USE_NTL = 1 >> CONFIG.mine ifeq ($(MACHINE), aarch64) From d39ca280e5118f62fa3455898937b1fa45ecc8c1 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 09:48:21 -0700 Subject: [PATCH 080/265] fix sorting import bug in Compiler/types.py --- Compiler/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Compiler/types.py b/Compiler/types.py index 52a17cf8..10b2c424 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5657,7 +5657,7 @@ class Array(_vectorizable): if n_threads or 1 > 1: raise CompilerError('multi-threaded sorting only implemented ' 'with Batcher\'s odd-even mergesort') - import sorting + from . import sorting sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): @@ -6346,7 +6346,7 @@ class SubMultiArray(_vectorizable): if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) key_indices = (None,) + util.tuplify(key_indices) - import sorting + from . import sorting keys = self.get_vector_by_indices(*key_indices) sorting.radix_sort(keys, self, n_bits=n_bits) From 6db5f5d86187b238d927b04b3f9a445a1e90bd18 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 23 Jul 2022 10:03:53 -0700 Subject: [PATCH 081/265] update README to better represent running from other directories --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 81e4f88d..18278c25 100644 --- a/README.md +++ b/README.md @@ -466,21 +466,21 @@ for further examples. #### Compiling and running programs from external directories -Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: +Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all MP-SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: ``` -spdz$ cd ../ +MP-SPDZ$ cd ../ $ mkdir myprogs $ cd myprogs $ mkdir -p Programs/Source $ vi Programs/Source/test.mpc -$ ../spdz/compile.py test.mpc +$ ../MP-SPDZ/compile.py test.mpc $ ls Programs/ Bytecode Public-Input Schedules Source -$ ../spdz/Scripts/setup-online.sh +$ ../MP-SPDZ/Scripts/setup-online.sh $ ls Player-Data Programs -$ ../spdz/Scripts/run-online.sh test +$ ../MP-SPDZ/Scripts/run-online.sh test ``` ### TensorFlow inference From 101879f37a5164d0f18d51d52cdcda86a9c66b06 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 25 Jul 2022 13:40:18 +1000 Subject: [PATCH 082/265] Try loading dynamic library from root directory in scripts on Linux and macOS. --- Scripts/run-common.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 7e5e6d44..c6835069 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -57,3 +57,6 @@ run_player() { players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} + +export LD_LIBRARY_PATH="$SPDZROOT:$LD_LIBRARY_PATH" +export DYLD_LIBRARY_PATH="$SPDZROOT:$DYLD_LIBRARY_PATH" From 81419ba32180c62c07e3033f7eab7c4f810b7184 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 25 Jul 2022 18:12:04 +1000 Subject: [PATCH 083/265] Fix bugs in matrix multiplication with binary circuits. --- Compiler/GC/types.py | 16 ++++++++++++++-- Compiler/types.py | 4 +++- GC/TinySecret.h | 2 +- Processor/Instruction.hpp | 16 +++++++++++++++- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 4287844a..67410678 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -17,6 +17,7 @@ from Compiler import instructions_base import Compiler.GC.instructions as inst import operator import math +import itertools from functools import reduce class bits(Tape.Register, _structure, _bit): @@ -1182,9 +1183,20 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented + other_bits = util.bit_decompose(other) + m = float('inf') + for x in itertools.chain(self.v, other_bits): + try: + m = min(m, x.n) + except: + pass + if m == 1: + op = operator.mul + else: + op = operator.and_ matrix = [] - for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x & b for x in self.v[:len(self.v)-i]]) + for i, b in enumerate(other_bits): + matrix.append([op(x, b) for x in self.v[:len(self.v)-i]]) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ diff --git a/Compiler/types.py b/Compiler/types.py index 10b2c424..1531c49d 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -5263,7 +5263,8 @@ class Array(_vectorizable): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() - if size is not None and isinstance(base, _register): + if size is not None and isinstance(base, _register) \ + and not issubclass(self.value_type, _vec): base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ @@ -6063,6 +6064,7 @@ class SubMultiArray(_vectorizable): assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() + self.value_type.matrix_mul A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9cdde3dc..85098d18 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -146,7 +146,7 @@ public: if (this != &res) res.get_regs().assign(this->get_regs().begin(), this->get_regs().begin() - + max(size_t(n_bits), this->get_regs().size())); + + min(size_t(n_bits), this->get_regs().size())); res.resize_regs(n_bits); } diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 5b0589b6..1d7c883d 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -666,6 +666,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const return r[1] + size; else return 0; + case TRANS: + if (reg_type == SBIT) + { + int n_outputs = n; + auto& args = start; + int n_inputs = args.size() - n_outputs; + long long res = 0; + for (int i = 0; i < n_outputs; i++) + res = max(res, args[i] + DIV_CEIL(n_inputs, 64)); + for (int j = 0; j < n_inputs; j++) + res = max(res, args[n_outputs] + DIV_CEIL(n_outputs, 64)); + return res; + } + else + return 0; default: if (get_reg_type() != reg_type) return 0; @@ -731,7 +746,6 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case ANDM: case NOTS: case NOTCB: - case TRANS: size = DIV_CEIL(n, 64); break; case CONVCBIT2S: From 91960440f578909bc5f06f05ca42f7edc6d1c7ad Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 26 Jul 2022 16:04:56 +1000 Subject: [PATCH 084/265] Reveal sbitvec as list. --- Compiler/GC/types.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 67410678..9fdb5904 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -724,15 +724,7 @@ class sbitvec(_vec): for j, x in enumerate(v[i].bit_decompose()): x.store_in_mem(address + i + j * n) def reveal(self): - if len(self) > cbits.unit: - return self.elements()[0].reveal() - revealed = [cbit() for i in range(len(self))] - for i in range(len(self)): - try: - inst.reveal(1, revealed[i], self.v[i]) - except: - revealed[i] = cbit.conv(self.v[i]) - return cbits.get_type(len(self)).bit_compose(revealed) + return util.untuplify([x.reveal() for x in self.elements()]) @classmethod def two_power(cls, nn): return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) From 97efdbc01fa66adc91592995964e531387c370da Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 26 Jul 2022 17:05:38 +1000 Subject: [PATCH 085/265] Fix bug in preprocessing accounting. --- Compiler/library.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/library.py b/Compiler/library.py index 524a55e1..799f85d2 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -878,7 +878,7 @@ def range_loop(loop_body, start, stop=None, step=None): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) // step) * x[0] + lambda x: int(ceil(((stop - start) / step))) * x[0] def for_range(start, stop=None, step=None): """ From e1b45388768aab6465e2e6c914791903a6dabb48 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 11:17:41 -0400 Subject: [PATCH 086/265] refactor to add Compiler class --- Compiler/__init__.py | 27 --- Compiler/compilerLib.py | 435 ++++++++++++++++++++++++++++++++-------- compile.py | 106 ++-------- 3 files changed, 369 insertions(+), 199 deletions(-) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 9a22da46..6a0d6b1d 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -2,30 +2,3 @@ from . import compilerLib, program, instructions, types, library, floatingpoint from .GC import types as GC_types import inspect from .config import * -from .compilerLib import run - - -# add all instructions to the program VARS dictionary -compilerLib.VARS = {} -instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] - -for mod in (types, GC_types): - instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ - if t[1].__module__ == mod.__name__] - -instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ - if t[1].__module__ == library.__name__] - -for op in instr_classes: - compilerLib.VARS[op.__name__] = op - -# add open and input separately due to name conflict -compilerLib.VARS['open'] = instructions.asm_open -compilerLib.VARS['vopen'] = instructions.vasm_open -compilerLib.VARS['gopen'] = instructions.gasm_open -compilerLib.VARS['vgopen'] = instructions.vgasm_open -compilerLib.VARS['input'] = instructions.asm_input -compilerLib.VARS['ginput'] = instructions.gasm_input - -compilerLib.VARS['comparison'] = comparison -compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index b2898e21..591700c1 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,94 +1,361 @@ -from Compiler.program import Program -from .GC import types as GC_types - +import inspect +import os +import re import sys -import re, tempfile, os +import tempfile +from optparse import OptionParser + +from .GC import types as GC_types +from .program import Program, defaults -def run(args, options): - """ Compile a file and output a Program object. - - If options.merge_opens is set to True, will attempt to merge any - parallelisable open instructions. """ - - prog = Program(args, options) - VARS['program'] = prog - if options.binary: - VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfixvec - for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ - 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ - 'squant': - del VARS[i] - - print('Compiling file', prog.infile) - f = open(prog.infile, 'rb') +class Compiler: + def __init__(self): + self.usage = "usage: %prog [options] filename [args]" + self.build_option_parser() + self.VARS = {} - changed = False - if options.flow_optimization: - output = [] - if_stack = [] - for line in open(prog.infile): - if if_stack and not re.match(if_stack[-1][0], line): - if_stack.pop() - m = re.match( - '(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):', - line) - if m: - output.append('%s@for_range_opt(%s)\n' % (m.group(1), - m.group(3))) - output.append('%sdef _(%s):\n' % (m.group(1), m.group(2))) - changed = True - continue - m = re.match('(\s*)if(\W.*):', line) - if m: - if_stack.append((m.group(1), len(output))) - output.append('%s@if_(%s)\n' % (m.group(1), m.group(2))) - output.append('%sdef _():\n' % (m.group(1))) - changed = True - continue - m = re.match('(\s*)elif\s+', line) - if m: - raise CompilerError('elif not supported') - if if_stack: - m = re.match('%selse:' % if_stack[-1][0], line) - if m: - start = if_stack[-1][1] - ws = if_stack[-1][0] - output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws, - output[start]) - output.append('%s@else_\n' % ws) - output.append('%sdef _():\n' % ws) - continue - output.append(line) - if changed: - infile = tempfile.NamedTemporaryFile('w+', delete=False) - for line in output: - infile.write(line) - infile.seek(0) - else: - infile = open(prog.infile) - else: - infile = open(prog.infile) + def build_option_parser(self): + parser = OptionParser(usage=self.usage) + parser.add_option( + "-n", + "--nomerge", + action="store_false", + dest="merge_opens", + default=defaults.merge_opens, + help="don't attempt to merge open instructions", + ) + parser.add_option("-o", "--output", dest="outfile", help="specify output file") + parser.add_option( + "-a", + "--asm-output", + dest="asmoutfile", + help="asm output file for debugging", + ) + parser.add_option( + "-g", + "--galoissize", + dest="galois", + default=defaults.galois, + help="bit length of Galois field", + ) + parser.add_option( + "-d", + "--debug", + action="store_true", + dest="debug", + help="keep track of trace for debugging", + ) + parser.add_option( + "-c", + "--comparison", + dest="comparison", + default="log", + help="comparison variant: log|plain|inv|sinv", + ) + parser.add_option( + "-M", + "--preserve-mem-order", + action="store_true", + dest="preserve_mem_order", + default=defaults.preserve_mem_order, + help="preserve order of memory instructions; possible efficiency loss", + ) + parser.add_option( + "-O", + "--optimize-hard", + action="store_true", + dest="optimize_hard", + help="currently not in use", + ) + parser.add_option( + "-u", + "--noreallocate", + action="store_true", + dest="noreallocate", + default=defaults.noreallocate, + help="don't reallocate", + ) + parser.add_option( + "-m", + "--max-parallel-open", + dest="max_parallel_open", + default=defaults.max_parallel_open, + help="restrict number of parallel opens", + ) + parser.add_option( + "-D", + "--dead-code-elimination", + action="store_true", + dest="dead_code_elimination", + default=defaults.dead_code_elimination, + help="eliminate instructions with unused result", + ) + parser.add_option( + "-p", + "--profile", + action="store_true", + dest="profile", + help="profile compilation", + ) + parser.add_option( + "-s", + "--stop", + action="store_true", + dest="stop", + help="stop on register errors", + ) + parser.add_option( + "-R", + "--ring", + dest="ring", + default=defaults.ring, + help="bit length of ring (default: 0 for field)", + ) + parser.add_option( + "-B", + "--binary", + dest="binary", + default=defaults.binary, + help="bit length of sint in binary circuit (default: 0 for arithmetic)", + ) + parser.add_option( + "-F", + "--field", + dest="field", + default=defaults.field, + help="bit length of sint modulo prime (default: 64)", + ) + parser.add_option( + "-P", + "--prime", + dest="prime", + default=defaults.prime, + help="prime modulus (default: not specified)", + ) + parser.add_option( + "-I", + "--insecure", + action="store_true", + dest="insecure", + help="activate insecure functionality for benchmarking", + ) + parser.add_option( + "-b", + "--budget", + dest="budget", + default=defaults.budget, + help="set budget for optimized loop unrolling " "(default: 100000)", + ) + parser.add_option( + "-X", + "--mixed", + action="store_true", + dest="mixed", + help="mixing arithmetic and binary computation", + ) + parser.add_option( + "-Y", + "--edabit", + action="store_true", + dest="edabit", + help="mixing arithmetic and binary computation using edaBits", + ) + parser.add_option( + "-Z", + "--split", + default=defaults.split, + dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)", + ) + parser.add_option( + "-C", + "--CISC", + action="store_true", + dest="cisc", + help="faster CISC compilation mode", + ) + parser.add_option( + "-K", + "--keep-cisc", + dest="keep_cisc", + help="don't translate CISC instructions", + ) + parser.add_option( + "-l", + "--flow-optimization", + action="store_true", + dest="flow_optimization", + help="optimize control flow", + ) + parser.add_option( + "-v", + "--verbose", + action="store_true", + dest="verbose", + help="more verbose output", + ) + self.parser = parser - # make compiler modules directly accessible - sys.path.insert(0, 'Compiler') - # create the tapes - exec(compile(infile.read(), infile.name, 'exec'), VARS) + def parse_args(self): + self.options, self.args = self.parser.parse_args() + if len(self.args) < 1: + self.parser.print_help() + return - if changed and not options.debug: - os.unlink(infile.name) + if self.options.optimize_hard: + print("Note that -O/--optimize-hard currently has no effect") - prog.finalize() + def build_program(self): + self.prog = Program(self.args, self.options) - if prog.req_num: - print('Program requires at most:') - for x in prog.req_num.pretty(): - print(x) + def build_vars(self): + from . import comparison, floatingpoint, instructions, library, types - if prog.verbose: - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) + # add all instructions to the program VARS dictionary + instr_classes = [ + t[1] for t in inspect.getmembers(instructions, inspect.isclass) + ] - return prog + for mod in (types, GC_types): + instr_classes += [ + t[1] + for t in inspect.getmembers(mod, inspect.isclass) + if t[1].__module__ == mod.__name__ + ] + + instr_classes += [ + t[1] + for t in inspect.getmembers(library, inspect.isfunction) + if t[1].__module__ == library.__name__ + ] + + for op in instr_classes: + self.VARS[op.__name__] = op + + # add open and input separately due to name conflict + self.VARS["open"] = instructions.asm_open + self.VARS["vopen"] = instructions.vasm_open + self.VARS["gopen"] = instructions.gasm_open + self.VARS["vgopen"] = instructions.vgasm_open + self.VARS["input"] = instructions.asm_input + self.VARS["ginput"] = instructions.gasm_input + + self.VARS["comparison"] = comparison + self.VARS["floatingpoint"] = floatingpoint + + self.VARS["program"] = self.prog + if self.options.binary: + self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary)) + self.VARS["sfix"] = GC_types.sbitfixvec + for i in [ + "cint", + "cfix", + "cgf2n", + "sintbit", + "sgf2n", + "sgf2nint", + "sgf2nuint", + "sgf2nuint32", + "sgf2nfloat", + "sfloat", + "cfloat", + "squant", + ]: + del self.VARS[i] + + def prep_compile(self): + self.parse_args() + self.build_program() + self.build_vars() + + def compile_file(self): + """Compile a file and output a Program object. + + If options.merge_opens is set to True, will attempt to merge any + parallelisable open instructions.""" + print("Compiling file", self.prog.infile) + + with open(self.prog.infile, "rb") as f: + changed = False + if self.options.flow_optimization: + output = [] + if_stack = [] + for line in f: + if if_stack and not re.match(if_stack[-1][0], line): + if_stack.pop() + m = re.match( + r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_]+)\):", + line, + ) + if m: + output.append( + "%s@for_range_opt(%s)\n" % (m.group(1), m.group(3)) + ) + output.append("%sdef _(%s):\n" % (m.group(1), m.group(2))) + changed = True + continue + m = re.match(r"(\s*)if(\W.*):", line) + if m: + if_stack.append((m.group(1), len(output))) + output.append("%s@if_(%s)\n" % (m.group(1), m.group(2))) + output.append("%sdef _():\n" % (m.group(1))) + changed = True + continue + m = re.match(r"(\s*)elif\s+", line) + if m: + raise CompilerError("elif not supported") + if if_stack: + m = re.match("%selse:" % if_stack[-1][0], line) + if m: + start = if_stack[-1][1] + ws = if_stack[-1][0] + output[start] = re.sub( + r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start] + ) + output.append("%s@else_\n" % ws) + output.append("%sdef _():\n" % ws) + continue + output.append(line) + if changed: + infile = tempfile.NamedTemporaryFile("w+", delete=False) + for line in output: + infile.write(line) + infile.seek(0) + else: + infile = open(self.prog.infile) + else: + infile = open(self.prog.infile) + + # make compiler modules directly accessible + sys.path.insert(0, "Compiler") + # create the tapes + exec(compile(infile.read(), infile.name, "exec"), self.VARS) + + if changed and not self.options.debug: + os.unlink(infile.name) + + return self.finalize_compile() + + def compile_func(self, f): + self.prep_compile() + print(f"Compiling function: {f.__name__}") + f(self.VARS) + self.finalize_compile() + + def finalize_compile(self): + self.prog.finalize() + + if self.prog.req_num: + print("Program requires at most:") + for x in self.prog.req_num.pretty(): + print(x) + + if self.prog.verbose: + print("Program requires:", repr(self.prog.req_num)) + print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) + print("Memory size:", dict(self.prog.allocated_mem)) + + return self.prog diff --git a/compile.py b/compile.py index da1b69ee..50671a04 100755 --- a/compile.py +++ b/compile.py @@ -12,100 +12,30 @@ # # See the compiler documentation at https://mp-spdz.readthedocs.io # for details on the Compiler package +from Compiler.compilerLib import Compiler -from optparse import OptionParser -from Compiler.program import defaults -import Compiler +def compilation(compiler): + prog = compiler.compile_file() -def main(): - usage = "usage: %prog [options] filename [args]" - parser = OptionParser(usage=usage) - parser.add_option("-n", "--nomerge", - action="store_false", dest="merge_opens", - default=defaults.merge_opens, - help="don't attempt to merge open instructions") - parser.add_option("-o", "--output", dest="outfile", - help="specify output file") - parser.add_option("-a", "--asm-output", dest="asmoutfile", - help="asm output file for debugging") - parser.add_option("-g", "--galoissize", dest="galois", - default=defaults.galois, - help="bit length of Galois field") - parser.add_option("-d", "--debug", action="store_true", dest="debug", - help="keep track of trace for debugging") - parser.add_option("-c", "--comparison", dest="comparison", default="log", - help="comparison variant: log|plain|inv|sinv") - parser.add_option("-M", "--preserve-mem-order", action="store_true", - dest="preserve_mem_order", - default=defaults.preserve_mem_order, - help="preserve order of memory instructions; possible efficiency loss") - parser.add_option("-O", "--optimize-hard", action="store_true", - dest="optimize_hard", help="currently not in use") - parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate", - default=defaults.noreallocate, help="don't reallocate") - parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open", - default=defaults.max_parallel_open, - help="restrict number of parallel opens") - parser.add_option("-D", "--dead-code-elimination", action="store_true", - dest="dead_code_elimination", - default=defaults.dead_code_elimination, - help="eliminate instructions with unused result") - parser.add_option("-p", "--profile", action="store_true", dest="profile", - help="profile compilation") - parser.add_option("-s", "--stop", action="store_true", dest="stop", - help="stop on register errors") - parser.add_option("-R", "--ring", dest="ring", default=defaults.ring, - help="bit length of ring (default: 0 for field)") - parser.add_option("-B", "--binary", dest="binary", default=defaults.binary, - help="bit length of sint in binary circuit (default: 0 for arithmetic)") - parser.add_option("-F", "--field", dest="field", default=defaults.field, - help="bit length of sint modulo prime (default: 64)") - parser.add_option("-P", "--prime", dest="prime", default=defaults.prime, - help="prime modulus (default: not specified)") - parser.add_option("-I", "--insecure", action="store_true", dest="insecure", - help="activate insecure functionality for benchmarking") - parser.add_option("-b", "--budget", dest="budget", default=defaults.budget, - help="set budget for optimized loop unrolling " - "(default: 100000)") - parser.add_option("-X", "--mixed", action="store_true", dest="mixed", - help="mixing arithmetic and binary computation") - parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", - help="mixing arithmetic and binary computation using edaBits") - parser.add_option("-Z", "--split", default=defaults.split, dest="split", - help="mixing arithmetic and binary computation " - "using direct conversion if supported " - "(number of parties as argument)") - parser.add_option("-C", "--CISC", action="store_true", dest="cisc", - help="faster CISC compilation mode") - parser.add_option("-K", "--keep-cisc", dest="keep_cisc", - help="don't translate CISC instructions") - parser.add_option("-l", "--flow-optimization", action="store_true", - dest="flow_optimization", help="optimize control flow") - parser.add_option("-v", "--verbose", action="store_true", dest="verbose", - help="more verbose output") - options,args = parser.parse_args() - if len(args) < 1: - parser.print_help() - return + if prog.public_input_file is not None: + print( + "WARNING: %s is required to run the program" % prog.public_input_file.name + ) - if options.optimize_hard: - print('Note that -O/--optimize-hard currently has no effect') - def compilation(): - prog = Compiler.run(args, options) - - if prog.public_input_file is not None: - print('WARNING: %s is required to run the program' % \ - prog.public_input_file.name) - - if options.profile: +def main(compiler): + compiler.prep_compile() + if compiler.options.profile: import cProfile - p = cProfile.Profile().runctx('compilation()', globals(), locals()) - p.dump_stats(args[0] + '.prof') + + p = cProfile.Profile().runctx("compilation(compiler)", globals(), locals()) + p.dump_stats(compiler.args[0] + ".prof") p.print_stats(2) else: - compilation() + compilation(compiler) -if __name__ == '__main__': - main() + +if __name__ == "__main__": + compiler = Compiler() + main(compiler) From 1c6c75886f0332d3c0c6f4baefce6bdebf46b2d8 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 14:02:21 -0400 Subject: [PATCH 087/265] allow for name to be passed in for function compiler --- Compiler/compilerLib.py | 24 +- Compiler/program.py | 814 ++++++++++++++++++++++------------------ 2 files changed, 471 insertions(+), 367 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 591700c1..bd9368ba 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -7,11 +7,15 @@ from optparse import OptionParser from .GC import types as GC_types from .program import Program, defaults +from Compiler.exceptions import CompilerError class Compiler: - def __init__(self): - self.usage = "usage: %prog [options] filename [args]" + def __init__(self, usage=None): + if usage: + self.usage = usage + else: + self.usage = "usage: %prog [options] filename [args]" self.build_option_parser() self.VARS = {} @@ -201,15 +205,11 @@ class Compiler: def parse_args(self): self.options, self.args = self.parser.parse_args() - if len(self.args) < 1: - self.parser.print_help() - return - if self.options.optimize_hard: print("Note that -O/--optimize-hard currently has no effect") - def build_program(self): - self.prog = Program(self.args, self.options) + def build_program(self, name=None): + self.prog = Program(self.args, self.options, name=name) def build_vars(self): from . import comparison, floatingpoint, instructions, library, types @@ -266,9 +266,9 @@ class Compiler: ]: del self.VARS[i] - def prep_compile(self): + def prep_compile(self, name=None): self.parse_args() - self.build_program() + self.build_program(name=name) self.build_vars() def compile_file(self): @@ -339,8 +339,8 @@ class Compiler: return self.finalize_compile() - def compile_func(self, f): - self.prep_compile() + def compile_func(self, f, name): + self.prep_compile(name) print(f"Compiling function: {f.__name__}") f(self.VARS) self.finalize_compile() diff --git a/Compiler/program.py b/Compiler/program.py index e06418f3..a94e44d3 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -4,39 +4,40 @@ blocks and registers. Most relevant is the central :py:class:`Program` object that holds various properties of the computation. """ -from Compiler.config import * -from Compiler.exceptions import * -from Compiler.instructions_base import RegType +import inspect +import itertools +import math +import os +import re +import sys +from collections import defaultdict, deque +from functools import reduce + import Compiler.instructions import Compiler.instructions_base import Compiler.instructions_base as inst_base +from Compiler.config import REG_MAX, USER_MEM, COST +from Compiler.exceptions import CompilerError +from Compiler.instructions_base import RegType + from . import allocator as al from . import util -import random -import time -import sys, os, errno -import inspect -from collections import defaultdict, deque -import itertools -import math -from functools import reduce -import re - data_types = dict( - triple = 0, - square = 1, - bit = 2, - inverse = 3, - dabit = 4, + triple=0, + square=1, + bit=2, + inverse=3, + dabit=4, ) field_types = dict( - modp = 0, - gf2n = 1, - bit = 2, + modp=0, + gf2n=1, + bit=2, ) + class defaults: debug = False verbose = False @@ -62,8 +63,9 @@ class defaults: insecure = False keep_cisc = False + class Program(object): - """ A program consists of a list of tapes representing the whole + """A program consists of a list of tapes representing the whole computation. When compiling an :file:`.mpc` file, the single instances is @@ -71,20 +73,22 @@ class Program(object): from Python code, an instance has to be created before running any instructions. """ - def __init__(self, args, options=defaults): - from .non_linear import Ring, Prime, KnownPrime + + def __init__(self, args, options=defaults, name=None): + from .non_linear import KnownPrime, Prime + self.options = options self.verbose = options.verbose self.args = args + self.name = name self.init_names(args) self._security = 40 self.prime = None self.tapes = [] - if sum(x != 0 for x in(options.ring, options.field, - options.binary)) > 1: - raise CompilerError('can only use one out of -B, -R, -F') + if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: + raise CompilerError("can only use one out of -B, -R, -F") if options.prime and (options.ring or options.binary): - raise CompilerError('can only use one out of -B, -R, -p') + raise CompilerError("can only use one out of -B, -R, -p") if options.ring: self.set_ring_size(int(options.ring)) else: @@ -93,19 +97,20 @@ class Program(object): self.prime = int(options.prime) max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: - raise CompilerError('integer bit length can be maximal %s' % - max_bit_length) + raise CompilerError( + "integer bit length can be maximal %s" % max_bit_length + ) self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 - print('Default bit length:', self.bit_length) - print('Default security parameter:', self.security) + print("Default bit length:", self.bit_length) + print("Default security parameter:", self.security) self.galois_length = int(options.galois) if self.verbose: - print('Galois length:', self.galois_length) + print("Galois length:", self.galois_length) self.tape_counter = 0 self._curr_tape = None self.DEBUG = options.debug @@ -119,24 +124,36 @@ class Program(object): self.public_input_file = None self.types = {} self.budget = int(self.options.budget) - self.to_merge = [Compiler.instructions.asm_open_class, \ - Compiler.instructions.gasm_open_class, \ - Compiler.instructions.muls_class, \ - Compiler.instructions.gmuls_class, \ - Compiler.instructions.mulrs_class, \ - Compiler.instructions.gmulrs, \ - Compiler.instructions.dotprods_class, \ - Compiler.instructions.gdotprods_class, \ - Compiler.instructions.asm_input_class, \ - Compiler.instructions.gasm_input_class, - Compiler.instructions.inputfix_class, - Compiler.instructions.inputfloat_class, - Compiler.instructions.inputmixed_class, - Compiler.instructions.trunc_pr_class, - Compiler.instructions_base.Mergeable] + self.to_merge = [ + Compiler.instructions.asm_open_class, + Compiler.instructions.gasm_open_class, + Compiler.instructions.muls_class, + Compiler.instructions.gmuls_class, + Compiler.instructions.mulrs_class, + Compiler.instructions.gmulrs, + Compiler.instructions.dotprods_class, + Compiler.instructions.gdotprods_class, + Compiler.instructions.asm_input_class, + Compiler.instructions.gasm_input_class, + Compiler.instructions.inputfix_class, + Compiler.instructions.inputfloat_class, + Compiler.instructions.inputmixed_class, + Compiler.instructions.trunc_pr_class, + Compiler.instructions_base.Mergeable, + ] import Compiler.GC.instructions as gc - self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ - gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] + + self.to_merge += [ + gc.ldmsdi, + gc.stmsdi, + gc.ldmsd, + gc.stmsd, + gc.stmsdci, + gc.xors, + gc.andrs, + gc.ands, + gc.inputb, + ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed @@ -153,7 +170,8 @@ class Program(object): self.n_running_threads = None self.input_files = {} Program.prog = self - from . import instructions_base, instructions, types, comparison + from . import comparison, instructions, instructions_base, types + instructions.program = self instructions_base.program = self types.program = self @@ -164,53 +182,53 @@ class Program(object): return self.args def max_par_tapes(self): - """ Upper bound on number of tapes that will be run in parallel. - (Excludes empty tapes) """ + """Upper bound on number of tapes that will be run in parallel. + (Excludes empty tapes)""" return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source - if 'Programs' in os.listdir(os.getcwd()): + if "Programs" in os.listdir(os.getcwd()): # compile prog in ./Programs/Source directory - self.programs_dir = os.getcwd() + '/Programs' + self.programs_dir = os.getcwd() + "/Programs" else: # assume source is in main SPDZ directory - self.programs_dir = sys.path[0] + '/Programs' + self.programs_dir = sys.path[0] + "/Programs" if self.verbose: - print('Compiling program in', self.programs_dir) - + print("Compiling program in", self.programs_dir) + # create extra directories if needed - for dirname in ['Public-Input', 'Bytecode', 'Schedules']: - if not os.path.exists(self.programs_dir + '/' + dirname): - os.mkdir(self.programs_dir + '/' + dirname) - - progname = args[0].split('/')[-1] - if progname.endswith('.mpc'): - progname = progname[:-4] - - if os.path.exists(args[0]): - self.infile = args[0] - else: - self.infile = self.programs_dir + '/Source/' + progname + '.mpc' + for dirname in ["Public-Input", "Bytecode", "Schedules"]: + if not os.path.exists(self.programs_dir + "/" + dirname): + os.mkdir(self.programs_dir + "/" + dirname) + + if self.name is None: + self.name = args[0].split("/")[-1] + if self.name.endswith(".mpc"): + self.name = self.name[:-4] + + if os.path.exists(args[0]): + self.infile = args[0] + else: + self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ if self.options.outfile: - self.name = self.options.outfile + '-' + progname + self.name = self.options.outfile + "-" + self.name else: - self.name = progname + self.name = self.name if len(args) > 1: - self.name += '-' + '-'.join(re.sub('/', '_', arg) - for arg in args[1:]) - self.progname = progname + self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:]) def set_ring_size(self, ring_size): from .non_linear import Ring + for tape in self.tapes: - prev = tape.req_bit_length['p'] + prev = tape.req_bit_length["p"] if prev and prev != ring_size: - raise CompilerError('cannot have different ring sizes') + raise CompilerError("cannot have different ring sizes") self.bit_length = ring_size - 1 self.non_linear = Ring(ring_size) self.options.ring = str(ring_size) @@ -234,7 +252,8 @@ class Program(object): :param function: Python function defining the thread :param args: arguments to the function :param name: name used for files - :param single_thread: Boolean indicating whether tape will never be run in parallel to itself + :param single_thread: Boolean indicating whether tape will + never be run in parallel to itself :returns: tape handle """ @@ -258,20 +277,22 @@ class Program(object): return self.run_tapes([[tape_index, arg]])[0] def run_tapes(self, args): - """ Run tapes in parallel. See :py:func:`new_tape` for an example. + """Run tapes in parallel. See :py:func:`new_tape` for an example. - :param args: list of tape handles or tuples of tape handle and extra argument (for :py:func:`~Compiler.library.get_arg`) + :param args: list of tape handles or tuples of tape handle and extra + argument (for :py:func:`~Compiler.library.get_arg`) :returns: list of thread numbers """ if not self.curr_tape.singular: - raise CompilerError('Compiler does not support ' \ - 'recursive spawning of threads') + raise CompilerError( + "Compiler does not support " "recursive spawning of threads" + ) args = [list(util.tuplify(arg)) for arg in args] singular_tapes = set() for arg in args: if self.tapes[arg[0]].singular: if arg[0] in singular_tapes: - raise CompilerError('cannot run singular tape in parallel') + raise CompilerError("cannot run singular tape in parallel") singular_tapes.add(arg[0]) assert len(arg) assert len(arg) <= 2 @@ -286,59 +307,59 @@ class Program(object): else: thread_numbers.append(self.n_threads) self.n_threads += 1 - self.curr_tape.start_new_basicblock(name='pre-run_tape') - Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in - zip(thread_numbers, args)), [])) - self.curr_tape.start_new_basicblock(name='post-run_tape') + self.curr_tape.start_new_basicblock(name="pre-run_tape") + Compiler.instructions.run_tape( + *sum(([x] + list(y) for x, y in zip(thread_numbers, args)), []) + ) + self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: - self.curr_tape.req_node.children.append( - self.tapes[arg[0]].req_tree) + self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree) return thread_numbers def join_tape(self, thread_number): self.join_tapes([thread_number]) def join_tapes(self, thread_numbers): - """ Wait for completion of tapes. See :py:func:`new_tape` for an example. + """Wait for completion of tapes. See :py:func:`new_tape` for an example. :param thread_numbers: list of thread numbers """ - self.curr_tape.start_new_basicblock(name='pre-join_tape') + self.curr_tape.start_new_basicblock(name="pre-join_tape") for thread_number in thread_numbers: Compiler.instructions.join_tape(thread_number) self.curr_tape.free_threads.add(thread_number) - self.curr_tape.start_new_basicblock(name='post-join_tape') + self.curr_tape.start_new_basicblock(name="post-join_tape") def update_req(self, tape): if self.req_num is None: self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): - """ Write all non-empty threads and schedule to files. """ + """Write all non-empty threads and schedule to files.""" nonempty_tapes = [t for t in self.tapes] - sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name - sch_file = open(sch_filename, 'w') - print('Writing to', sch_filename) - sch_file.write(str(self.max_par_tapes()) + '\n') - sch_file.write(str(len(nonempty_tapes)) + '\n') - sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') - sch_file.write('1 0\n') - sch_file.write('0\n') - sch_file.write(' '.join(sys.argv) + '\n') - req = max(x.req_bit_length['p'] for x in self.tapes) + sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name + sch_file = open(sch_filename, "w") + print("Writing to", sch_filename) + sch_file.write(str(self.max_par_tapes()) + "\n") + sch_file.write(str(len(nonempty_tapes)) + "\n") + sch_file.write(" ".join(tape.name for tape in nonempty_tapes) + "\n") + sch_file.write("1 0\n") + sch_file.write("0\n") + sch_file.write(" ".join(sys.argv) + "\n") + req = max(x.req_bit_length["p"] for x in self.tapes) if self.options.ring: - sch_file.write('R:%s' % self.options.ring) + sch_file.write("R:%s" % self.options.ring) elif self.options.prime: - sch_file.write('p:%s' % self.options.prime) + sch_file.write("p:%s" % self.options.prime) else: - sch_file.write('lgp:%s' % req) - sch_file.write('\n') - sch_file.write('opts: %s\n' % ' '.join(self.relevant_opts)) + sch_file.write("lgp:%s" % req) + sch_file.write("\n") + sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) for tape in self.tapes: tape.write_bytes() @@ -347,12 +368,12 @@ class Program(object): tape.optimize(self.options) tape.write_bytes() if self.options.asmoutfile: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) tape.purge() - + @property def curr_tape(self): - """ The tape that is currently running.""" + """The tape that is currently running.""" if self._curr_tape is None: assert not self.tapes self._curr_tape = Tape(self.name, self) @@ -365,13 +386,13 @@ class Program(object): @property def curr_block(self): - """ The basic block that is currently being created. """ + """The basic block that is currently being created.""" return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): - """ Allocate memory from the top """ + """Allocate memory from the top""" if not isinstance(size, int): - raise CompilerError('size must be known at compile time') + raise CompilerError("size must be known at compile time") if size == 0: return if isinstance(mem_type, type): @@ -389,8 +410,7 @@ class Program(object): single_size = size size *= self.n_running_threads else: - raise CompilerError('cannot allocate memory ' - 'outside main thread') + raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] addr = blocks.pop(size) if addr is not None: @@ -400,24 +420,23 @@ class Program(object): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) - if addr + size >= 2 ** 32: - raise CompilerError("allocation exceeded for type '%s'" % - mem_type) - self.allocated_mem_blocks[addr,mem_type] = size + if addr + size >= 2**32: + raise CompilerError("allocation exceeded for type '%s'" % mem_type) + self.allocated_mem_blocks[addr, mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if + tn = get_thread_number() - runtime_error_if(tn > self.n_running_threads, 'malloc') + runtime_error_if(tn > self.n_running_threads, "malloc") return addr + single_size * (tn - 1) else: return addr def free(self, addr, mem_type): - """ Free memory """ - if self.curr_block.alloc_pool \ - is not self.curr_tape.basicblocks[0].alloc_pool: - raise CompilerError('Cannot free memory within function block') - size = self.allocated_mem_blocks.pop((addr,mem_type)) + """Free memory""" + if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool: + raise CompilerError("Cannot free memory within function block") + size = self.allocated_mem_blocks.pop((addr, mem_type)) self.free_mem_blocks[mem_type].push(addr, size) def finalize(self): @@ -435,47 +454,48 @@ class Program(object): if self.options.asmoutfile: for tape in self.tapes: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) def finalize_memory(self): - from . import library - self.curr_tape.start_new_basicblock(None, 'memory-usage') + self.curr_tape.start_new_basicblock(None, "memory-usage") # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() - for mem_type,size in sorted(self.allocated_mem.items()): + for mem_type, size in sorted(self.allocated_mem.items()): if size: - #print "Memory of type '%s' of size %d" % (mem_type, size) + # print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: self.types[mem_type].load_mem(size - 1, mem_type) else: from Compiler.types import _get_type + _get_type(mem_type).load_mem(size - 1, mem_type) if self.verbose: if self.saved: - print('Saved %s memory units through reallocation' % self.saved) + print("Saved %s memory units through reallocation" % self.saved) def public_input(self, x): - """ Append a value to the public input file. """ + """Append a value to the public input file.""" if self.public_input_file is None: - self.public_input_file = open(self.programs_dir + - '/Public-Input/%s' % self.name, 'w') - self.public_input_file.write('%s\n' % str(x)) + self.public_input_file = open( + self.programs_dir + "/Public-Input/%s" % self.name, "w" + ) + self.public_input_file.write("%s\n" % str(x)) def set_bit_length(self, bit_length): - """ Change the integer bit length for non-linear functions. """ + """Change the integer bit length for non-linear functions.""" self.bit_length = bit_length - print('Changed bit length for comparisons etc. to', bit_length) + print("Changed bit length for comparisons etc. to", bit_length) def set_security(self, security): self._security = security self.non_linear.set_security(security) - print('Changed statistical security for comparison etc. to', security) + print("Changed statistical security for comparison etc. to", security) @property def security(self): - """ The statistical security parameter for non-linear - functions. """ + """The statistical security parameter for non-linear + functions.""" return self._security @security.setter @@ -493,7 +513,7 @@ class Program(object): @property def use_trunc_pr(self): if not self._use_trunc_pr: - self.relevant_opts.add('trunc_pr') + self.relevant_opts.add("trunc_pr") return self._use_trunc_pr @use_trunc_pr.setter @@ -501,7 +521,7 @@ class Program(object): self._use_trunc_pr = change def use_edabit(self, change=None): - """ Setting whether to use edaBits for non-linear + """Setting whether to use edaBits for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -509,7 +529,7 @@ class Program(object): """ if change is None: if not self._edabit: - self.relevant_opts.add('edabit') + self.relevant_opts.add("edabit") return self._edabit else: self._edabit = change @@ -518,7 +538,7 @@ class Program(object): return True def use_split(self, change=None): - """ Setting whether to use local arithmetic-binary share + """Setting whether to use local arithmetic-binary share conversion for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -526,16 +546,16 @@ class Program(object): """ if change is None: if not self._split: - self.relevant_opts.add('split') + self.relevant_opts.add("split") return self._split else: if change and not self.options.ring: - raise CompilerError('splitting only supported for rings') - assert change > 1 or change == False + raise CompilerError("splitting only supported for rings") + assert change > 1 or change is False self._split = change def use_square(self, change=None): - """ Setting whether to use preprocessed square tuples + """Setting whether to use preprocessed square tuples (default: false). :param change: change setting if not :py:obj:`None` @@ -559,22 +579,22 @@ class Program(object): self._linear_rounds = change def options_from_args(self): - """ Set a number of options from the command-line arguments. """ - if 'trunc_pr' in self.args: + """Set a number of options from the command-line arguments.""" + if "trunc_pr" in self.args: self.use_trunc_pr = True - if 'signed_trunc_pr' in self.args: + if "signed_trunc_pr" in self.args: self.use_trunc_pr = -1 - if 'split' in self.args or 'split3' in self.args: + if "split" in self.args or "split3" in self.args: self.use_split(3) for arg in self.args: - m = re.match('split([0-9]+)', arg) + m = re.match("split([0-9]+)", arg) if m: self.use_split(int(m.group(1))) - if 'raw' in self.args: + if "raw" in self.args: self.always_raw(True) - if 'edabit' in self.args: + if "edabit" in self.args: self.use_edabit(True) - if 'linear_rounds' in self.args: + if "linear_rounds" in self.args: self.linear_rounds(True) def disable_memory_warnings(self): @@ -583,28 +603,32 @@ class Program(object): @staticmethod def read_tapes(schedule): - m = re.search(r'([^/]*)\.mpc', schedule) + m = re.search(r"([^/]*)\.mpc", schedule) if m: schedule = m.group(1) if not os.path.exists(schedule): - schedule = 'Programs/Schedules/%s.sch' % schedule + schedule = "Programs/Schedules/%s.sch" % schedule try: lines = open(schedule).readlines() except FileNotFoundError: - print('%s not found, have you compiled the program?' % schedule, - file=sys.stderr) + print( + "%s not found, have you compiled the program?" % schedule, + file=sys.stderr, + ) sys.exit(1) - for tapename in lines[2].split(' '): + for tapename in lines[2].split(" "): yield tapename.strip() + class Tape: - """ A tape contains a list of basic blocks, onto which instructions are added. """ + """A tape contains a list of basic blocks, onto which instructions are added.""" + def __init__(self, name, program): - """ Set prime p and the initial instructions and registers. """ + """Set prime p and the initial instructions and registers.""" self.program = program - name += '-%d' % program.get_tape_counter() + name += "-%d" % program.get_tape_counter() self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) @@ -658,9 +682,9 @@ class Tape: def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): - """ Sets the block which we start from next, depending on the condition. + """Sets the block which we start from next, depending on the condition. (Default is to go to next block in the list) """ @@ -668,34 +692,33 @@ class Tape: self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): - """ Add the jump for this block's exit condition to list of - instructions (must be done after merging) """ + """Add the jump for this block's exit condition to list of + instructions (must be done after merging)""" self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): - """ Set the correct relative jump offset """ + """Set the correct relative jump offset""" offset = self.get_offset(self.exit_block) self.exit_condition.set_relative_jump(offset) - #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self, retain_usage=True): def relevant(inst): - req_node = Tape.ReqNode('') + req_node = Tape.ReqNode("") req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} + if retain_usage: - self.usage_instructions = list(filter(relevant, - self.instructions)) + self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] if len(self.usage_instructions) > 1000: - print('Retaining %d instructions' % len(self.usage_instructions)) + print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True @@ -706,14 +729,14 @@ class Tape: instructions = self.instructions for inst in instructions: inst.add_usage(req_node) - req_node.num['all', 'round'] += self.n_rounds - req_node.num['all', 'inv'] += self.n_to_merge + req_node.num["all", "round"] += self.n_rounds + req_node.num["all", "inv"] += self.n_to_merge def expand_cisc(self): new_instructions = [] - if self.parent.program.options.keep_cisc != None: - skip = ['LTZ', 'Trunc'] - skip += self.parent.program.options.keep_cisc.split(',') + if self.parent.program.options.keep_cisc is not None: + skip = ["LTZ", "Trunc"] + skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] for inst in self.instructions: @@ -726,38 +749,38 @@ class Tape: return self.name def is_empty(self): - """ Returns True if the list of basic blocks is empty. + """Returns True if the list of basic blocks is empty. Note: False is returned even when tape only contains basic blocks with no instructions. However, these are removed when - optimize is called. """ + optimize is called.""" if not self.purged: - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 return self._is_empty - def start_new_basicblock(self, scope=False, name=''): + def start_new_basicblock(self, scope=False, name=""): # use False because None means no scope if scope is False: scope = self.active_basicblock - suffix = '%s-%d' % (name, self.block_counter) + suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 - sub = self.BasicBlock(self, self.name + '-' + suffix, scope) + sub = self.BasicBlock(self, self.name + "-" + suffix, scope) self.basicblocks.append(sub) self.active_basicblock = sub self.req_node.add_block(sub) - #print 'Compiling basic block', sub.name + # print 'Compiling basic block', sub.name def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name - self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' + self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" def purge(self): for block in self.basicblocks: block.purge() - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 del self.basicblocks del self.active_basicblock self.purged = True @@ -767,26 +790,29 @@ class Tape: if self.purged: return return function(self, *args, **kwargs) + return wrapper @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print('Tape %s is empty' % self.name) + print("Tape %s is empty" % self.name) return if self.if_states: - print('Tracebacks for open blocks:') + print("Tracebacks for open blocks:") for state in self.if_states: try: print(util.format_trace(state.caller)) except AttributeError: pass print() - raise CompilerError('Unclosed if/else blocks, see tracebacks above') + raise CompilerError("Unclosed if/else blocks, see tracebacks above") if self.program.verbose: - print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) + print( + "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) + ) for block in self.basicblocks: al.determine_scope(block, options) @@ -794,41 +820,56 @@ class Tape: # merge open instructions # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: - for i,block in enumerate(self.basicblocks): + for i, block in enumerate(self.basicblocks): if len(block.instructions) > 0 and self.program.verbose: - print('Processing basic block %s, %d/%d, %d instructions' % \ - (block.name, i, len(self.basicblocks), \ - len(block.instructions))) + print( + "Processing basic block %s, %d/%d, %d instructions" + % ( + block.name, + i, + len(self.basicblocks), + len(block.instructions), + ) + ) # the next call is necessary for allocation later even without merging - merger = al.Merger(block, options, \ - tuple(self.program.to_merge)) + merger = al.Merger(block, options, tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 1000000: - print('Eliminate dead code...') + print("Eliminate dead code...") merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue if len(block.instructions) > 1000000: - print('Merging instructions...') + print("Merging instructions...") numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) if merger.counter and self.program.verbose: - print('Block requires', \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.counter.items()))) + print( + "Block requires", + ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.counter.items()) + ), + ) if merger.counter and self.program.verbose: - print('Block requires %s rounds' % \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.rounds.items()))) + print( + "Block requires %s rounds" + % ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.rounds.items()) + ) + ) # free memory merger = None if options.dead_code_elimination: - block.instructions = [x for x in block.instructions if x is not None] + block.instructions = [ + x for x in block.instructions if x is not None + ] if not (options.merge_opens and self.merge_opens): - print('Not merging instructions in tape %s' % self.name) + print("Not merging instructions in tape %s" % self.name) if options.cisc: self.expand_cisc() @@ -853,19 +894,27 @@ class Tape: reg_counts = self.count_regs() if options.noreallocate: if self.program.verbose: - print('Tape register usage:', dict(reg_counts)) + print("Tape register usage:", dict(reg_counts)) else: if self.program.verbose: - print('Tape register usage before re-allocation:', - dict(reg_counts)) - print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) - print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) - print('Re-allocating...') + print("Tape register usage before re-allocation:", dict(reg_counts)) + print( + "modp: %d clear, %d secret" + % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + ) + print( + "GF2N: %d clear, %d secret" + % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + ) + print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) + def alloc(block): - for reg in sorted(block.used_from_scope, - key=lambda x: (x.reg_type, x.i)): + for reg in sorted( + block.used_from_scope, key=lambda x: (x.reg_type, x.i) + ): allocator.alloc_reg(reg, block.alloc_pool) + def alloc_loop(block): left = deque([block]) while left: @@ -873,73 +922,84 @@ class Tape: alloc(block) for child in block.children: left.append(child) - for i,block in enumerate(reversed(self.basicblocks)): + + for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: - print('Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks))) + print( + "Allocating %s, %d/%d" % (block.name, i, len(self.basicblocks)) + ) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, int) and jump < 0 and \ - block.exit_block.scope is not None: + if ( + isinstance(jump, int) + and jump < 0 + and block.exit_block.scope is not None + ): alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) allocator.finalize(options) if self.program.verbose: - print('Tape register usage:', dict(allocator.usage)) + print("Tape register usage:", dict(allocator.usage)) # offline data requirements if self.program.verbose: - print('Compile offline data requirements...') + print("Compile offline data requirements...") self.req_num = self.req_tree.aggregate() if self.program.verbose: - print('Tape requires', self.req_num) - for req,num in sorted(self.req_num.items()): - if num == float('inf') or num >= 2 ** 32: + print("Tape requires", self.req_num) + for req, num in sorted(self.req_num.items()): + if num == float("inf") or num >= 2**32: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( - Compiler.instructions.use(field_types[req[0]], \ - data_types[req[1]], num, \ - add_to_prog=False)) - elif req[1] == 'input': + Compiler.instructions.use( + field_types[req[0]], data_types[req[1]], num, add_to_prog=False + ) + ) + elif req[1] == "input": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_inp(field_types[req[0]], \ - req[2], num, \ - add_to_prog=False)) - elif req[0] == 'modp': + Compiler.instructions.use_inp( + field_types[req[0]], req[2], num, add_to_prog=False + ) + ) + elif req[0] == "modp": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'gf2n': + Compiler.instructions.use_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "gf2n": self.basicblocks[-1].instructions.append( - Compiler.instructions.guse_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'edabit': + Compiler.instructions.guse_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "edabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(False, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'sedabit': + Compiler.instructions.use_edabit( + False, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "sedabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(True, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'matmul': + Compiler.instructions.use_edabit( + True, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "matmul": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_matmul(*req[1], num, \ - add_to_prog=False)) + Compiler.instructions.use_matmul(*req[1], num, add_to_prog=False) + ) if not self.is_empty(): # bit length requirement - for x in ('p', '2'): + for x in ("p", "2"): if self.req_bit_length[x]: bl = self.req_bit_length[x] if self.program.options.ring: bl = -int(self.program.options.ring) self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(bl, - add_to_prog=False)) + Compiler.instructions.reqbl(bl, add_to_prog=False) + ) if self.program.verbose: - print('Tape requires prime bit length', self.req_bit_length['p']) - print('Tape requires galois bit length', self.req_bit_length['2']) + print("Tape requires prime bit length", self.req_bit_length["p"]) + print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged def expand_cisc(self): @@ -948,93 +1008,99 @@ class Tape: @unpurged def _get_instructions(self): - return itertools.chain.\ - from_iterable(b.instructions for b in self.basicblocks) + return itertools.chain.from_iterable(b.instructions for b in self.basicblocks) @unpurged def get_encoding(self): - """ Get the encoding of the program, in human-readable format. """ + """Get the encoding of the program, in human-readable format.""" return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): - """ Get the byte encoding of the program as an actual string of bytes. """ - return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + """Get the byte encoding of the program as an actual string of bytes.""" + return b"".join( + i.get_bytes() for i in self._get_instructions() if i is not None + ) + @unpurged def write_encoding(self, filename): - """ Write the readable encoding to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the readable encoding to a file.""" + print("Writing to", filename) + f = open(filename, "w") for line in self.get_encoding(): - f.write(str(line) + '\n') + f.write(str(line) + "\n") f.close() - + @unpurged def write_str(self, filename): - """ Write the sequence of instructions to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the sequence of instructions to a file.""" + print("Writing to", filename) + f = open(filename, "w") n = 0 for block in self.basicblocks: if block.instructions: - f.write('# %s\n' % block.name) + f.write("# %s\n" % block.name) for line in block.instructions: - f.write('%s # %d\n' % (line, n)) + f.write("%s # %d\n" % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): - """ Write the program's byte encoding to a file. """ + """Write the program's byte encoding to a file.""" if filename is None: filename = self.outfile - if not filename.endswith('.bc'): - filename += '.bc' - if not 'Bytecode' in filename: - filename = self.program.programs_dir + '/Bytecode/' + filename - print('Writing to', filename) - f = open(filename, 'wb') + if not filename.endswith(".bc"): + filename += ".bc" + if "Bytecode" not in filename: + filename = self.program.programs_dir + "/Bytecode/" + filename + print("Writing to", filename) + f = open(filename, "wb") for i in self._get_instructions(): if i is not None: f.write(i.get_bytes()) f.close() - + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name class ReqNum(defaultdict): def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) + def __add__(self, other): res = Tape.ReqNum() - for i,count in list(self.items()): - res[i] += count - for i,count in list(other.items()): + for i, count in list(self.items()): + res[i] += count + for i, count in list(other.items()): res[i] += count return res + def __mul__(self, other): res = Tape.ReqNum() for i in self: res[i] = other * self[i] return res + __rmul__ = __mul__ + def set_all(self, value): - if value == float('inf') and self['all', 'inv'] > 0: - print('Going to unknown from %s' % self) + if value == float("inf") and self["all", "inv"] > 0: + print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: res[i] = value return res + def max(self, other): res = Tape.ReqNum() for i in self: @@ -1042,82 +1108,103 @@ class Tape: for i in other: res[i] = max(self[i], other[i]) return res + def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ - if req[1] != 'input' and req[0] != 'edabit') + return sum( + num * COST[req[0]][req[1]] + for req, num in list(self.items()) + if req[1] != "input" and req[0] != "edabit" + ) + def pretty(self): - t = lambda x: 'integer' if x == 'modp' else x + def t(x): + return "integer" if x == "modp" else x + res = [] for req, num in self.items(): domain = t(req[0]) - n = '%12.0f' % num - if req[1] == 'input': - res += ['%s %s inputs from player %d' \ - % (n, domain, req[2])] - elif domain.endswith('edabit'): - if domain == 'sedabit': - eda = 'strict edabits' + n = "%12.0f" % num + if req[1] == "input": + res += ["%s %s inputs from player %d" % (n, domain, req[2])] + elif domain.endswith("edabit"): + if domain == "sedabit": + eda = "strict edabits" else: - eda = 'loose edabits' - res += ['%s %s of length %d' % (n, eda, req[1])] - elif domain == 'matmul': - res += ['%s matrix multiplications (%dx%d * %dx%d)' % - (n, req[1][0], req[1][1], req[1][1], req[1][2])] - elif req[0] != 'all': - res += ['%s %s %ss' % (n, domain, req[1])] - if self['all','round']: - res += ['% 12.0f virtual machine rounds' % self['all','round']] + eda = "loose edabits" + res += ["%s %s of length %d" % (n, eda, req[1])] + elif domain == "matmul": + res += [ + "%s matrix multiplications (%dx%d * %dx%d)" + % (n, req[1][0], req[1][1], req[1][1], req[1][2]) + ] + elif req[0] != "all": + res += ["%s %s %ss" % (n, domain, req[1])] + if self["all", "round"]: + res += ["% 12.0f virtual machine rounds" % self["all", "round"]] return res + def __str__(self): - return ', '.join(self.pretty()) + return ", ".join(self.pretty()) + def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ['num', 'children', 'name', 'blocks'] + __slots__ = ["num", "children", "name", "blocks"] + def __init__(self, name): self.children = [] self.name = name self.blocks = [] + def aggregate(self, *args): self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) - res = reduce(lambda x,y: x + y.aggregate(self.name), - self.children, self.num) + res = reduce( + lambda x, y: x + y.aggregate(self.name), self.children, self.num + ) return res + def increment(self, data_type, num=1): self.num[data_type] += num + def add_block(self, block): self.blocks.append(block) class ReqChild(object): - __slots__ = ['aggregator', 'nodes', 'parent'] + __slots__ = ["aggregator", "nodes", "parent"] + def __init__(self, aggregator, parent): self.aggregator = aggregator self.nodes = [] self.parent = parent + def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) try: n_reps = self.aggregator([1]) - n_rounds = res['all', 'round'] - n_invs = res['all', 'inv'] + n_rounds = res["all", "round"] + n_invs = res["all", "inv"] if (n_invs / n_rounds) * 1000 < n_reps: - print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) - except: + print( + self.nodes[0].blocks[0].name, + "blowing up rounds: ", + "(%d / %d) ** 3 < %d" % (n_rounds, n_reps, n_invs), + ) + except Exception: pass return res + def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) tape.req_node = new_node - def open_scope(self, aggregator, scope=False, name=''): + def open_scope(self, aggregator, scope=False, name=""): child = self.ReqChild(aggregator, self.req_node) self.req_node.children.append(child) - child.add_node(self, '%s-%d' % (name, len(self.basicblocks))) + child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) self.start_new_basicblock(name=name) return child @@ -1125,21 +1212,21 @@ class Tape: self.req_node = parent_req_node self.start_new_basicblock(outer_scope, name) - def require_bit_length(self, bit_length, t='p'): - if t == 'p': + def require_bit_length(self, bit_length, t="p"): + if t == "p": if self.program.prime: - if (bit_length >= self.program.prime.bit_length() - 1): + if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( - 'required bit length %d too much for %d' % \ - (bit_length, self.program.prime)) - self.req_bit_length[t] = max(bit_length + 1, \ - self.req_bit_length[t]) + "required bit length %d too much for %d" + % (bit_length, self.program.prime) + ) + self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @staticmethod def read_instructions(tapename): - tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + tape = open("Programs/Bytecode/%s.bc" % tapename, "rb") while tape.peek(): yield inst_base.ParsedInstruction(tape) @@ -1147,23 +1234,35 @@ class Tape: __slots__ = [] def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") + raise CompilerError( + "Cannot derive truth value from register, " + "consider using 'compile.py -l'" + ) class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. """ - __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ - "size", "vector", "vectorbase", "caller", \ - "can_eliminate", "duplicates"] + + __slots__ = [ + "reg_type", + "program", + "absolute_i", + "relative_i", + "size", + "vector", + "vectorbase", + "caller", + "can_eliminate", + "duplicates", + ] maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): - """ Creates a new register. - reg_type must be one of those defined in RegType. """ - if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': + """Creates a new register. + reg_type must be one of those defined in RegType.""" + if Compiler.instructions_base.get_global_instruction_type() == "gf2n": if reg_type == RegType.ClearModp: reg_type = RegType.ClearGF2N elif reg_type == RegType.SecretModp: @@ -1173,7 +1272,7 @@ class Tape: if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: - raise CompilerError('vector too large: %d' % size) + raise CompilerError("vector too large: %d" % size) self.size = size self.vectorbase = self self.relative_i = 0 @@ -1183,7 +1282,7 @@ class Tape: self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size else: - self.i = float('inf') + self.i = float("inf") self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) @@ -1204,13 +1303,14 @@ class Tape: if self.size == size: return else: - raise CompilerError('Mismatch of instruction and register size:' - ' %s != %s' % (self.size, size)) + raise CompilerError( + "Mismatch of instruction and register size:" + " %s != %s" % (self.size, size) + ) def set_vectorbase(self, vectorbase): if self.vectorbase is not self: - raise CompilerError('Cannot assign one register' \ - 'to several vectors') + raise CompilerError("Cannot assign one register" "to several vectors") self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase @@ -1218,7 +1318,7 @@ class Tape: return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base=0, size=None): - if size == None: + if size is None: size = self.size if base == 0 and size == self.size: return self @@ -1227,7 +1327,7 @@ class Tape: res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() - res.vector = self.vector[base:base+size] + res.vector = self.vector[base : base + size] return res def create_vector_elements(self): @@ -1265,14 +1365,18 @@ class Tape: @property def is_gf2n(self): - return self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.SecretGF2N - + return ( + self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.SecretGF2N + ) + @property def is_clear(self): - return self.reg_type == RegType.ClearModp or \ - self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.ClearInt + return ( + self.reg_type == RegType.ClearModp + or self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.ClearInt + ) def __str__(self): return self.reg_type + str(self.i) From 4859a09633f4696040f8a6800a7bd35ec69b9622 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 14:53:32 -0400 Subject: [PATCH 088/265] update to use decorator --- Compiler/compilerLib.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index bd9368ba..d695c506 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -5,9 +5,10 @@ import sys import tempfile from optparse import OptionParser +from Compiler.exceptions import CompilerError + from .GC import types as GC_types from .program import Program, defaults -from Compiler.exceptions import CompilerError class Compiler: @@ -339,10 +340,34 @@ class Compiler: return self.finalize_compile() - def compile_func(self, f, name): - self.prep_compile(name) - print(f"Compiling function: {f.__name__}") - f(self.VARS) + def register_function(self, name=None): + """ + To register a function to be compiled, use this as a decorator. + Example: + + @compiler.register_function('test-mpc') + def test_mpc(compiler): + ... + """ + + def inner(func): + self.compile_name = name or func.__name__ + self.compile_function = func + return func + + return inner + + def compile_func(self): + if not hasattr(self, "compile_name") and hasattr(self, "compile_func"): + raise CompilerError( + "No function to compile. " + "Did you decorate a function with @register_fuction(name)?" + ) + self.prep_compile(self.compile_name) + print( + f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" + ) + self.compile_function(self) self.finalize_compile() def finalize_compile(self): From 24a7b4f69d0618bebbe29cc31533c1ca9f829061 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 15:12:40 -0400 Subject: [PATCH 089/265] add setup.py to and an example mpc program --- Programs/Source/test_args.mpc | 31 +++++++++++++++++++++++++++++++ setup.py | 7 +++++++ 2 files changed, 38 insertions(+) create mode 100644 Programs/Source/test_args.mpc create mode 100644 setup.py diff --git a/Programs/Source/test_args.mpc b/Programs/Source/test_args.mpc new file mode 100644 index 00000000..88a3a803 --- /dev/null +++ b/Programs/Source/test_args.mpc @@ -0,0 +1,31 @@ +from Compiler.library import print_ln +from Compiler.types import Matrix, sint +from Compiler.compilerLib import Compiler + + +usage = "usage: %prog [options] [args]" +compiler = Compiler(usage=usage) +compiler.parser.add_option("--rows", dest="rows") +compiler.parser.add_option("--columns", dest="columns") +compiler.parse_args() +if not compiler.options.rows: + compiler.parser.error("--rows required") +if not compiler.options.columns: + compiler.parser.error("--columns required") + + +@compiler.register_function('testmpc') +def main(compiler): + numrows = int(compiler.options.rows) + numcolumns = int(compiler.options.columns) + rows = range(numrows) + reports = Matrix(numrows, numcolumns, sint) + reports.assign_vector( + sint.get_input_from(0, size=numrows * numcolumns) + ) + for row in rows: + print_ln(f"report[{row}]: %s", reports[row].reveal()) + + +if __name__ == "__main__": + compiler.compile_func() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..5e850bc5 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name='mp-spdz-compiler', + version='0.1.0', + packages=find_packages(include=['Compiler', 'Compiler.*']) +) From 7005ba4eaec426714a43892a95a0fea9bee90549 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Wed, 3 Aug 2022 15:40:14 -0400 Subject: [PATCH 090/265] remove unneed compiler parameter --- Compiler/compilerLib.py | 2 +- Programs/Source/test_args.mpc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index d695c506..113f4a8e 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -367,7 +367,7 @@ class Compiler: print( f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" ) - self.compile_function(self) + self.compile_function() self.finalize_compile() def finalize_compile(self): diff --git a/Programs/Source/test_args.mpc b/Programs/Source/test_args.mpc index 88a3a803..a9cee12e 100644 --- a/Programs/Source/test_args.mpc +++ b/Programs/Source/test_args.mpc @@ -15,7 +15,7 @@ if not compiler.options.columns: @compiler.register_function('testmpc') -def main(compiler): +def main(): numrows = int(compiler.options.rows) numcolumns = int(compiler.options.columns) rows = range(numrows) From 497dd79ab4e90aa4cdc24e98d292756c882b9b92 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 3 Aug 2022 18:48:51 +1000 Subject: [PATCH 091/265] Fix bug in LSB extraction. --- Compiler/comparison.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 23bee219..84bdd22b 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -637,6 +637,7 @@ def Mod2(a_0, a, k, kappa, signed): t = [program.curr_block.new_reg('s') for i in range(6)] c2k1 = program.curr_block.new_reg('c') PRandM(r_dprime, r_prime, [r_0], k, 1, kappa) + r_0 = r_prime mulsi(t[0], r_dprime, 2) if signed: ld2i(c2k1, k - 1) From c4c167fac7e772090941a0445902eea02848f2bd Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 5 Aug 2022 15:09:03 +1000 Subject: [PATCH 092/265] Flow optimization test. --- Programs/Source/test_flow_optimization.mpc | 23 ++++++++++++++++++++++ Scripts/test_flow_optimization.sh | 4 ++++ 2 files changed, 27 insertions(+) create mode 100644 Programs/Source/test_flow_optimization.mpc create mode 100755 Scripts/test_flow_optimization.sh diff --git a/Programs/Source/test_flow_optimization.mpc b/Programs/Source/test_flow_optimization.mpc new file mode 100644 index 00000000..ba7af650 --- /dev/null +++ b/Programs/Source/test_flow_optimization.mpc @@ -0,0 +1,23 @@ +n = 10 ** 7 +a = regint.Array(n) +b = regint.Array(n) + +for i in range(n): + if i > 1000: + a[i] = i + + if i < 1000: + b[i] = -1 + else: + b[i] = 2 * i + +def test(a, index, value): + print_ln('expected %s got %s at %s', value, a[index], index) + crash(a[index] != value) + +test(a, 999, 0) +test(b, 999, -1) +test(a, 10000, 10000) +test(b, 10000, 20000) +test(a, 1000000, 1000000) +test(b, 1000000, 2000000) diff --git a/Scripts/test_flow_optimization.sh b/Scripts/test_flow_optimization.sh new file mode 100755 index 00000000..b9ec62f6 --- /dev/null +++ b/Scripts/test_flow_optimization.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +./compile.py -l test_flow_optimization || exit 1 +Scripts/rep-field.sh test_flow_optimization || exit 1 From 5e4e3dd1a981e202b4eb99aa2feadd81d3f6fadb Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 5 Aug 2022 10:22:01 -0400 Subject: [PATCH 093/265] load mpc file as a string, not bytes --- Compiler/compilerLib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 113f4a8e..9e36e9f9 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -279,7 +279,7 @@ class Compiler: parallelisable open instructions.""" print("Compiling file", self.prog.infile) - with open(self.prog.infile, "rb") as f: + with open(self.prog.infile, "r") as f: changed = False if self.options.flow_optimization: output = [] From a1658819cd6d3faa4ea98dcbef255bb0cf01ab47 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Sat, 6 Aug 2022 12:47:36 +1000 Subject: [PATCH 094/265] Fix bug in sintbit. --- Compiler/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/types.py b/Compiler/types.py index 1531c49d..128fa203 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2821,7 +2821,9 @@ class sintbit(sint): elif util.is_zero(other): return self elif util.is_one(other): - return 1 + res = sintbit() + submr(res, cint(1), self) + return res else: return NotImplemented From 80c7250ec8c64cbd7ec8d463fab774624ac7c519 Mon Sep 17 00:00:00 2001 From: hernan232 Date: Mon, 8 Aug 2022 11:18:23 -0500 Subject: [PATCH 095/265] Add documentation about SPDZ2k non-interactive execution and correct typos in README. --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 18278c25..8eefe82d 100644 --- a/README.md +++ b/README.md @@ -630,6 +630,12 @@ e.g. if this machine is name `diffie` on the local network: The software uses TCP ports around 5000 by default, use the `-pn` argument to change that. +If you are using the SPDZ2k protocol in non-interactive mode to run a +program compiled with a ring size different from 64, you must specify +the ring size in the script to run the program as follows: + +`Scripts/spdz2k.sh -R tutorial` + ### Yao's garbled circuits We use half-gate garbling as described by [Zahur et @@ -796,7 +802,7 @@ with three parties overall, Party 0 and 1 run the online phase. ## BMR -BMR (Bellare-Micali-Rogaway) is a method of generating a garbled circuit +BMR (Beaver-Micali-Rogaway) is a method of generating a garbled circuit using another secure computation protocol. We have implemented BMR based on all available implementations using GF(2^128) because the nature of this field particularly suits the Free-XOR optimization for garbled From d6f843f5cf480281681f7cedfd90b7f75352056a Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 8 Aug 2022 18:14:55 +1000 Subject: [PATCH 096/265] Fix bugs in Python client. --- ExternalIO/client.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ExternalIO/client.py b/ExternalIO/client.py index c0033275..a6fd0b03 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -73,20 +73,21 @@ class octetStream: self.ptr = 0 def Send(self, socket): - socket.send(struct.pack(' Date: Wed, 10 Aug 2022 12:45:19 +1000 Subject: [PATCH 097/265] Use edaBits for equality test with rings. --- Compiler/floatingpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index c596240b..d3d3f8c5 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,7 +28,9 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_dabit: + if program.Program.prog.use_edabit: + r_prime, r = types.sint.get_edabit(k) + elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) r_prime = types.sint.bit_compose(rr) else: From 3f90cc3e7c7a573447687066003155d084b71bce Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 11 Aug 2022 11:11:00 +1000 Subject: [PATCH 098/265] Fix bugs in sorting with binary circuits. --- Compiler/GC/types.py | 9 +++++++++ Compiler/types.py | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 9fdb5904..5530432b 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -236,6 +236,11 @@ class bits(Tape.Register, _structure, _bit): This will output 1. """ return result_conv(x, y)(self & (x ^ y) ^ y) + def zero_if_not(self, condition): + if util.is_constant(condition): + return self * condition + else: + return self * cbit.conv(condition) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -491,6 +496,8 @@ class sbits(bits): if isinstance(other, int): return self.mul_int(other) try: + if (self.n, other.n) == (1, 1): + return self & other if min(self.n, other.n) != 1: raise NotImplementedError('high order multiplication') n = max(self.n, other.n) @@ -740,6 +747,8 @@ class sbitvec(_vec): bits += [0] * (n - len(bits)) assert len(bits) == n return cls.from_vec(bits) + def zero_if_not(self, condition): + return self.from_vec(x.zero_if_not(condition) for x in self.v) def __str__(self): return 'sbitvec(%d)' % n return sbitvecn diff --git a/Compiler/types.py b/Compiler/types.py index 128fa203..75e1f58f 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -324,6 +324,9 @@ class _number(Tape._no_truth): def popcnt_bits(bits): return sum(bits) + def zero_if_not(self, condition): + return condition * self + class _int(Tape._no_truth): """ Integer functionality. """ @@ -5331,7 +5334,7 @@ class Array(_vectorizable): :param condition: 0/1 (regint/cint/int) :param index: regint/cint/int """ - return condition * self[condition * index] + return self[condition * index].zero_if_not(condition) def maybe_set(self, condition, index, value): """ Change entry if condition is true. From f469dfc4735d2cb3d8293a4cc5c8c2dcb4ec9171 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Tue, 26 Jul 2022 14:07:35 +0200 Subject: [PATCH 099/265] Add the INVPERM instruction The IVMPERM instruction takes in a secret shared vector representing a permutation, and returns the corresponding secret shared inverse permutation. --- Compiler/instructions.py | 20 +++++ Compiler/instructions_base.py | 1 + Compiler/types.py | 5 ++ Processor/Instruction.h | 1 + Processor/Instruction.hpp | 7 ++ Processor/Processor.h | 3 + Processor/Processor.hpp | 16 ++++ Protocols/SecureShuffle.h | 50 +++++++++++- Protocols/SecureShuffle.hpp | 143 +++++++++++++++++++++++++++------- 9 files changed, 216 insertions(+), 30 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 5d6bf5fc..058b6ff4 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2501,6 +2501,26 @@ class delshuffle(base.Instruction): code = base.opcodes['DELSHUFFLE'] arg_format = ['ci'] +class inverse_permutation(base.VectorInstruction, shuffle_base): + """ Calculate the inverse permutation of a secret permutation. + + :param: destination (sint) + :param: source (sint) + + """ + __slots__ = [] + code = base.opcodes['INVPERM'] + arg_format = ['sw', 's'] + + def __init__(self, *args, **kwargs): + super(inverse_permutation, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + + def add_usage(self, req_node): + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), 1) + + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 3a56e604..f7aa48f9 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -111,6 +111,7 @@ opcodes = dict( GENSECSHUFFLE = 0xFB, APPLYSHUFFLE = 0xFC, DELSHUFFLE = 0xFD, + INVPERM = 0xFE, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/types.py b/Compiler/types.py index 75e1f58f..a69e5e52 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2776,6 +2776,11 @@ class sint(_secret, _int): applyshuffle(res, self, unit_size, shuffle, reverse) return res + def inverse_permutation(self): + res = sint(size=self.size) + inverse_permutation(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ diff --git a/Processor/Instruction.h b/Processor/Instruction.h index f3caf565..1de58c99 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -113,6 +113,7 @@ enum GENSECSHUFFLE = 0xFB, APPLYSHUFFLE = 0xFC, DELSHUFFLE = 0xFD, + INVPERM = 0xFE, // Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 1d7c883d..7763c837 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -283,6 +283,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) n = get_int(s); get_vector(2, start, s); break; + // instructions with 2 register operands + case INVPERM: + get_vector(2, start, s); + break; // open instructions + read/write instructions with variable length args case OPEN: case GOPEN: @@ -1076,6 +1080,9 @@ inline void Instruction::execute(Processor& Proc) const case DELSHUFFLE: Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); return; + case INVPERM: + Proc.Procp.inverse_permutation(*this); + return; case CHECK: { CheckJob job; diff --git a/Processor/Processor.h b/Processor/Processor.h index 927e9327..e29f6eb4 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -77,6 +77,7 @@ public: size_t generate_secure_shuffle(const Instruction& instruction); void apply_shuffle(const Instruction& instruction, int handle); void delete_shuffle(int handle); + void inverse_permutation(const Instruction& instruction); void input_personal(const vector& args); void send_personal(const vector& args); @@ -101,6 +102,8 @@ public: { return C[i]; } + + void inverse_permutation(const Instruction &instruction, int handle); }; class ArithmeticProcessor : public ProcessorBase diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 861e8cfe..e80df0d0 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -668,6 +668,12 @@ void SubProcessor::delete_shuffle(int handle) shuffler.del(handle); } +template +void SubProcessor::inverse_permutation(const Instruction& instruction) { + shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0], + instruction.get_start()[1]); +} + template void SubProcessor::input_personal(const vector& args) { @@ -686,6 +692,16 @@ void SubProcessor::input_personal(const vector& args) S[args[i + 2] + j] = input.finalize(args[i + 1]); } +/** + * + * @tparam T + * @param args Args contains four arguments + * a[0] = the size of the input (and output) vector + * a[1] = the player to which to reveal the output + * a[2] = the memory address of the input vector (sint) (i.e. the value to reveal) + * a[3] = the memory address of the output vector (cint) (i.e. the register to store the revealed value) + * // TODO: When would there be multiple sets of arguments? (for ... i < args.size(); i += 4 ... ) + */ template void SubProcessor::private_output(const vector& args) { diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index a90c6e64..c1c265ea 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -24,8 +24,28 @@ class SecureShuffle size_t n_shuffle; bool exact; + /** + * Generates and returns a newly generated random permutation. This permutation is generated locally. + * + * @param n The size of the permutation to generate. + * @return A vector representing a permutation, a shuffled array of integers 0 through n-1. + */ + vector generate_random_permutation(int n); + + /** + * Configure a shared waksman network from a permutation known only to config_player. + * Note that although the configuration bits of the waksman network are secret shared, + * the player that generated the permutation (config_player) knows the value of these bits. + * + * A permutation is a mapping represented as a vector. + * Each item in the vector represents the output of mapping(i) where i is the index of that item. + * e.g. [2, 4, 0, 3, 1] -> perm(1) = 4 + * + * @param config_player The player tasked with generating the random permutation from which to configure the waksman network. + * @param n_shuffle The size of the permutation to generate. + */ + void configure(int config_player, vector* perm, int n); void player_round(int config_player); - void generate(int config_player, int n_shuffle); void waksman(vector& a, int depth, int start); void cond_swap(T& x, T& y, const T& b); @@ -44,9 +64,37 @@ public: int generate(int n_shuffle); + /** + * + * @param a The vector of registers representing the stack // TODO: Is this correct? + * @param n The size of the input vector to shuffle + * @param unit_size Determines how many vector items constitute a single block with regards to permutation: + * i.e. input vector [1,2,3,4] with unit_size=2 under permutation map [1,0] + * would result in [3,4,1,2] + * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) + * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) + * @param handle The integer identifying the preconfigured waksman network (shuffle) to use. Such a handle can be obtained from calling + * @param reverse Boolean indicating whether to apply the inverse of the permutation + * @see SecureShuffle::generate for obtaining a shuffle handle + */ void apply(vector& a, size_t n, int unit_size, size_t output_base, size_t input_base, int handle, bool reverse); + /** + * Calculate the secret inverse permutation of stack given secret permutation. + * + * This method is given in [1], based on stack technique in [2]. It is used in the Compiler (high-level) implementation of Square-Root ORAM. + * + * [1] Samee Zahur, Xiao Wang, Mariana Raykova, Adrià Gascón, Jack Doerner, David Evans, and Jonathan Katz. 2016. Revisiting Square Root ORAM: Efficient Random Access in Multi-Party Computation. In IEEE S&P. + * [2] Ivan Damgård, Matthias Fitzi, Eike Kiltz, Jesper Buus Nielsen, and Tomas Toft. Unconditionally Secure Constant-rounds Multi-Party Computation for Equality, Comparison, Bits and Exponentiation. In Theory of Cryptography, 2006. + * + * @param stack The vector or registers representing the stack (?) + * @param n The size of the input vector for which to calculate the inverse permutation + * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) + * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) + */ + void inverse_permutation(vector& stack, size_t n, size_t output_base, size_t input_base); + void del(int handle); }; diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index d2b0676a..5d713066 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -58,6 +58,82 @@ void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t outpu post(a, n, output_base); } + +template +void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t output_base, + size_t input_base) { + int alice = 0; + int bob = 1; + + auto &P = proc.P; + auto &input = proc.input; + + // This method only supports two players + assert(proc.protocol.get_relevant_players().size() == 2); + // The current implementation assumes a semi-honest environment + assert(!T::malicious); + + // We are dealing directly with permutations, so the unit_size will always be 1. + this->unit_size = 1; + // We need to account for sizes which are not a power of 2 + size_t n_pow2 = (1u << int(ceil(log2(n)))); + + // Copy over the input registers + pre(stack, n, input_base); + // Alice generates stack local permutation and shares the waksman configuration bits secretly to Bob. + vector perm_alice(n_pow2); + if (P.my_num() == alice) + perm_alice = generate_random_permutation(n); + configure(alice, &perm_alice, n); + // Apply perm_alice to perm_alice to get perm_bob, + // stack permutation that we can reveal to Bob without Bob learning anything about perm_alice (since it is masked by perm_a) + iter_waksman(true); + // Store perm_bob at stack[output_base] + post(stack, n, output_base); + + // Reveal permutation perm_bob = perm_a * perm_alice + // Since this permutation is masked by perm_a, Bob learns nothing about perm + vector perm_bob(n_pow2); + typename T::PrivateOutput output(proc); + for (size_t i = 0; i < n; i++) + output.prepare_sending(stack[output_base + i], bob); + output.exchange(); + for (size_t i = 0; i < n_pow2; i++) { + // TODO: Is there a better way to convert a T::clear to int? + bigint val; + output.finalize(bob).to(val); + perm_bob[i] = (int) val.get_si(); + } + + vector perm_bob_inv(n_pow2); + if (P.my_num() == bob) { + for (int i = 0; i < (int) n; i++) + perm_bob_inv[perm_bob[i]] = i; + // Pad the permutation to n_pow2 + // Required when using waksman networks + for (int i = (int) n; i < (int) n_pow2; i++) + perm_bob_inv[i] = i; + } + + // Alice secret shares perm_a with bob + // perm_a is stored in the stack at output_base + input.reset_all(P); + if (P.my_num() == alice) { + for (int i = 0; i < (int) n; i++) + input.add_mine(perm_alice[i]); + } + input.exchange(); + for (int i = 0; i < (int) n; i++) + stack[output_base + i] = input.finalize(alice); + + // The two parties now jointly compute perm_a * perm_bob_inv to obtain perm_inv + pre(stack, n, output_base); + configure(bob, &perm_bob_inv, n); + iter_waksman(true); + // perm_inv is written back to stack[output_base] + post(stack, n, output_base); +} + template void SecureShuffle::del(int handle) { @@ -129,9 +205,27 @@ void SecureShuffle::post(vector& a, size_t n, size_t output_base) } template -void SecureShuffle::player_round(int config_player) -{ - generate(config_player, n_shuffle); +vector SecureShuffle::generate_random_permutation(int n) { + vector perm; + int n_pow2 = 1 << int(ceil(log2(n))); + int shuffle_size = n; + for (int j = 0; j < n_pow2; j++) + perm.push_back(j); + SeededPRNG G; + for (int i = 0; i < shuffle_size; i++) { + int j = G.get_uint(shuffle_size - i); + swap(perm[i], perm[i + j]); + } + + return perm; +} + +template +void SecureShuffle::player_round(int config_player) { + vector random_perm(n_shuffle); + if (proc.P.my_num() == config_player) + random_perm = generate_random_permutation(n_shuffle); + configure(config_player, &random_perm, n_shuffle); iter_waksman(); } @@ -142,9 +236,12 @@ int SecureShuffle::generate(int n_shuffle) shuffles.push_back({}); auto& shuffle = shuffles.back(); - for (auto i : proc.protocol.get_relevant_players()) - { - generate(i, n_shuffle); + for (auto i: proc.protocol.get_relevant_players()) { + vector perm; + if (proc.P.my_num() == i) + perm = generate_random_permutation(n_shuffle); + configure(i, &perm, n_shuffle); + shuffle.push_back(config); } @@ -152,39 +249,27 @@ int SecureShuffle::generate(int n_shuffle) } template -void SecureShuffle::generate(int config_player, int n) -{ - auto& P = proc.P; - auto& input = proc.input; +void SecureShuffle::configure(int config_player, vector *perm, int n) { + auto &P = proc.P; + auto &input = proc.input; input.reset_all(P); int n_pow2 = 1 << int(ceil(log2(n))); Waksman waksman(n_pow2); - if (P.my_num() == config_player) - { - vector perm; - int shuffle_size = n; - for (int j = 0; j < n_pow2; j++) - perm.push_back(j); - SeededPRNG G; - for (int i = 0; i < shuffle_size; i++) - { - int j = G.get_uint(shuffle_size - i); - swap(perm[i], perm[i + j]); - } - - auto config_bits = waksman.configure(perm); - for (size_t i = 0; i < config_bits.size(); i++) - { - auto& x = config_bits[i]; + // The player specified by config_player configures the shared waksman network + // using its personal permutation + if (P.my_num() == config_player) { + auto config_bits = waksman.configure(*perm); + for (size_t i = 0; i < config_bits.size(); i++) { + auto &x = config_bits[i]; for (size_t j = 0; j < x.size(); j++) if (waksman.matters(i, j)) input.add_mine(int(x[j])); else assert(x[j] == 0); } - } - else + // The other player waits for its share of the configured waksman network + } else for (size_t i = 0; i < waksman.n_bits(); i++) input.add_other(config_player); From 70135dd2fecd638abf8d14a9af1ce39b7e35c43a Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Tue, 26 Jul 2022 16:45:03 +0200 Subject: [PATCH 100/265] Fix segfault in INVPERM instruction --- Protocols/SecureShuffle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 5d713066..920ccf3a 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -98,7 +98,7 @@ void SecureShuffle::inverse_permutation(vector &stack, size_t n, size_t ou for (size_t i = 0; i < n; i++) output.prepare_sending(stack[output_base + i], bob); output.exchange(); - for (size_t i = 0; i < n_pow2; i++) { + for (size_t i = 0; i < n; i++) { // TODO: Is there a better way to convert a T::clear to int? bigint val; output.finalize(bob).to(val); From 52ac60beb03dc02b85894f347cfd775ad56fca51 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 11 Aug 2022 15:35:51 +0200 Subject: [PATCH 101/265] Add --invperm flag for the INVPERM instruction --- Compiler/program.py | 63 +++++++++++++++++++++++++++++---------------- Compiler/types.py | 16 ++++++++++-- compile.py | 2 ++ 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/Compiler/program.py b/Compiler/program.py index e06418f3..435bf7e5 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -49,6 +49,7 @@ class defaults: budget = 100000 mixed = False edabit = False + invperm = False split = None cisc = False comparison = None @@ -142,6 +143,8 @@ class Program(object): self.use_dabit = options.mixed """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit + """ Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)""" + self._invperm = options.invperm self._split = False if options.split: self.use_split(int(options.split)) @@ -167,7 +170,7 @@ class Program(object): """ Upper bound on number of tapes that will be run in parallel. (Excludes empty tapes) """ return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source if 'Programs' in os.listdir(os.getcwd()): @@ -178,16 +181,16 @@ class Program(object): self.programs_dir = sys.path[0] + '/Programs' if self.verbose: print('Compiling program in', self.programs_dir) - + # create extra directories if needed for dirname in ['Public-Input', 'Bytecode', 'Schedules']: if not os.path.exists(self.programs_dir + '/' + dirname): os.mkdir(self.programs_dir + '/' + dirname) - + progname = args[0].split('/')[-1] if progname.endswith('.mpc'): progname = progname[:-4] - + if os.path.exists(args[0]): self.infile = args[0] else: @@ -314,7 +317,7 @@ class Program(object): self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): """ Write all non-empty threads and schedule to files. """ @@ -349,7 +352,7 @@ class Program(object): if self.options.asmoutfile: tape.write_str(self.options.asmoutfile + '-' + tape.name) tape.purge() - + @property def curr_tape(self): """ The tape that is currently running.""" @@ -367,7 +370,7 @@ class Program(object): def curr_block(self): """ The basic block that is currently being created. """ return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): """ Allocate memory from the top """ if not isinstance(size, int): @@ -514,6 +517,20 @@ class Program(object): else: self._edabit = change + def use_invperm(self, change=None): + """ Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used. + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ + if change is None: + if not self._invperm: + self.relevant_opts.add('invperm') + return self._invperm + else: + self._invperm = change + + def use_edabit_for(self, *args): return True @@ -574,6 +591,8 @@ class Program(object): self.always_raw(True) if 'edabit' in self.args: self.use_edabit(True) + if 'invperm' in self.args: + self.use_invperm(True) if 'linear_rounds' in self.args: self.linear_rounds(True) @@ -658,7 +677,7 @@ class Tape: def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): """ Sets the block which we start from next, depending on the condition. @@ -668,15 +687,15 @@ class Tape: self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): """ Add the jump for this block's exit condition to list of instructions (must be done after merging) """ self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): """ Set the correct relative jump offset """ offset = self.get_offset(self.exit_block) @@ -749,7 +768,7 @@ class Tape: def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' @@ -863,7 +882,7 @@ class Tape: print('Re-allocating...') allocator = al.StraightlineAllocator(REG_MAX, self.program) def alloc(block): - for reg in sorted(block.used_from_scope, + for reg in sorted(block.used_from_scope, key=lambda x: (x.reg_type, x.i)): allocator.alloc_reg(reg, block.alloc_pool) def alloc_loop(block): @@ -955,12 +974,12 @@ class Tape: def get_encoding(self): """ Get the encoding of the program, in human-readable format. """ return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): """ Get the byte encoding of the program as an actual string of bytes. """ return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + @unpurged def write_encoding(self, filename): """ Write the readable encoding to a file. """ @@ -969,7 +988,7 @@ class Tape: for line in self.get_encoding(): f.write(str(line) + '\n') f.close() - + @unpurged def write_str(self, filename): """ Write the sequence of instructions to a file. """ @@ -983,7 +1002,7 @@ class Tape: f.write('%s # %d\n' % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): """ Write the program's byte encoding to a file. """ @@ -999,16 +1018,16 @@ class Tape: if i is not None: f.write(i.get_bytes()) f.close() - + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name @@ -1018,7 +1037,7 @@ class Tape: def __add__(self, other): res = Tape.ReqNum() for i,count in list(self.items()): - res[i] += count + res[i] += count for i,count in list(other.items()): res[i] += count return res @@ -1267,7 +1286,7 @@ class Tape: def is_gf2n(self): return self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.SecretGF2N - + @property def is_clear(self): return self.reg_type == RegType.ClearModp or \ diff --git a/Compiler/types.py b/Compiler/types.py index a69e5e52..d63295c8 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2777,8 +2777,20 @@ class sint(_secret, _int): return res def inverse_permutation(self): - res = sint(size=self.size) - inverse_permutation(res, self) + if program.use_invperm(): + # If enabled, we use the low-level INVPERM instruction. + # This instruction has only been implemented for a semi-honest two-party environement. + res = sint(size=self.size) + inverse_permutation(res, self) + else: + shuffle = sint.get_secure_shuffle(len(self)) + shuffled = self.secure_permute(shuffle).reveal() + idx = Array.create_from(shuffled) + res = Array.create_from(sint(regint.inc(len(self)))) + res.secure_permute(shuffle, reverse=False) + res.assign_slice_vector(idx, res.get_vector()) + library.break_point() + res = res.get_vector() return res class sintbit(sint): diff --git a/compile.py b/compile.py index da1b69ee..2455946b 100755 --- a/compile.py +++ b/compile.py @@ -72,6 +72,8 @@ def main(): help="mixing arithmetic and binary computation") parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", help="mixing arithmetic and binary computation using edaBits") + parser.add_option("--invperm", action="store_true", dest="invperm", + help="speedup inverse permutation (only use in two-party, semi-honest environment)") parser.add_option("-Z", "--split", default=defaults.split, dest="split", help="mixing arithmetic and binary computation " "using direct conversion if supported " From 9e9210e683d29f77919c164040ba773b3154a0a3 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 28 Jul 2022 12:28:11 +0200 Subject: [PATCH 102/265] Add bit_not to MemValue --- Compiler/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Compiler/types.py b/Compiler/types.py index d63295c8..5b7441a5 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6708,6 +6708,7 @@ class MemValue(_mem): if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) bit_and = lambda self,other: self.read().bit_and(other) + bit_not = lambda self: self.read().bit_not() def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: From 06520ea7a11451bbe77fa279d7c7abc12208d786 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 28 Jul 2022 12:28:47 +0200 Subject: [PATCH 103/265] Add SqrtORAM to Compiler --- Compiler/sqrt_oram.py | 476 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 476 insertions(+) create mode 100644 Compiler/sqrt_oram.py diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py new file mode 100644 index 00000000..5525aad4 --- /dev/null +++ b/Compiler/sqrt_oram.py @@ -0,0 +1,476 @@ +from __future__ import annotations +from abc import abstractmethod +from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar +from Compiler import library as lib +from Compiler.GC.types import cbit, sbit, sbitint, sbits +from Compiler.oram import AbstractORAM, get_n_threads +from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint +import numpy as np + +debug = True +reveal = True +n_parallel = 1024 + +def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): + if isinstance(array, MultiArray): + temp = array[pos_b][:] + array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) + array[pos_a].assign(cond.if_else(temp, array[pos_a][:])) + if isinstance(array, Array): + temp = array[pos_b] + array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) + array[pos_a] = cond.if_else(temp, array[pos_a]) + +T = TypeVar("T", sint, sbitint) +B = TypeVar("B", sintbit, sbit) + +class SqrtOram(Generic[T, B]): + # TODO: Preferably this is an Array of vectors, but this is currently not supported + # One should regard these structures as Arrays where an entry may hold more + # than one value (which is a nice property to have when using the ORAM in + # practise). + shuffle: MultiArray + stash: MultiArray + # A block has an index and data + # `shuffle` and `stash` store the data, + # `shufflei` and `stashi` store the index + shufflei: Array + stashi: Array + + shuffle_used: Array + position_map: PositionMap + + # The size of the ORAM, i.e. how many elements it stores + n: int + # The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed + T: int + # Keep track of how far we are in the period, and coincidentally how large + # the stash is (each access results in a fake or real block being put on + # the stash) + t: cint + + def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + """Initialize a new Oblivious RAM using the "Square-Root" algorithm. + + Args: + data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array). + value_type (sint): The secret type to use, defaults to sint. + k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. + period (int): Leave at None, this parameter is used to recursively pass down the top-level period. + """ + self.n = len(data) + + self.value_type = value_type + if value_type != sint and value_type != sbitint: + raise Exception("The value_type must be either sint or sbitint") + self.bit_type: Type[B] = value_type.bit_type + self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) )) + self.entry_length = entry_length + + if debug: + lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) + self.shuffle_used = cint.Array(self.n) + # Random permutation on the data + self.shuffle = data + self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) + permutation = Array.create_from(self.shuffle_the_shuffle()) + # Calculate the period if not given + # upon recursion, the period should stay the same ("in sync"), + # therefore it can be passed as a constructor parameter + self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1)) + ) if not period else period + if debug and not period: + lib.print_ln('Period set to %s', self.T) + # Initialize position map (recursive oram) + self.position_map = PositionMap.create(permutation, k + 1, self.T) + + # Initialize stash + self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type) + self.stashi = Array(self.T, value_type=value_type) + self.t = MemValue(cint(0)) + + + def read(self, index: T): + data = self.value_type.Array(self.entry_length) + return self.access(index, self.bit_type(False), data) + + def write(self, index: T, value: Array): + self.access(index, self.bit_type(True), value) + + __getitem__ = read + __setitem__ = write + + def access(self, index: T, write: B, value: Array): + if len(value) != self.entry_length: + raise Exception("A block must be of size entry_length={}".format(self.entry_length)) + # Method Blocks do not accepts arrays as arguments + # workaround by temporarily storing it as a class field + # arrays are stored in memory so this is fine + index = MemValue(index) + return Array.create_from(self._access(index, write, value[:])) + + @lib.method_block + def _access(self, index: T, write: B, *value: list[T]): + item: T = self.value_type(*value) + + if debug: + @lib.if_e(write.reveal() == 1) + def _(): + lib.print_ln('Writing to secret index %s', index.reveal()) + @lib.else_ + def __(): + lib.print_ln('Reading from secret index %s', index.reveal()) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + + # Scan through the stash + @lib.if_(self.t > 0) + def _(): + nonlocal found + found |= index == self.stashi[0] + # We ensure that if the item is found in stash, it ends up in the first + # position (more importantly, a fixed position) of the stash + # This allows us to keep track of it in an oblivious manner + @lib.for_range_opt(self.t) + def _(i): + nonlocal found + found_: B = index == self.stashi[i + 1] + swap(self.stash, 0, i, found_) + swap(self.stashi, 0, i, found_) + found |= found_ + # found = self.bit_type(found.bit_or(found_)) + # If the item was not in the stash, we move the unknown and unimportant + # stash[0] out of the way (to the end of the stash) + swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + + if debug: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln(' Found item in stash') + @lib.else_ + def __(): + lib.print_ln(' Item not in stash') + lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + self.shuffle_used[physical_address] = cbit(True) + + # If the item was in the stash (thus currently residing in stash[0]), + # we place the random item retrieved from the shuffle at the end of the stash + self.stash[self.t].assign(found.if_else( + self.shuffle[physical_address][:], + self.stash[self.t][:])) + self.stashi[self.t] = found.if_else( + self.shufflei[physical_address], + self.stashi[self.t]) + # If the item was not found in the stash, + # we place the item retrieved from the shuffle in stash[0] + self.stash[0].assign(found.bit_not().if_else( + self.shuffle[physical_address][:], + self.stash[0][:])) + self.stashi[0] = found.bit_not().if_else( + self.shufflei[physical_address], + self.stashi[0]) + if debug: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + @lib.else_ + def __(): + lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + self.stash[0].assign(write.if_else(item, self.stash[0][:])) + item=write.bit_not().if_else(self.stash[0][:], item) + return item + + + @lib.method_block + def shuffle_the_shuffle(self): + """Permute the memory using a newly generated permutation and return + the permutation that would generate this particular shuffling. + + This permutation is needed to know how to map logical addresses to + physical addresses, and is used as such by the postition map.""" + + # Random permutation on n elements + random_shuffle = sint.get_secure_shuffle(self.n) + # Apply the random permutation + lib.print_ln('\tGenerated shuffle') + self.shuffle.secure_permute(random_shuffle) + lib.print_ln('\tShuffled shuffle') + self.shufflei.secure_permute(random_shuffle) + lib.print_ln('\tShuffled shuffle indexes') + # Calculate the permutation that would have produced the newly produced + # shuffle order. This can be calculated by regarding the logical + # indexes (shufflei) as a permutation and calculating its inverse, + # i.e. find P such that P([1,2,3,...]) = shufflei. + # this is not necessarily equal to the inverse of the above generated + # random_shuffle, as the shuffle may already be out of order (e.g. when + # refreshing). + permutation = MemValue(self.shufflei[:].inverse_permutation()) + lib.print_ln('\tCalculated inverse permutation') + return permutation + + @lib.method_block + def refresh(self): + """Refresh the ORAM by reinserting the stash back into the shuffle, and + reshuffling the shuffle. + + This must happen after T (period) accesses to the ORAM.""" + lib.print_ln('Refreshing SqrtORAM') + + # Shuffle and emtpy the stash, and store elements back into shuffle + j = MemValue(cint(0,size=1)) + @lib.for_range_opt(self.n) + def _(i): + @lib.if_(self.shuffle_used[i]) + def _(): + nonlocal j + self.shuffle[i] = self.stash[j] + self.shufflei[i] = self.stashi[j] + j += 1 + + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self.shuffle_used.assign_all(0) + + # Reinitialize position map + permutation = self.shuffle_the_shuffle() + # Note that we skip here the step of "packing" the permutation. + # Since the underlying memory of the position map is already aligned in + # this packed structure, we can simply overwrite the memory while + # maintaining the structure. + self.position_map.reinitialize(*permutation) + + @lib.method_block + def reinitialize(self, *data: T): + # Note that this method is only used during refresh, and as such is + # only called with a permutation as data. + + # The logical addresses of some previous permutation are irrelevant and must be reset + self.shufflei.assign([self.index_type(i) for i in range(self.n)]) + # Reset the clock + self.t.write(0) + # Reset shuffle_used + self.shuffle_used.assign_all(0) + + # Note that the self.shuffle is actually a MultiArray + # This structure is preserved while overwriting the values using + # assign_vector + self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length)) + permutation = self.shuffle_the_shuffle() + self.position_map.reinitialize(*permutation) + + +class PositionMap(Generic[T, B]): + PACK_LOG: int = 2 + PACK: int = 1 << PACK_LOG + + n: int # n in the paper + depth: int # k in the paper + value_type: Type[T] + + def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: + self.n = n + self.depth=MemValue(cint(k)) + self.value_type = value_type + self.bit_type = value_type.bit_type + self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n)))) + + @abstractmethod + def get_position(self, logical_address: _secret, fake: B) -> Any: + """Retrieve the block at the given (secret) logical address.""" + if debug: + lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) + + def reinitialize(self, *permutation: T): + """Reinitialize this PositionMap. + + Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`), + we cannot simply call __init__ on self. Instead, we must take care to + reuse and overwrite the same memory. + """ + ... + + @classmethod + def create(cls, permutation: Array, k: int, period: int, value_type: Type[T] = sint) -> PositionMap: + """Creates a new PositionMap. This is the method one should call when + needing a new position map. Depending on the size of the given data, it + will either instantiate a RecursivePositionMap or + a LinearPositionMap.""" + n = len(permutation) + + if n / PositionMap.PACK <= period: + if debug: + lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n) + res = LinearPositionMap(permutation, value_type, k=k) + else: + if debug: + lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n) + res = RecursivePositionMap(permutation, period, value_type, k=k) + + return res + + +class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): + + def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None: + PositionMap.__init__(self, len(permutation), k=k) + pack = PositionMap.PACK + + # We pack the permutation into a smaller structure, index with a new permutation + packed_size = int(np.ceil(self.n / pack)) + packed_structure = MultiArray( + (packed_size, pack), value_type=value_type) + for i in range(packed_size): + packed_structure[i] = Array.create_from( + permutation[i*pack:(i+1)*pack]) + + # TODO: Should this be n or packed_size? + SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + super().get_position(logical_address, fake) + + pack = PositionMap.PACK + pack_log = PositionMap.PACK_LOG + + # The item at logical_address + # will be in block with index h (block.) + # at position l in block.data (block.data) + h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length))) + l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) + + # The resulting physical address + p = MemValue(self.index_type(0)) + found: B = MemValue(self.bit_type(False)) + + # First we try and retrieve the item from the stash + + # We retrieve stash[h] + # Since h is secret, we do this by scanning the entire stash + @lib.for_range(self.t) + def _(j): + nonlocal found + condition = self.stashi[j] == h + found |= condition + # block = stash[h] + # block is itself an array (it holds a permutation) + # we need to grab block[l] + @lib.for_range(pack) + def _(i): + nonlocal condition + condition &= l == i + p.write(condition.if_else(self.stash[j][i], p)) + + if debug: + @lib.if_(condition.reveal() == 1) + def _(): + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal()) + + # Then we try and retrieve the item from the shuffle (the actual memory) + + if debug: + @lib.if_(found.reveal() == 0) + def _(): + lib.print_ln('\t%s Position not in stash', self.depth) + + + p_prime = self.position_map.get_position(h, found) + self.shuffle_used[p_prime] = cbit(True) + # The block retrieved from the shuffle + # Depending on whether the block has already been `found`, this block + # is either the desired block (found=False) or a random block + # (found=True) + block_p_prime: Array = self.shuffle[p_prime] + + if debug: + @lib.if_e(found.reveal() == 0) + def _(): + lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + @lib.else_ + def __(): + lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + + # We add the retrieved block from the shuffle to the stash + self.stash[self.t].assign(block_p_prime[:]) + self.stashi[self.t] = self.shufflei[p_prime] + # Increase t + self.t += 1 + + # if found or not fake + condition = self.bit_type(fake.bit_or(found.bit_not())) + # Retrieve l'th item from block + # l is secret, so we must use linear scan + @lib.for_range_opt(pack) + def _(i): + hit: B = self.bit_type(i == l) + p.write((condition & hit).if_else(block_p_prime[i], p)) + + return p.reveal() + + @lib.method_block + def reinitialize(self, *permutation: T): + SqrtOram.reinitialize(self, *permutation) + +class LinearPositionMap(PositionMap): + physical: Array + used: Array + + def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None: + PositionMap.__init__(self, len(data), value_type, k=k) + self.physical = data + self.used = self.bit_type.Array(self.n) + + @lib.method_block + def get_position(self, logical_address: T, fake: B) -> _clear: + """ + This method corresponds to GetPosBase in the paper. + """ + super().get_position(logical_address, fake) + fake = self.bit_type(fake) + + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + linear_scan = self.bit_type.Array(self.n) + @lib.for_range_opt(self.n) + def _(i): + linear_scan[i] = logical_address == i + + p: MemValue = MemValue(self.index_type(-1)) + done: B = self.bit_type(False) + + @lib.for_range_opt(self.n) + def _(j): + nonlocal done, fake + condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ + .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) + p.write(condition.if_else(self.physical[j], p)) + self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) + done = self.bit_type(condition.if_else(self.bit_type(True), done)) + + if debug: + @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) + def _(): + lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) + + return p.reveal() + + @lib.method_block + def reinitialize(self, *data: T): + self.physical.assign_vector(data) + self.used.assign_all(False) From b070c23a26fce2685e1871576a78d914f877b528 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 11:24:07 +0200 Subject: [PATCH 104/265] Optimize performance of SqrtORAM --- Compiler/sqrt_oram.py | 190 +++++++++++++++++++++++++----------------- 1 file changed, 114 insertions(+), 76 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5525aad4..e949e72e 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -1,17 +1,41 @@ from __future__ import annotations from abc import abstractmethod -from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar +import math +from typing import Any, Generic, Type, TypeVar + +from Compiler.program import Program +from Compiler import util from Compiler import library as lib from Compiler.GC.types import cbit, sbit, sbitint, sbits -from Compiler.oram import AbstractORAM, get_n_threads -from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint -import numpy as np +from Compiler.types import ( + Array, + MemValue, + MultiArray, + _clear, + _secret, + cint, + sint, + sintbit, + regint +) -debug = True -reveal = True +program = Program.prog + +debug = False n_parallel = 1024 +n_threads = 8 + +multithreading = True def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): + """Swap two positions in an Array if a condition is met. + + Args: + array (Array | MultiArray): The array in which to swap the first and second position + pos_a (int | cint): The first position + pos_b (int | cint): The second position + cond (sintbit | sbit): The condition determining whether to swap + """ if isinstance(array, MultiArray): temp = array[pos_b][:] array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) @@ -49,7 +73,7 @@ class SqrtOram(Generic[T, B]): # the stash) t: cint - def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: @@ -64,55 +88,51 @@ class SqrtOram(Generic[T, B]): if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") self.bit_type: Type[B] = value_type.bit_type - self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) )) + self.index_type = value_type.get_type(util.log2(self.n)) self.entry_length = entry_length if debug: lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - self.shuffle = data + if isinstance(data, MultiArray): + self.shuffle = data + elif isinstance(data, sint): + self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type) + self.shuffle.assign_vector(data.get_vector()) + else: + raise Exception("Incorrect format.") self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter - self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1)) - ) if not period else period + self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) # Initialize position map (recursive oram) self.position_map = PositionMap.create(permutation, k + 1, self.T) # Initialize stash - self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type) + self.stash = MultiArray((self.T, entry_length), value_type=value_type) self.stashi = Array(self.T, value_type=value_type) self.t = MemValue(cint(0)) - + @lib.method_block def read(self, index: T): - data = self.value_type.Array(self.entry_length) - return self.access(index, self.bit_type(False), data) + value = self.value_type(0, size=self.entry_length) + return self.access(index, self.bit_type(False), *value) - def write(self, index: T, value: Array): - self.access(index, self.bit_type(True), value) + @lib.method_block + def write(self, index: T, value: T): + lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length") + self.access(index, self.bit_type(True), *value) __getitem__ = read __setitem__ = write - def access(self, index: T, write: B, value: Array): - if len(value) != self.entry_length: - raise Exception("A block must be of size entry_length={}".format(self.entry_length)) - # Method Blocks do not accepts arrays as arguments - # workaround by temporarily storing it as a class field - # arrays are stored in memory so this is fine - index = MemValue(index) - return Array.create_from(self._access(index, write, value[:])) - @lib.method_block - def _access(self, index: T, write: B, *value: list[T]): - item: T = self.value_type(*value) - + def access(self, index: T, write: B, *value: T): if debug: @lib.if_e(write.reveal() == 1) def _(): @@ -120,6 +140,7 @@ class SqrtOram(Generic[T, B]): @lib.else_ def __(): lib.print_ln('Reading from secret index %s', index.reveal()) + value = self.value_type(value) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) @@ -136,14 +157,24 @@ class SqrtOram(Generic[T, B]): # We ensure that if the item is found in stash, it ends up in the first # position (more importantly, a fixed position) of the stash # This allows us to keep track of it in an oblivious manner - @lib.for_range_opt(self.t) - def _(i): - nonlocal found - found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i, found_) - swap(self.stashi, 0, i, found_) - found |= found_ - # found = self.bit_type(found.bit_or(found_)) + if multithreading: + found_ = self.bit_type.Array(size=self.T) + @lib.multithread(8, self.T) + def _(base, size): + found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base) + @lib.for_range_opt(self.t - 1) + def _(i): + swap(self.stash, 0, i, found_[i]) + swap(self.stashi, 0, i, found_[i]) + found.write(sum(found_)) + else: + @lib.for_range_opt(self.t - 1) + def _(i): + nonlocal found + found_: B = index == self.stashi[i + 1] + swap(self.stash, 0, i, found_) + swap(self.stashi, 0, i, found_) + found |= found_ # If the item was not in the stash, we move the unknown and unimportant # stash[0] out of the way (to the end of the stash) swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) @@ -156,7 +187,7 @@ class SqrtOram(Generic[T, B]): @lib.else_ def __(): lib.print_ln(' Item not in stash') - lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -171,14 +202,15 @@ class SqrtOram(Generic[T, B]): self.stashi[self.t] = found.if_else( self.shufflei[physical_address], self.stashi[self.t]) - # If the item was not found in the stash, - # we place the item retrieved from the shuffle in stash[0] + # If the item was not found in the stash, we place the item retrieved + # from the shuffle (the item we are actually looking for) in stash[0] self.stash[0].assign(found.bit_not().if_else( self.shuffle[physical_address][:], self.stash[0][:])) self.stashi[0] = found.bit_not().if_else( self.shufflei[physical_address], self.stashi[0]) + if debug: @lib.if_e(found.reveal() == 1) def _(): @@ -191,9 +223,9 @@ class SqrtOram(Generic[T, B]): # Increase the "time" (i.e. access count in current period) self.t.iadd(1) - self.stash[0].assign(write.if_else(item, self.stash[0][:])) - item=write.bit_not().if_else(self.stash[0][:], item) - return item + self.stash[0].assign(write.if_else(value, self.stash[0][:])) + value=write.bit_not().if_else(self.stash[0][:], value) + return value @lib.method_block @@ -206,12 +238,12 @@ class SqrtOram(Generic[T, B]): # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) + if debug: lib.print_ln('\tGenerated shuffle') # Apply the random permutation - lib.print_ln('\tGenerated shuffle') self.shuffle.secure_permute(random_shuffle) - lib.print_ln('\tShuffled shuffle') + if debug: lib.print_ln('\tShuffled shuffle') self.shufflei.secure_permute(random_shuffle) - lib.print_ln('\tShuffled shuffle indexes') + if debug: lib.print_ln('\tShuffled shuffle indexes') # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -220,7 +252,7 @@ class SqrtOram(Generic[T, B]): # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). permutation = MemValue(self.shufflei[:].inverse_permutation()) - lib.print_ln('\tCalculated inverse permutation') + if debug: lib.print_ln('\tCalculated inverse permutation') return permutation @lib.method_block @@ -229,7 +261,8 @@ class SqrtOram(Generic[T, B]): reshuffling the shuffle. This must happen after T (period) accesses to the ORAM.""" - lib.print_ln('Refreshing SqrtORAM') + + if debug: lib.print_ln('Refreshing SqrtORAM') # Shuffle and emtpy the stash, and store elements back into shuffle j = MemValue(cint(0,size=1)) @@ -276,7 +309,7 @@ class SqrtOram(Generic[T, B]): class PositionMap(Generic[T, B]): - PACK_LOG: int = 2 + PACK_LOG: int = 3 PACK: int = 1 << PACK_LOG n: int # n in the paper @@ -288,7 +321,7 @@ class PositionMap(Generic[T, B]): self.depth=MemValue(cint(k)) self.value_type = value_type self.bit_type = value_type.bit_type - self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n)))) + self.index_type = self.value_type.get_type(util.log2(n)) @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: @@ -332,7 +365,7 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): pack = PositionMap.PACK # We pack the permutation into a smaller structure, index with a new permutation - packed_size = int(np.ceil(self.n / pack)) + packed_size = int(math.ceil(self.n / pack)) packed_structure = MultiArray( (packed_size, pack), value_type=value_type) for i in range(packed_size): @@ -359,28 +392,33 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): p = MemValue(self.index_type(0)) found: B = MemValue(self.bit_type(False)) - # First we try and retrieve the item from the stash + # First we try and retrieve the item from the stash at position stash[h][l] + # Since h and l are secret, we do this by scanning the entire stash - # We retrieve stash[h] - # Since h is secret, we do this by scanning the entire stash + # First we scan the stash for the block we need + condition1 = self.bit_type.Array(self.T) + @lib.for_range_opt_multithread(8, self.T) + def _(i): + condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) + found = sum(condition1) + # Once a block is found, we use condition2 to pick the correct item from that block + condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack)) + # condition3 combines condition1 & condition2, only returning true at stash[h][l] + condition3 = self.bit_type.Array(self.T * pack) + @lib.for_range_opt_multithread(8, [self.T, pack]) + def _(i, j): + condition3[i*pack + j] = condition1[i] & condition2[j] + # Finally we use condition3 to conditionally write p @lib.for_range(self.t) - def _(j): - nonlocal found - condition = self.stashi[j] == h - found |= condition - # block = stash[h] - # block is itself an array (it holds a permutation) - # we need to grab block[l] + def _(i): @lib.for_range(pack) - def _(i): - nonlocal condition - condition &= l == i - p.write(condition.if_else(self.stash[j][i], p)) + def _(j): + p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) if debug: - @lib.if_(condition.reveal() == 1) + @lib.if_(condition1[i].reveal() == 1) def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal()) + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal()) # Then we try and retrieve the item from the shuffle (the actual memory) @@ -389,22 +427,22 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): def _(): lib.print_ln('\t%s Position not in stash', self.depth) - + # Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle p_prime = self.position_map.get_position(h, found) self.shuffle_used[p_prime] = cbit(True) + # The block retrieved from the shuffle - # Depending on whether the block has already been `found`, this block - # is either the desired block (found=False) or a random block - # (found=True) block_p_prime: Array = self.shuffle[p_prime] if debug: @lib.if_e(found.reveal() == 0) def _(): - lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) @lib.else_ def __(): - lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash self.stash[self.t].assign(block_p_prime[:]) @@ -413,13 +451,13 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): self.t += 1 # if found or not fake - condition = self.bit_type(fake.bit_or(found.bit_not())) + condition: B = self.bit_type(fake.bit_or(found.bit_not())) # Retrieve l'th item from block # l is secret, so we must use linear scan + hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack)) @lib.for_range_opt(pack) def _(i): - hit: B = self.bit_type(i == l) - p.write((condition & hit).if_else(block_p_prime[i], p)) + p.write((hit[i]).if_else(block_p_prime[i], p)) return p.reveal() From 3a4cceeedf9c4e2f6555d4b39b595972e5d4aa56 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 14:41:51 +0200 Subject: [PATCH 105/265] Fix misaligned stashi bug in sqrt_oram --- Compiler/sqrt_oram.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index e949e72e..c4b14f99 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -141,6 +141,7 @@ class SqrtOram(Generic[T, B]): def __(): lib.print_ln('Reading from secret index %s', index.reveal()) value = self.value_type(value) + index = MemValue(index) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) @@ -159,21 +160,23 @@ class SqrtOram(Generic[T, B]): # This allows us to keep track of it in an oblivious manner if multithreading: found_ = self.bit_type.Array(size=self.T) - @lib.multithread(8, self.T) + @lib.multithread(1, self.T) def _(base, size): - found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base) + found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size)) + & self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) @lib.for_range_opt(self.t - 1) def _(i): - swap(self.stash, 0, i, found_[i]) - swap(self.stashi, 0, i, found_[i]) + swap(self.stash, 0, i + 1, found_[i+1]) + swap(self.stashi, 0, i + 1, found_[i+1]) found.write(sum(found_)) else: @lib.for_range_opt(self.t - 1) def _(i): nonlocal found found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i, found_) - swap(self.stashi, 0, i, found_) + swap(self.stash, 0, i + 1, found_) + swap(self.stashi, 0, i + 1, found_) found |= found_ # If the item was not in the stash, we move the unknown and unimportant # stash[0] out of the way (to the end of the stash) @@ -183,11 +186,11 @@ class SqrtOram(Generic[T, B]): if debug: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln(' Found item in stash') + lib.print_ln('\tFound item in stash') @lib.else_ def __(): - lib.print_ln(' Item not in stash') - lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln('\tItem not in stash') + lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash From 8af345a7138cebbc2ba120b9b8b840f4f8dd688c Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 17:35:22 +0200 Subject: [PATCH 106/265] Fix improper multi-dimensionality in SqrtORAM --- Compiler/sqrt_oram.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index c4b14f99..5a5e24dd 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -82,7 +82,17 @@ class SqrtOram(Generic[T, B]): k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. """ - self.n = len(data) + if isinstance(data, MultiArray): + self.shuffle = data + self.n = len(data) + elif isinstance(data, sint): + self.n = math.ceil(len(data) // entry_length) + if (len(data) % entry_length != 0): + raise Exception('Data incorrectly padded.') + self.shuffle = MultiArray((self.n, entry_length), value_type=value_type) + self.shuffle.assign_part_vector(data.get_vector()) + else: + raise Exception("Incorrect format.") self.value_type = value_type if value_type != sint and value_type != sbitint: @@ -95,13 +105,6 @@ class SqrtOram(Generic[T, B]): lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - if isinstance(data, MultiArray): - self.shuffle = data - elif isinstance(data, sint): - self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type) - self.shuffle.assign_vector(data.get_vector()) - else: - raise Exception("Incorrect format.") self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given @@ -124,8 +127,8 @@ class SqrtOram(Generic[T, B]): return self.access(index, self.bit_type(False), *value) @lib.method_block - def write(self, index: T, value: T): - lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length") + def write(self, index: T, *value: T): + lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length") self.access(index, self.bit_type(True), *value) __getitem__ = read From 33299e78a58553c8d61706a40fcb257161db0323 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Fri, 29 Jul 2022 17:35:48 +0200 Subject: [PATCH 107/265] Add multithreading to LinearPositionMap in SqrtORAM --- Compiler/sqrt_oram.py | 55 +++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5a5e24dd..5de1174f 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -22,9 +22,7 @@ from Compiler.types import ( program = Program.prog debug = False -n_parallel = 1024 n_threads = 8 - multithreading = True def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): @@ -486,26 +484,47 @@ class LinearPositionMap(PositionMap): This method corresponds to GetPosBase in the paper. """ super().get_position(logical_address, fake) - fake = self.bit_type(fake) - - # In order to get an address at secret logical_address, - # we need to perform a linear scan. - linear_scan = self.bit_type.Array(self.n) - @lib.for_range_opt(self.n) - def _(i): - linear_scan[i] = logical_address == i + fake = MemValue(self.bit_type(fake)) + logical_address = MemValue(logical_address) p: MemValue = MemValue(self.index_type(-1)) done: B = self.bit_type(False) - @lib.for_range_opt(self.n) - def _(j): - nonlocal done, fake - condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ - .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) - p.write(condition.if_else(self.physical[j], p)) - self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) - done = self.bit_type(condition.if_else(self.bit_type(True), done)) + if multithreading: + conditions:Array = self.bit_type.Array(self.n) + conditions.assign_all(0) + + @lib.for_range_opt_multithread(8, self.n) + def condition_i(i): + conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i) + + @lib.for_range_opt(self.n) + def _(i): + nonlocal done + conditions[i] &= done.bit_not() + done |= conditions[i] + @lib.map_sum_opt(8, self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * conditions[i] + p.write(calc_p()) + + self.used.assign(self.used[:] | conditions[:]) + else: + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + linear_scan = self.bit_type.Array(self.n) + @lib.for_range_opt(self.n) + def _(i): + linear_scan[i] = logical_address == i + + @lib.for_range_opt(self.n) + def __(j): + nonlocal done, fake + condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ + .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) + p.write(condition.if_else(self.physical[j], p)) + self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) + done = self.bit_type(condition.if_else(self.bit_type(True), done)) if debug: @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) From 2cd263dad0344355ff4837d04c26fbf93bd4abc7 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Mon, 1 Aug 2022 17:58:45 +0200 Subject: [PATCH 108/265] Improve multithreading and remove non-multithreaded code --- Compiler/sqrt_oram.py | 477 +++++++++++++++++++++++++++++------------- 1 file changed, 332 insertions(+), 145 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5de1174f..a757c045 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -3,10 +3,10 @@ from abc import abstractmethod import math from typing import Any, Generic, Type, TypeVar -from Compiler.program import Program from Compiler import util from Compiler import library as lib from Compiler.GC.types import cbit, sbit, sbitint, sbits +from Compiler.program import Program from Compiler.types import ( Array, MemValue, @@ -14,16 +14,28 @@ from Compiler.types import ( _clear, _secret, cint, + regint, sint, sintbit, - regint ) +from oram import get_n_threads program = Program.prog -debug = False +debug = True +trace = True n_threads = 8 -multithreading = True +n_parallel = 1 + +def get_n_threads(n_loops): + if n_threads is None: + if n_loops > 2048: + return 8 + else: + return None + else: + return n_threads + def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): """Swap two positions in an Array if a condition is met. @@ -43,9 +55,11 @@ def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) array[pos_a] = cond.if_else(temp, array[pos_a]) + T = TypeVar("T", sint, sbitint) B = TypeVar("B", sintbit, sbit) + class SqrtOram(Generic[T, B]): # TODO: Preferably this is an Array of vectors, but this is currently not supported # One should regard these structures as Arrays where an entry may hold more @@ -75,7 +89,7 @@ class SqrtOram(Generic[T, B]): """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: - data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array). + data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). value_type (sint): The secret type to use, defaults to sint. k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. @@ -87,7 +101,8 @@ class SqrtOram(Generic[T, B]): self.n = math.ceil(len(data) // entry_length) if (len(data) % entry_length != 0): raise Exception('Data incorrectly padded.') - self.shuffle = MultiArray((self.n, entry_length), value_type=value_type) + self.shuffle = MultiArray( + (self.n, entry_length), value_type=value_type) self.shuffle.assign_part_vector(data.get_vector()) else: raise Exception("Incorrect format.") @@ -96,19 +111,23 @@ class SqrtOram(Generic[T, B]): if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") self.bit_type: Type[B] = value_type.bit_type - self.index_type = value_type.get_type(util.log2(self.n)) + self.index_size = util.log2(self.n) + self.index_type = value_type.get_type(self.index_size) self.entry_length = entry_length if debug: - lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k) + lib.print_ln( + 'Initializing SqrtORAM of size %s at depth %s', self.n, k) self.shuffle_used = cint.Array(self.n) # Random permutation on the data - self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)]) + self.shufflei = Array.create_from( + [self.index_type(i) for i in range(self.n)]) permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter - self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period + self.T = int(math.ceil( + math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) # Initialize position map (recursive oram) @@ -119,28 +138,108 @@ class SqrtOram(Generic[T, B]): self.stashi = Array(self.T, value_type=value_type) self.t = MemValue(cint(0)) - @lib.method_block - def read(self, index: T): - value = self.value_type(0, size=self.entry_length) - return self.access(index, self.bit_type(False), *value) - - @lib.method_block - def write(self, index: T, *value: T): - lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length") - self.access(index, self.bit_type(True), *value) - - __getitem__ = read - __setitem__ = write + # Initialize temp variables needed during the computation + self.found_ = self.bit_type.Array(size=self.T) @lib.method_block def access(self, index: T, write: B, *value: T): - if debug: + if trace: @lib.if_e(write.reveal() == 1) def _(): lib.print_ln('Writing to secret index %s', index.reveal()) + @lib.else_ def __(): lib.print_ln('Reading from secret index %s', index.reveal()) + + value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # If we are writing, we need to add the value + self.stash[i] += write * access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('\tFound item in stash') + + @lib.else_ + def __(): + lib.print_ln('\tDid not find item in stash') + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + @lib.if_((write * found.bit_not()).reveal()) + def _(): + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) + + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + @lib.method_block + def write(self, index: T, *value: T): + if trace: + lib.print_ln('Writing to secret index %s', index.reveal()) + value = self.value_type(value) index = MemValue(index) @@ -150,87 +249,159 @@ class SqrtOram(Generic[T, B]): self.refresh() found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) - # Scan through the stash - @lib.if_(self.t > 0) - def _(): - nonlocal found - found |= index == self.stashi[0] - # We ensure that if the item is found in stash, it ends up in the first - # position (more importantly, a fixed position) of the stash - # This allows us to keep track of it in an oblivious manner - if multithreading: - found_ = self.bit_type.Array(size=self.T) - @lib.multithread(1, self.T) - def _(base, size): - found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size)) - & self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), - base=base) - @lib.for_range_opt(self.t - 1) - def _(i): - swap(self.stash, 0, i + 1, found_[i+1]) - swap(self.stashi, 0, i + 1, found_[i+1]) - found.write(sum(found_)) - else: - @lib.for_range_opt(self.t - 1) - def _(i): - nonlocal found - found_: B = index == self.stashi[i + 1] - swap(self.stash, 0, i + 1, found_) - swap(self.stashi, 0, i + 1, found_) - found |= found_ - # If the item was not in the stash, we move the unknown and unimportant - # stash[0] out of the way (to the end of the stash) - swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) - swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0))) + # First we scan the stash for the item + self.found_.assign_all(0) - if debug: + # This will result in an bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + # We update the stash value + self.stash[i] += access_here * (value - entry) + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: @lib.if_e(found.reveal() == 1) def _(): lib.print_ln('\tFound item in stash') + @lib.else_ def __(): - lib.print_ln('\tItem not in stash') - lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal()) + lib.print_ln('\tDid not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. self.shuffle_used[physical_address] = cbit(True) - # If the item was in the stash (thus currently residing in stash[0]), - # we place the random item retrieved from the shuffle at the end of the stash - self.stash[self.t].assign(found.if_else( - self.shuffle[physical_address][:], - self.stash[self.t][:])) - self.stashi[self.t] = found.if_else( - self.shufflei[physical_address], - self.stashi[self.t]) - # If the item was not found in the stash, we place the item retrieved - # from the shuffle (the item we are actually looking for) in stash[0] - self.stash[0].assign(found.bit_not().if_else( - self.shuffle[physical_address][:], - self.stash[0][:])) - self.stashi[0] = found.bit_not().if_else( - self.shufflei[physical_address], - self.stashi[0]) + # If the item was not found in the stash + # ...we update the item in the shuffle + self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:]) + # ...and the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] - if debug: - @lib.if_e(found.reveal() == 1) + if trace: + @lib.if_(found.bit_not().reveal()) def _(): - lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) - @lib.else_ - def __(): - lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal()) + lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) self.t.iadd(1) - self.stash[0].assign(write.if_else(value, self.stash[0][:])) - value=write.bit_not().if_else(self.stash[0][:], value) - return value + return result + @lib.method_block + def read(self, index: T, *value: T): + if trace: + lib.print_ln('Reading from secret index %s', index.reveal()) + value = self.value_type(value) + index = MemValue(index) + + # Refresh if we have performed T (period) accesses + @lib.if_(self.t == self.T) + def _(): + self.refresh() + + found: B = MemValue(self.bit_type(False)) + result: T = MemValue(self.value_type(0, size=self.entry_length)) + + # First we scan the stash for the item + self.found_.assign_all(0) + + # This will result in a bit array with at most one True, + # indicating where in the stash 'index' is found + @lib.multithread(get_n_threads(self.T), self.T) + def _(base, size): + self.found_.assign_vector( + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ + self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + base=base) + + # To determine whether the item is found in the stash, we simply + # check wheterh the demuxed array contains a True + # TODO: What if the index=0? + found.write(sum(self.found_)) + + # Store the stash item into the result if found + # If the item is not in the stash, the result will simple remain 0 + @lib.map_sum(get_n_threads(self.T), n_parallel, self.T, + self.entry_length, [self.value_type] * self.entry_length) + def stash_item(i): + entry = self.stash[i][:] + access_here = self.found_[i] + # This is a bit unfortunate + # We should loop from 0 to self.t, but t is dynamic thus this is impossible. + # Therefore we loop till self.T (the max value of self.t) + # is_in_time = i < self.t + + return (entry * access_here)[:] + result += self.value_type(stash_item(), size=self.entry_length) + + if trace: + @lib.if_e(found.reveal() == 1) + def _(): + lib.print_ln('\tFound item in stash') + + @lib.else_ + def __(): + lib.print_ln('\tDid not find item in stash') + + # Possible fake lookup of the item in the shuffle, + # depending on whether we already found the item in the stash + physical_address = self.position_map.get_position(index, found) + # We set shuffle_used to True, to track that this shuffle item needs to be refreshed + # with its equivalent on the stash once the period is up. + self.shuffle_used[physical_address] = cbit(True) + + # If the item was not found in the stash + # the item retrieved from the shuffle is our result + result += self.shuffle[physical_address] * found.bit_not() + # We append the newly retrieved item to the stash + self.stash[self.t].assign(self.shuffle[physical_address][:]) + self.stashi[self.t] = self.shufflei[physical_address] + + if trace: + lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + + # Increase the "time" (i.e. access count in current period) + self.t.iadd(1) + + return result + + __getitem__ = read + __setitem__ = write @lib.method_block def shuffle_the_shuffle(self): @@ -242,12 +413,15 @@ class SqrtOram(Generic[T, B]): # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) - if debug: lib.print_ln('\tGenerated shuffle') + if trace: + lib.print_ln('\tGenerated shuffle') # Apply the random permutation self.shuffle.secure_permute(random_shuffle) - if debug: lib.print_ln('\tShuffled shuffle') + if trace: + lib.print_ln('\tShuffled shuffle') self.shufflei.secure_permute(random_shuffle) - if debug: lib.print_ln('\tShuffled shuffle indexes') + if trace: + lib.print_ln('\tShuffled shuffle indexes') # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -256,7 +430,8 @@ class SqrtOram(Generic[T, B]): # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). permutation = MemValue(self.shufflei[:].inverse_permutation()) - if debug: lib.print_ln('\tCalculated inverse permutation') + if trace: + lib.print_ln('\tCalculated inverse permutation') return permutation @lib.method_block @@ -266,10 +441,12 @@ class SqrtOram(Generic[T, B]): This must happen after T (period) accesses to the ORAM.""" - if debug: lib.print_ln('Refreshing SqrtORAM') + if trace: + lib.print_ln('Refreshing SqrtORAM') # Shuffle and emtpy the stash, and store elements back into shuffle - j = MemValue(cint(0,size=1)) + j = MemValue(cint(0, size=1)) + @lib.for_range_opt(self.n) def _(i): @lib.if_(self.shuffle_used[i]) @@ -301,13 +478,14 @@ class SqrtOram(Generic[T, B]): self.shufflei.assign([self.index_type(i) for i in range(self.n)]) # Reset the clock self.t.write(0) - # Reset shuffle_used + # Reset shuffle_used self.shuffle_used.assign_all(0) # Note that the self.shuffle is actually a MultiArray # This structure is preserved while overwriting the values using # assign_vector - self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length)) + self.shuffle.assign_vector(self.value_type( + data, size=self.n * self.entry_length)) permutation = self.shuffle_the_shuffle() self.position_map.reinitialize(*permutation) @@ -316,13 +494,13 @@ class PositionMap(Generic[T, B]): PACK_LOG: int = 3 PACK: int = 1 << PACK_LOG - n: int # n in the paper - depth: int # k in the paper + n: int # n in the paper + depth: int # k in the paper value_type: Type[T] - def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None: + def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None: self.n = n - self.depth=MemValue(cint(k)) + self.depth = MemValue(cint(k)) self.value_type = value_type self.bit_type = value_type.bit_type self.index_type = self.value_type.get_type(util.log2(n)) @@ -330,8 +508,9 @@ class PositionMap(Generic[T, B]): @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: """Retrieve the block at the given (secret) logical address.""" - if debug: - lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) + if trace: + lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, + self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) def reinitialize(self, *permutation: T): """Reinitialize this PositionMap. @@ -352,11 +531,13 @@ class PositionMap(Generic[T, B]): if n / PositionMap.PACK <= period: if debug: - lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n) + lib.print_ln( + 'Initializing LinearPositionMap at depth %s of size %s', k, n) res = LinearPositionMap(permutation, value_type, k=k) else: if debug: - lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n) + lib.print_ln( + 'Initializing RecursivePositionMap at depth %s of size %s', k, n) res = RecursivePositionMap(permutation, period, value_type, k=k) return res @@ -364,7 +545,7 @@ class PositionMap(Generic[T, B]): class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): - def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None: + def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None: PositionMap.__init__(self, len(permutation), k=k) pack = PositionMap.PACK @@ -377,7 +558,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): permutation[i*pack:(i+1)*pack]) # TODO: Should this be n or packed_size? - SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + SqrtOram.__init__(self, packed_structure, value_type=value_type, + period=period, entry_length=pack, k=self.depth) @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: @@ -389,7 +571,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): # The item at logical_address # will be in block with index h (block.) # at position l in block.data (block.data) - h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length))) + h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)( + logical_address).right_shift(pack_log, program.bit_length))) l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) # The resulting physical address @@ -401,32 +584,37 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): # First we scan the stash for the block we need condition1 = self.bit_type.Array(self.T) + @lib.for_range_opt_multithread(8, self.T) def _(i): condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) found = sum(condition1) # Once a block is found, we use condition2 to pick the correct item from that block - condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack)) + condition2 = Array.create_from( + regint.inc(pack) == l.expand_to_vector(pack)) # condition3 combines condition1 & condition2, only returning true at stash[h][l] condition3 = self.bit_type.Array(self.T * pack) + @lib.for_range_opt_multithread(8, [self.T, pack]) def _(i, j): condition3[i*pack + j] = condition1[i] & condition2[j] # Finally we use condition3 to conditionally write p + @lib.for_range(self.t) def _(i): @lib.for_range(pack) def _(j): p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) - if debug: + if trace: @lib.if_(condition1[i].reveal() == 1) def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal()) + lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal( + ), self.stash[i].reveal()) # Then we try and retrieve the item from the shuffle (the actual memory) - if debug: + if trace: @lib.if_(found.reveal() == 0) def _(): lib.print_ln('\t%s Position not in stash', self.depth) @@ -438,15 +626,16 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): # The block retrieved from the shuffle block_p_prime: Array = self.shuffle[p_prime] - if debug: + if trace: @lib.if_e(found.reveal() == 0) def _(): lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', - self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + @lib.else_ def __(): lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', - self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) + self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash self.stash[self.t].assign(block_p_prime[:]) @@ -458,7 +647,9 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): condition: B = self.bit_type(fake.bit_or(found.bit_not())) # Retrieve l'th item from block # l is secret, so we must use linear scan - hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack)) + hit = Array.create_from((regint.inc(pack) == l.expand_to_vector( + pack)) & condition.expand_to_vector(pack)) + @lib.for_range_opt(pack) def _(i): p.write((hit[i]).if_else(block_p_prime[i], p)) @@ -469,67 +660,63 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): def reinitialize(self, *permutation: T): SqrtOram.reinitialize(self, *permutation) + class LinearPositionMap(PositionMap): physical: Array used: Array - def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None: + def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None: PositionMap.__init__(self, len(data), value_type, k=k) self.physical = data self.used = self.bit_type.Array(self.n) + # Initialize random temp variables needed during the computation + self.physical_demux: Array = self.bit_type.Array(self.n) + @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: """ This method corresponds to GetPosBase in the paper. """ super().get_position(logical_address, fake) + fake = MemValue(self.bit_type(fake)) logical_address = MemValue(logical_address) p: MemValue = MemValue(self.index_type(-1)) done: B = self.bit_type(False) - if multithreading: - conditions:Array = self.bit_type.Array(self.n) - conditions.assign_all(0) + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + self.physical_demux.assign_all(0) + @lib.for_range_opt_multithread(8, self.n) + def condition_i(i): + self.physical_demux.assign((self.bit_type(fake).bit_not() + & self.bit_type(logical_address == i)) | (fake + & self.used[i].bit_not()), base=i) - @lib.for_range_opt_multithread(8, self.n) - def condition_i(i): - conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i) + # In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False) + # We only need once, so we pick the first one we find + @lib.for_range_opt(self.n) + def _(i): + nonlocal done + self.physical_demux[i] &= done.bit_not() + done |= self.physical_demux[i] - @lib.for_range_opt(self.n) - def _(i): - nonlocal done - conditions[i] &= done.bit_not() - done |= conditions[i] - @lib.map_sum_opt(8, self.n, [self.value_type]) - def calc_p(i): - return self.physical[i] * conditions[i] - p.write(calc_p()) + # Retrieve the value from the physical memory obliviously + @lib.map_sum_opt(8, self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * self.physical_demux[i] + p.write(calc_p()) - self.used.assign(self.used[:] | conditions[:]) - else: - # In order to get an address at secret logical_address, - # we need to perform a linear scan. - linear_scan = self.bit_type.Array(self.n) - @lib.for_range_opt(self.n) - def _(i): - linear_scan[i] = logical_address == i + # Update self.used + self.used.assign(self.used[:] | self.physical_demux[:]) - @lib.for_range_opt(self.n) - def __(j): - nonlocal done, fake - condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ - .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) - p.write(condition.if_else(self.physical[j], p)) - self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) - done = self.bit_type(condition.if_else(self.bit_type(True), done)) - - if debug: + if trace: @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical))) def _(): - lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) + lib.runtime_error( + '%s Did not find requested logical_address in shuffle, something went wrong.', self.depth) return p.reveal() From fb5871a2f8a54bae8af1ea64114496b962781e82 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Thu, 11 Aug 2022 14:41:20 +0200 Subject: [PATCH 109/265] Add allow_memory_allocation option to SqrtORAM Also remove unused swap function in SqrtORAM --- Compiler/sqrt_oram.py | 307 ++++++++++++++++++++++++------------------ 1 file changed, 177 insertions(+), 130 deletions(-) diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index a757c045..d732e2f2 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -1,32 +1,34 @@ from __future__ import annotations -from abc import abstractmethod + import math +from abc import abstractmethod from typing import Any, Generic, Type, TypeVar -from Compiler import util from Compiler import library as lib +from Compiler import util from Compiler.GC.types import cbit, sbit, sbitint, sbits from Compiler.program import Program -from Compiler.types import ( - Array, - MemValue, - MultiArray, - _clear, - _secret, - cint, - regint, - sint, - sintbit, -) -from oram import get_n_threads +from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint, + regint, sint, sintbit) +from oram import demux_array, get_n_threads program = Program.prog -debug = True -trace = True -n_threads = 8 +# Adds messages on completion of heavy computation steps +debug = False +# Finer grained trace of steps that the ORAM performs +# + runtime error checks +# Warning: reveals information and makes the computation insecure +trace = False + +n_threads = 16 n_parallel = 1 +# Avoids any memory allocation +# This prevents some optimizations but allows for using the ORAMs outside of the main tape +allow_memory_allocation = True + + def get_n_threads(n_loops): if n_threads is None: if n_loops > 2048: @@ -37,25 +39,6 @@ def get_n_threads(n_loops): return n_threads -def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): - """Swap two positions in an Array if a condition is met. - - Args: - array (Array | MultiArray): The array in which to swap the first and second position - pos_a (int | cint): The first position - pos_b (int | cint): The second position - cond (sintbit | sbit): The condition determining whether to swap - """ - if isinstance(array, MultiArray): - temp = array[pos_b][:] - array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:])) - array[pos_a].assign(cond.if_else(temp, array[pos_a][:])) - if isinstance(array, Array): - temp = array[pos_b] - array[pos_b] = cond.if_else(array[pos_a], array[pos_b]) - array[pos_a] = cond.if_else(temp, array[pos_a]) - - T = TypeVar("T", sint, sbitint) B = TypeVar("B", sintbit, sbit) @@ -85,7 +68,7 @@ class SqrtOram(Generic[T, B]): # the stash) t: cint - def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None: + def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True) -> None: """Initialize a new Oblivious RAM using the "Square-Root" algorithm. Args: @@ -94,6 +77,9 @@ class SqrtOram(Generic[T, B]): k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. period (int): Leave at None, this parameter is used to recursively pass down the top-level period. """ + global debug, allow_memory_allocation + + # Correctly initialize the shuffle (memory) depending on the type of data if isinstance(data, MultiArray): self.shuffle = data self.n = len(data) @@ -107,9 +93,12 @@ class SqrtOram(Generic[T, B]): else: raise Exception("Incorrect format.") - self.value_type = value_type + # Only sint is supported if value_type != sint and value_type != sbitint: raise Exception("The value_type must be either sint or sbitint") + + # Set derived constants + self.value_type = value_type self.bit_type: Type[B] = value_type.bit_type self.index_size = util.log2(self.n) self.index_type = value_type.get_type(self.index_size) @@ -118,11 +107,11 @@ class SqrtOram(Generic[T, B]): if debug: lib.print_ln( 'Initializing SqrtORAM of size %s at depth %s', self.n, k) + self.shuffle_used = cint.Array(self.n) # Random permutation on the data self.shufflei = Array.create_from( [self.index_type(i) for i in range(self.n)]) - permutation = Array.create_from(self.shuffle_the_shuffle()) # Calculate the period if not given # upon recursion, the period should stay the same ("in sync"), # therefore it can be passed as a constructor parameter @@ -130,8 +119,21 @@ class SqrtOram(Generic[T, B]): math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period if debug and not period: lib.print_ln('Period set to %s', self.T) + + # Here we allocate the memory for the permutation + # Note that self.shuffle_the_shuffle mutates this field + # Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading + self.permutation = Array.create_from( + [self.index_type(i) for i in range(self.n)]) + # We allow the caller to postpone the initialization of the shuffle + # This is the most expensive operation, and can be done in a thread (only if you know what you're doing) + # Note that if you do not initialize, the ORAM is insecure + if initialize: + self.shuffle_the_shuffle() + else: + print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.') # Initialize position map (recursive oram) - self.position_map = PositionMap.create(permutation, k + 1, self.T) + self.position_map = PositionMap.create(self.permutation, k + 1, self.T) # Initialize stash self.stash = MultiArray((self.T, entry_length), value_type=value_type) @@ -140,19 +142,28 @@ class SqrtOram(Generic[T, B]): # Initialize temp variables needed during the computation self.found_ = self.bit_type.Array(size=self.T) + self.j = MemValue(cint(0, size=1)) + + # To prevent the compiler from recompiling the same code over and over again, we should use @method_block + # However, @method_block requires allocation (of return address), which is not allowed when not in the main thread + # Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread + self.shuffle_the_shuffle = lib.method_block(self.shuffle_the_shuffle) if allow_memory_allocation else self.shuffle_the_shuffle + self.refresh = lib.method_block(self.refresh) if allow_memory_allocation else self.refresh @lib.method_block def access(self, index: T, write: B, *value: T): + global trace,n_parallel if trace: @lib.if_e(write.reveal() == 1) def _(): - lib.print_ln('Writing to secret index %s', index.reveal()) + lib.print_ln(' Writing to secret index %s', index.reveal()) @lib.else_ def __(): - lib.print_ln('Reading from secret index %s', index.reveal()) + lib.print_ln(' Reading from secret index %s', index.reveal()) - value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length) + value = self.value_type(value, size=self.entry_length).get_vector( + 0, size=self.entry_length) index = MemValue(index) # Refresh if we have performed T (period) accesses @@ -171,8 +182,9 @@ class SqrtOram(Generic[T, B]): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -200,11 +212,11 @@ class SqrtOram(Generic[T, B]): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -215,7 +227,8 @@ class SqrtOram(Generic[T, B]): # If the item was not found in the stash # ...we update the item in the shuffle - self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:]) + self.shuffle[physical_address] += write * \ + found.bit_not() * (value - self.shuffle[physical_address][:]) # ...and the item retrieved from the shuffle is our result result += self.shuffle[physical_address] * found.bit_not() # We append the newly retrieved item to the stash @@ -225,10 +238,8 @@ class SqrtOram(Generic[T, B]): if trace: @lib.if_((write * found.bit_not()).reveal()) def _(): - lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) - - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, - self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) + lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) # Increase the "time" (i.e. access count in current period) self.t.iadd(1) @@ -237,8 +248,9 @@ class SqrtOram(Generic[T, B]): @lib.method_block def write(self, index: T, *value: T): + global trace, n_parallel if trace: - lib.print_ln('Writing to secret index %s', index.reveal()) + lib.print_ln(' Writing to secret index %s', index.reveal()) value = self.value_type(value) index = MemValue(index) @@ -259,8 +271,9 @@ class SqrtOram(Generic[T, B]): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -286,11 +299,11 @@ class SqrtOram(Generic[T, B]): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -301,7 +314,8 @@ class SqrtOram(Generic[T, B]): # If the item was not found in the stash # ...we update the item in the shuffle - self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:]) + self.shuffle[physical_address] += found.bit_not() * \ + (value - self.shuffle[physical_address][:]) # ...and the item retrieved from the shuffle is our result result += self.shuffle[physical_address] * found.bit_not() # We append the newly retrieved item to the stash @@ -311,9 +325,10 @@ class SqrtOram(Generic[T, B]): if trace: @lib.if_(found.bit_not().reveal()) def _(): - lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address) + lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal( + ), self.shuffle[physical_address].reveal(), physical_address) - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) @@ -323,14 +338,20 @@ class SqrtOram(Generic[T, B]): @lib.method_block def read(self, index: T, *value: T): + global debug, trace, n_parallel if trace: - lib.print_ln('Reading from secret index %s', index.reveal()) + lib.print_ln(' Reading from secret index %s', index.reveal()) + value = self.value_type(value) index = MemValue(index) # Refresh if we have performed T (period) accesses @lib.if_(self.t == self.T) def _(): + if debug: + lib.print_ln(' Refreshing SqrtORAM') + lib.print_ln(' t=%s according to me', self.t) + self.refresh() found: B = MemValue(self.bit_type(False)) @@ -344,8 +365,9 @@ class SqrtOram(Generic[T, B]): @lib.multithread(get_n_threads(self.T), self.T) def _(base, size): self.found_.assign_vector( - (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \ - self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)), + (self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & + self.bit_type(regint.inc(size, base=base) < + self.t.expand_to_vector(size)), base=base) # To determine whether the item is found in the stash, we simply @@ -371,11 +393,11 @@ class SqrtOram(Generic[T, B]): if trace: @lib.if_e(found.reveal() == 1) def _(): - lib.print_ln('\tFound item in stash') + lib.print_ln(' Found item in stash') @lib.else_ def __(): - lib.print_ln('\tDid not find item in stash') + lib.print_ln(' Did not find item in stash') # Possible fake lookup of the item in the shuffle, # depending on whether we already found the item in the stash @@ -392,7 +414,7 @@ class SqrtOram(Generic[T, B]): self.stashi[self.t] = self.shufflei[physical_address] if trace: - lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, + lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t) # Increase the "time" (i.e. access count in current period) @@ -403,25 +425,36 @@ class SqrtOram(Generic[T, B]): __getitem__ = read __setitem__ = write - @lib.method_block - def shuffle_the_shuffle(self): + def shuffle_the_shuffle(self) -> None: """Permute the memory using a newly generated permutation and return the permutation that would generate this particular shuffling. This permutation is needed to know how to map logical addresses to physical addresses, and is used as such by the postition map.""" + global trace # Random permutation on n elements random_shuffle = sint.get_secure_shuffle(self.n) if trace: - lib.print_ln('\tGenerated shuffle') + lib.print_ln(' Generated shuffle') # Apply the random permutation self.shuffle.secure_permute(random_shuffle) if trace: - lib.print_ln('\tShuffled shuffle') + lib.print_ln(' Shuffled shuffle') self.shufflei.secure_permute(random_shuffle) if trace: - lib.print_ln('\tShuffled shuffle indexes') + lib.print_ln(' Shuffled shuffle indexes') + + if trace: + # If shufflei does not contain exactly the indices [i for i in + # range(self.n)], the underlying waksman network of + # 'inverse_permutation' will hang. + tmp_shuffli = Array.create_from(self.shufflei[:]) + @lib.if_(sum(lib.sort(tmp_shuffli)[:] == Array.create_from([cint(i) for i in range(self.n)])[:]).reveal() != self.n) + def _(): + lib.print_ln( + 'Shufflei is corrupted! You have found a bug in the implementation :c\nThe computation will now hang...') + # Calculate the permutation that would have produced the newly produced # shuffle order. This can be calculated by regarding the logical # indexes (shufflei) as a permutation and calculating its inverse, @@ -429,45 +462,45 @@ class SqrtOram(Generic[T, B]): # this is not necessarily equal to the inverse of the above generated # random_shuffle, as the shuffle may already be out of order (e.g. when # refreshing). - permutation = MemValue(self.shufflei[:].inverse_permutation()) + self.permutation.assign(self.shufflei[:].inverse_permutation()) if trace: - lib.print_ln('\tCalculated inverse permutation') - return permutation + lib.print_ln(' Calculated inverse permutation') - @lib.method_block def refresh(self): """Refresh the ORAM by reinserting the stash back into the shuffle, and reshuffling the shuffle. - This must happen after T (period) accesses to the ORAM.""" - - if trace: - lib.print_ln('Refreshing SqrtORAM') + This must happen on the T'th (period) accesses to the ORAM.""" + self.j.write(0) # Shuffle and emtpy the stash, and store elements back into shuffle - j = MemValue(cint(0, size=1)) @lib.for_range_opt(self.n) def _(i): @lib.if_(self.shuffle_used[i]) def _(): - nonlocal j - self.shuffle[i] = self.stash[j] - self.shufflei[i] = self.stashi[j] - j += 1 + self.shuffle[i] = self.stash[self.j] + self.shufflei[i] = self.stashi[self.j] + self.j += 1 # Reset the clock self.t.write(0) # Reset shuffle_used - self.shuffle_used.assign_all(0) + global allow_memory_allocation + if allow_memory_allocation: + self.shuffle_used.assign_all(0) + else: + @lib.for_range_opt(self.n) + def _(i): + self.shuffle_used[i] = cint(0) # Reinitialize position map - permutation = self.shuffle_the_shuffle() + self.shuffle_the_shuffle() # Note that we skip here the step of "packing" the permutation. # Since the underlying memory of the position map is already aligned in # this packed structure, we can simply overwrite the memory while # maintaining the structure. - self.position_map.reinitialize(*permutation) + self.position_map.reinitialize(*self.permutation) @lib.method_block def reinitialize(self, *data: T): @@ -478,7 +511,7 @@ class SqrtOram(Generic[T, B]): self.shufflei.assign([self.index_type(i) for i in range(self.n)]) # Reset the clock self.t.write(0) - # Reset shuffle_used + # Reset shuffle_used self.shuffle_used.assign_all(0) # Note that the self.shuffle is actually a MultiArray @@ -486,8 +519,10 @@ class SqrtOram(Generic[T, B]): # assign_vector self.shuffle.assign_vector(self.value_type( data, size=self.n * self.entry_length)) - permutation = self.shuffle_the_shuffle() - self.position_map.reinitialize(*permutation) + # Note that this updates self.permutation (see constructor for explanation) + self.shuffle_the_shuffle() + self.position_map.reinitialize(*self.permutation) + class PositionMap(Generic[T, B]): @@ -508,8 +543,9 @@ class PositionMap(Generic[T, B]): @abstractmethod def get_position(self, logical_address: _secret, fake: B) -> Any: """Retrieve the block at the given (secret) logical address.""" + global trace if trace: - lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, + lib.print_ln(' %s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal()) def reinitialize(self, *permutation: T): @@ -529,6 +565,7 @@ class PositionMap(Generic[T, B]): a LinearPositionMap.""" n = len(permutation) + global debug if n / PositionMap.PACK <= period: if debug: lib.print_ln( @@ -561,6 +598,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth) + # Initialize random temp variables needed during the computation + self.block_index_demux: Array = self.bit_type.Array(self.T) + self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK) + @lib.method_block def get_position(self, logical_address: T, fake: B) -> _clear: super().get_position(logical_address, fake) @@ -576,50 +617,42 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) # The resulting physical address - p = MemValue(self.index_type(0)) + p = MemValue(self.index_type(-1)) found: B = MemValue(self.bit_type(False)) # First we try and retrieve the item from the stash at position stash[h][l] # Since h and l are secret, we do this by scanning the entire stash # First we scan the stash for the block we need - condition1 = self.bit_type.Array(self.T) + self.block_index_demux.assign_all(0) - @lib.for_range_opt_multithread(8, self.T) + @lib.for_range_opt_multithread(get_n_threads(self.T), self.T) def _(i): - condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t) - found = sum(condition1) - # Once a block is found, we use condition2 to pick the correct item from that block - condition2 = Array.create_from( - regint.inc(pack) == l.expand_to_vector(pack)) - # condition3 combines condition1 & condition2, only returning true at stash[h][l] - condition3 = self.bit_type.Array(self.T * pack) + self.block_index_demux[i] = ( + self.stashi[i] == h) & self.bit_type(i < self.t) + # We can determine if the 'index' is in the stash by checking the + # block_index_demux array + found = sum(self.block_index_demux) + # Once a block is found, we use the following condition to pick the correct item from that block + demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux) - @lib.for_range_opt_multithread(8, [self.T, pack]) - def _(i, j): - condition3[i*pack + j] = condition1[i] & condition2[j] - # Finally we use condition3 to conditionally write p - - @lib.for_range(self.t) - def _(i): - @lib.for_range(pack) - def _(j): - p.write(condition3[i*pack + j].if_else(self.stash[i][j], p)) - - if trace: - @lib.if_(condition1[i].reveal() == 1) - def _(): - lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal( - ), self.stash[i].reveal()) - - # Then we try and retrieve the item from the shuffle (the actual memory) + # Finally we use the conditions to conditionally write p + @lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type]) + def p_(i): + # We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum + # Therefore we include the check (i < self.t) + return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i < self.t) + p.write(p_()) + global trace if trace: @lib.if_(found.reveal() == 0) def _(): - lib.print_ln('\t%s Position not in stash', self.depth) + lib.print_ln(' %s Position not in stash', self.depth) - # Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle + # Then we try and retrieve the item from the shuffle (the actual memory) + # Depending on whether we found the item in the stash, we either + # block 'h' in which 'index' resides, or a random block from the shuffle p_prime = self.position_map.get_position(h, found) self.shuffle_used[p_prime] = cbit(True) @@ -629,12 +662,12 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): if trace: @lib.if_e(found.reveal() == 0) def _(): - lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', + lib.print_ln(' %s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) @lib.else_ def __(): - lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)', + lib.print_ln(' %s Retrieved dummy stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal()) # We add the retrieved block from the shuffle to the stash @@ -680,6 +713,13 @@ class LinearPositionMap(PositionMap): """ super().get_position(logical_address, fake) + global trace + if trace: + @lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal()) + def _(): + lib.runtime_error( + 'logical_address must lie between 0 and self.n - 1') + fake = MemValue(self.bit_type(fake)) logical_address = MemValue(logical_address) @@ -689,11 +729,12 @@ class LinearPositionMap(PositionMap): # In order to get an address at secret logical_address, # we need to perform a linear scan. self.physical_demux.assign_all(0) - @lib.for_range_opt_multithread(8, self.n) + + @lib.for_range_opt_multithread(get_n_threads(self.n), self.n) def condition_i(i): - self.physical_demux.assign((self.bit_type(fake).bit_not() - & self.bit_type(logical_address == i)) | (fake - & self.used[i].bit_not()), base=i) + self.physical_demux[i] = \ + (self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \ + | (fake & self.used[i].bit_not()) # In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False) # We only need once, so we pick the first one we find @@ -704,7 +745,7 @@ class LinearPositionMap(PositionMap): done |= self.physical_demux[i] # Retrieve the value from the physical memory obliviously - @lib.map_sum_opt(8, self.n, [self.value_type]) + @lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type]) def calc_p(i): return self.physical[i] * self.physical_demux[i] p.write(calc_p()) @@ -720,7 +761,13 @@ class LinearPositionMap(PositionMap): return p.reveal() - @lib.method_block def reinitialize(self, *data: T): self.physical.assign_vector(data) - self.used.assign_all(False) + + global allow_memory_allocation + if allow_memory_allocation: + self.used.assign_all(False) + else: + @lib.for_range_opt(self.n) + def _(i): + self.used[i] = self.bit_type(0) From a4e4baddeedefbf2a6328caf72e738a03f3930ef Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Fri, 12 Aug 2022 10:45:36 -0700 Subject: [PATCH 110/265] add new in python compile to docs --- README.md | 31 ++++++++++++++++++++++++++++++- doc/index.rst | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 18278c25..79bf5339 100644 --- a/README.md +++ b/README.md @@ -464,6 +464,35 @@ See the [documentation](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.circuit) for further examples. +#### Compiling programs directly in Python + +You may prefer to not have an entirely static `.mpc` file to compile, and may want to compile based on dynamic inputs. For example, you may want to be able to compile with different sizes of input data without making a code change to the `.mpc` file. To handle this, the compiler an also be directly imported, and a function can be compiled with the following interface: + +```python +# hello_world.mpc +from Compiler.library import print_ln +from Compiler.compilerLib import Compiler + +compiler = Compiler() + +@compiler.register_function('helloworld') +def hello_world(): + print_ln('hello world') + +if __name__ == "__main__": + compiler.compile_func() +``` + +You could then run this with: + +```bash +python hello_world.mpc +``` + +This is particularly useful if want to add new command line arguements specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) for more details on this use case. + +Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`. + #### Compiling and running programs from external directories Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all MP-SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example: @@ -966,7 +995,7 @@ After compiling the mpc file: You can benchmark the ORAM implementation as follows: 1) Edit `Program/Source/gc_oram.mpc` to change size and to choose -Circuit ORAM or linear scan without ORAM. +Circuit ORAM or linear scan without ORAM. 2) Run `./compile.py -D gc_oram`. The `-D` argument instructs the compiler to remove dead code. This is useful for more complex programs such as this one. diff --git a/doc/index.rst b/doc/index.rst index 59caa58d..61acd045 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -17,8 +17,7 @@ Compilation process The easiest way of using MP-SPDZ is using ``compile.py`` as described below. If you would like to run compilation directly from -Python, see ``Scripts/direct_compilation_example.py``. It contains all -the necessary setup steps. +Python, see :ref:`Direct Compilation in Python`. After putting your code in ``Program/Source/.mpc``, run the compiler from the root directory as follows @@ -140,6 +139,43 @@ computation: to the run time. +Direct Compilation in Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You may prefer to not have an entirely static `.mpc` file to compile, +and may want to compile based on dynamic inputs. For example, you may +want to be able to compile with different sizes of input data without +making a code change to the `.mpc` file. To handle this, the compiler +an also be directly imported, and a function can be compiled with the +following interface: + +.. code-block:: python + # hello_world.mpc + from Compiler.library import print_ln + from Compiler.compilerLib import Compiler + + compiler = Compiler() + + @compiler.register_function('helloworld') + def hello_world(): + print_ln('hello world') + + if __name__ == "__main__": + compiler.compile_func() + + +You could then run this with the same args as used with `compile.py`: + +.. code-block:: bash + python hello_world.mpc + +This is particularly useful if want to add new command line arguements +specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) +for more details on this use case. + +Note that when using this approach, all objects provided in the high level +interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file +is interpreted directly by Python (instead of being read by `compile.py`.) + Compilation vs run time ~~~~~~~~~~~~~~~~~~~~~~~ From f83476ab2a90e2f44f4661d6f845cc4b560306d2 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 13 Aug 2022 12:23:12 -0700 Subject: [PATCH 111/265] fix small typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 79bf5339..8e03710c 100644 --- a/README.md +++ b/README.md @@ -483,7 +483,7 @@ if __name__ == "__main__": compiler.compile_func() ``` -You could then run this with: +You could then run this with the same args as used with `compile.py`: ```bash python hello_world.mpc @@ -491,7 +491,7 @@ python hello_world.mpc This is particularly useful if want to add new command line arguements specifically for your `.mpc` file. See [test_args.mpc](Programs/Source/test_args.mpc) for more details on this use case. -Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`. +Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`.) #### Compiling and running programs from external directories From 39b6d7e22dedc5798823d306741d33ff6635bc25 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Sat, 13 Aug 2022 17:56:24 -0700 Subject: [PATCH 112/265] allow custom_args to manaully override args --- Compiler/compilerLib.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 9e36e9f9..ae8ca9b2 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -12,11 +12,12 @@ from .program import Program, defaults class Compiler: - def __init__(self, usage=None): + def __init__(self, custom_args=None, usage=None): if usage: self.usage = usage else: self.usage = "usage: %prog [options] filename [args]" + self.custom_args = custom_args self.build_option_parser() self.VARS = {} @@ -205,7 +206,7 @@ class Compiler: self.parser = parser def parse_args(self): - self.options, self.args = self.parser.parse_args() + self.options, self.args = self.parser.parse_args(self.custom_args) if self.options.optimize_hard: print("Note that -O/--optimize-hard currently has no effect") @@ -358,7 +359,7 @@ class Compiler: return inner def compile_func(self): - if not hasattr(self, "compile_name") and hasattr(self, "compile_func"): + if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")): raise CompilerError( "No function to compile. " "Did you decorate a function with @register_fuction(name)?" From 7630fbc22bfea75004f80827fbc760e92ff9cc37 Mon Sep 17 00:00:00 2001 From: hernan232 Date: Tue, 16 Aug 2022 09:12:17 -0500 Subject: [PATCH 113/265] Correct documentation in BMR table. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 70db2a0c..5f257c4d 100644 --- a/README.md +++ b/README.md @@ -842,7 +842,7 @@ lists the available schemes. | Program | Protocol | Dishonest Maj. | Malicious | \# parties | Script | | --- | --- | --- | --- | --- | --- | | `real-bmr-party.x` | MASCOT | Y | Y | 2 or more | `real-bmr.sh` | -| `semi-bmr-party.x` | Semi | Y | Y | 2 or more | `semi-bmr.sh` | +| `semi-bmr-party.x` | Semi | Y | N | 2 or more | `semi-bmr.sh` | | `shamir-bmr-party.x` | Shamir | N | N | 3 or more | `shamir-bmr.sh` | | `mal-shamir-bmr-party.x` | Shamir | N | Y | 3 or more | `mal-shamir-bmr.sh` | | `rep-bmr-party.x` | Replicated | N | N | 3 | `rep-bmr.sh` | From e08a6adb63ea057338f5613645d9d498cb43f2a9 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 17 Aug 2022 13:22:04 +1000 Subject: [PATCH 114/265] Fix shuffling in emulation. --- Processor/Instruction.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 7763c837..fbab7aa9 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -509,9 +509,12 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) case USE_INP: if (r[0] >= N_DATA_FIELD_TYPE) throw invalid_program(); - if ((unsigned)r[1] >= usage.inputs.size()) - throw Processor_Error("Player number too high"); - usage.inputs[r[1]][r[0]] = n; + if (usage.inputs.size() != 1) + { + if ((unsigned) r[1] >= usage.inputs.size()) + throw Processor_Error("Player number too high"); + usage.inputs[r[1]][r[0]] = n; + } return int(n) >= 0; case USE_EDABIT: usage.edabits[{r[0], r[1]}] = n; From 6a424539c93f5489a6d09360f0092224552d94d8 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 25 Aug 2022 13:20:46 +1000 Subject: [PATCH 115/265] SoftSpokenOT. --- .gitignore | 4 + .gitmodules | 12 +- BMR/Party.cpp | 2 + BMR/RealGarbleWire.h | 2 - BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 4 +- BMR/Register.h | 22 +--- BMR/Register.hpp | 12 +- BMR/Register_inline.h | 6 +- CHANGELOG.md | 17 ++- CONFIG | 18 +-- Compiler/GC/instructions.py | 16 ++- Compiler/GC/types.py | 83 ++++++++++--- Compiler/comparison.py | 21 ++-- Compiler/compilerLib.py | 11 +- Compiler/floatingpoint.py | 32 ++--- Compiler/instructions.py | 69 ++++++----- Compiler/instructions_base.py | 10 +- Compiler/library.py | 29 +++-- Compiler/ml.py | 13 ++- Compiler/mpc_math.py | 4 +- Compiler/program.py | 18 ++- Compiler/types.py | 131 ++++++++++++++------- Compiler/util.py | 3 + ECDSA/P256Element.h | 2 +- ECDSA/fake-spdz-ecdsa-party.cpp | 8 +- ECDSA/ot-ecdsa-party.hpp | 9 +- FHE/Ciphertext.cpp | 2 +- FHE/NTL-Subs.cpp | 6 +- FHE/PPData.cpp | 3 +- FHEOffline/PairwiseMachine.cpp | 17 ++- FHEOffline/PairwiseMachine.h | 25 ++-- FHEOffline/SimpleGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 6 +- FHEOffline/SimpleMachine.cpp | 14 +++ FHEOffline/SimpleMachine.h | 21 +++- GC/AtlasShare.h | 5 - GC/CcdShare.h | 5 - GC/FakeSecret.cpp | 3 +- GC/MaliciousCcdShare.h | 5 - GC/Processor.h | 2 +- GC/Secret.hpp | 2 +- GC/Secret_inline.h | 8 +- GC/ShareParty.hpp | 22 ++-- GC/ShareSecret.hpp | 9 +- GC/ShareThread.h | 1 + GC/ShareThread.hpp | 6 + GC/ThreadMaster.hpp | 10 +- GC/TinierShare.h | 22 +++- GC/TinyMC.h | 10 ++ GC/TinyShare.h | 5 - HOSTS.example | 5 - Machines/OTMachine.cpp | 2 +- Machines/Tinier.cpp | 12 +- Machines/TripleMachine.cpp | 20 +++- Machines/mama-party.cpp | 32 +++-- Machines/spdz2k-party.cpp | 8 +- Machines/tinier-party.cpp | 4 +- Makefile | 94 ++++++++++----- Math/Square.h | 2 + Math/Square.hpp | 9 ++ Math/Zp_Data.cpp | 4 + Math/Zp_Data.h | 2 + Math/bigint.h | 2 +- Networking/CryptoPlayer.cpp | 30 ++++- Networking/CryptoPlayer.h | 5 + Networking/Player.cpp | 119 ++++++++----------- Networking/Player.h | 70 +++++------ Networking/PlayerBuffer.h | 23 ++++ Networking/PlayerCtSocket.h | 169 +++++++++++++++++++++++++++ Networking/Receiver.h | 5 + Networking/Sender.h | 5 + OT/BaseOT.cpp | 133 ++++++++------------- OT/BaseOT.h | 11 +- OT/BitMatrix.h | 3 + OT/BitMatrix.hpp | 7 +- OT/MamaRectangle.h | 5 + OT/NPartyTripleGenerator.h | 5 +- OT/NPartyTripleGenerator.hpp | 38 +++++- OT/OTExtension.cpp | 9 ++ OT/OTExtension.h | 2 +- OT/OTExtensionWithMatrix.cpp | 137 +++++++++++++++++++++- OT/OTExtensionWithMatrix.h | 28 ++++- OT/OTMultiplier.h | 16 +-- OT/OTMultiplier.hpp | 75 ++++++++++-- OT/OTTripleSetup.h | 25 +++- OT/Rectangle.h | 2 + OT/Rectangle.hpp | 7 ++ OT/TripleMachine.h | 7 +- Processor/BaseMachine.cpp | 17 ++- Processor/BaseMachine.h | 11 +- Processor/Instruction.hpp | 15 +-- Processor/Machine.hpp | 23 ++-- Processor/NoFilePrep.h | 2 +- Processor/OfflineMachine.hpp | 16 ++- Processor/Online-Thread.hpp | 3 + Processor/OnlineMachine.h | 5 + Processor/OnlineMachine.hpp | 10 -- Processor/OnlineOptions.cpp | 9 +- Processor/Processor.h | 4 +- Processor/Processor.hpp | 27 +++-- Programs/Source/mnist_full_C.mpc | 1 + Programs/Source/test_args.mpc | 2 +- Programs/Source/test_gc.mpc | 8 +- Protocols/ChaiGearPrep.h | 2 +- Protocols/ChaiGearPrep.hpp | 4 +- Protocols/ChaiGearShare.h | 1 + Protocols/CowGearShare.h | 1 + Protocols/FakeInput.h | 2 +- Protocols/LowGearKeyGen.hpp | 4 + Protocols/MAC_Check.h | 11 ++ Protocols/MAC_Check.hpp | 34 +++++- Protocols/MAC_Check_Base.h | 3 + Protocols/MaliciousRepPrep.hpp | 1 + Protocols/MamaPrep.h | 3 +- Protocols/MamaPrep.hpp | 4 +- Protocols/MamaShare.h | 8 +- Protocols/NoShare.h | 17 ++- Protocols/ProtocolSetup.h | 12 ++ Protocols/ReplicatedPrep.hpp | 8 +- Protocols/SecureShuffle.hpp | 10 +- Protocols/SemiPrep.h | 2 + Protocols/SemiPrep.hpp | 15 +++ Protocols/Share.h | 6 +- Protocols/ShareInterface.h | 2 +- Protocols/Spdz2kPrep.h | 2 +- Protocols/Spdz2kShare.h | 4 +- Protocols/fake-stuff.h | 2 + Protocols/fake-stuff.hpp | 19 ++- README.md | 164 ++++++++++++++------------ Scripts/build.sh | 12 +- Scripts/test_tutorial.sh | 2 +- Tools/Coordinator.cpp | 78 +++++++++++++ Tools/Coordinator.h | 43 +++++++ Tools/NetworkOptions.h | 2 +- Tools/PointerVector.h | 2 +- Tools/Subroutines.cpp | 4 +- Tools/Subroutines.h | 9 +- Tools/Waksman.h | 9 +- Tools/benchmarking.cpp | 6 +- Tools/benchmarking.h | 2 +- Tools/int.h | 5 + Tools/intrinsics.h | 1 + Tools/octetStream.h | 29 ++++- Tools/random.cpp | 35 +----- Tools/random.h | 66 +++++++++-- Tools/time-func.cpp | 3 +- Utils/Fake-Offline.cpp | 35 +++--- Utils/binary-example.cpp | 2 +- Utils/pairwise-offline.cpp | 2 +- Yao/YaoEvalWire.cpp | 2 +- Yao/YaoEvalWire.h | 2 - Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoGarbleWire.h | 2 - azure-pipelines.yml | 11 +- SimpleOT => deps/SimpleOT | 0 deps/SimplestOT_C | 1 + deps/libOTe | 1 + mpir => deps/mpir | 0 simde => deps/simde | 0 doc/Doxyfile | 2 +- doc/add-protocol.rst | 17 +++ doc/compilation.rst | 182 +++++++++++++++++++++++++++++ doc/conf.py | 10 +- doc/gen-readme.sh | 4 + doc/index.rst | 194 +------------------------------ doc/instructions.rst | 5 +- doc/low-level.rst | 18 ++- doc/machine-learning.rst | 4 +- doc/requirements.txt | 3 +- doc/troubleshooting.rst | 2 +- 171 files changed, 2181 insertions(+), 1025 deletions(-) delete mode 100644 HOSTS.example create mode 100644 Networking/PlayerBuffer.h create mode 100644 Networking/PlayerCtSocket.h create mode 100644 Tools/Coordinator.cpp create mode 100644 Tools/Coordinator.h rename SimpleOT => deps/SimpleOT (100%) create mode 160000 deps/SimplestOT_C create mode 160000 deps/libOTe rename mpir => deps/mpir (100%) rename simde => deps/simde (100%) create mode 100644 doc/compilation.rst create mode 100755 doc/gen-readme.sh diff --git a/.gitignore b/.gitignore index 9a4dd72e..0d770b1e 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,7 @@ _build/ # environment .env + +# temp doc files +doc/readme.md +doc/xml diff --git a/.gitmodules b/.gitmodules index 32dca28b..7dea81d3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,18 @@ [submodule "SimpleOT"] - path = SimpleOT + path = deps/SimpleOT url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] - path = mpir + path = deps/mpir url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion [submodule "simde"] - path = simde + path = deps/simde url = https://github.com/simd-everywhere/simde +[submodule "deps/libOTe"] + path = deps/libOTe + url = https://github.com/mkskeller/softspoken-implementation +[submodule "deps/SimplestOT_C"] + path = deps/SimplestOT_C + url = https://github.com/mkskeller/SimplestOT_C diff --git a/BMR/Party.cpp b/BMR/Party.cpp index beddd64c..0fe11a0f 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) : } cout << "Compiler: " << prev << endl; P = new PlainPlayer(N, 0); + Share::MAC_Check::setup(*P); if (argc > 4) threshold = atoi(argv[4]); cout << "Threshold for multi-threaded evaluation: " << threshold << endl; @@ -280,6 +281,7 @@ FakeProgramParty::~FakeProgramParty() cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() << " GB" << endl; #endif + Share::MAC_Check::teardown(); } void FakeProgramParty::_compute_prfs_outputs(Key* keys) diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h index 9fa2dc52..115d0bca 100644 --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -48,8 +48,6 @@ public: static void inputbvec(GC::Processor>& processor, ProcessorBase& input_processor, const vector& args); - RealGarbleWire(const Register& reg) : PRFRegister(reg) {} - void garble(PRFOutputs& prf_output, const RealGarbleWire& left, const RealGarbleWire& right); diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 760a20b8..c9e31fc6 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -110,7 +110,7 @@ void RealGarbleWire::inputbvec( { GarbleInputter inputter; processor.inputbvec(inputter, input_processor, args, - inputter.party.P->my_num()); + *inputter.party.P); } template diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 64efc550..70208ec5 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -97,8 +97,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : if (online_opts.live_prep) { mac_key.randomize(prng); - if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({*P, true}); prep = new typename T::LivePrep(0, usage); } else @@ -107,6 +105,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : prep = new Sub_Data_Files(N, prep_dir, usage); } + T::MAC_Check::setup(*P); MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); @@ -219,6 +218,7 @@ RealProgramParty::~RealProgramParty() delete garble_inputter; delete garble_protocol; cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + T::MAC_Check::teardown(); } template diff --git a/BMR/Register.h b/BMR/Register.h index f348f7b7..6a15a720 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -152,7 +152,7 @@ public: * for pipelining matters. */ - Register(int n_parties); + Register(); void init(int n_parties); void init(int rfd, int n_parties); @@ -278,10 +278,6 @@ public: static int threshold(int) { throw not_implemented(); } - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } @@ -306,8 +302,6 @@ public: void other_input(Input&, int) {} char get_output() { return 0; } - - ProgramRegister(const Register& reg) : Register(reg) {} }; class PRFRegister : public ProgramRegister @@ -319,8 +313,6 @@ public: static void load(vector >& accesses, const NoMemory& source); - PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -396,8 +388,6 @@ public: static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& proc); - EvalRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -427,8 +417,6 @@ public: static void load(vector >& accesses, const NoMemory& source); - GarbleRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char value = -1); @@ -452,8 +440,6 @@ public: static void load(vector >& accesses, const NoMemory& source); - RandomRegister(const Register& reg) : ProgramRegister(reg) {} - void randomize(); void op(const Register& left, const Register& right, Function func); @@ -469,12 +455,6 @@ public: }; -inline Register::Register(int n_parties) : - garbled_entry(n_parties), external(NO_SIGNAL), - mask(NO_SIGNAL), keys(n_parties) -{ -} - inline void KeyVector::operator=(const KeyVector& other) { resize(other.size()); diff --git a/BMR/Register.hpp b/BMR/Register.hpp index bd214a85..61790694 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, const vector& args) { NoOpInputter inputter; - int my_num = -1; - try - { - my_num = ProgramParty::s().P->my_num(); - } - catch (exception&) - { - } - processor.inputbvec(inputter, input_processor, args, my_num); + processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P); } template @@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, { EvalInputter inputter; processor.inputbvec(inputter, input_processor, args, - ProgramParty::s().P->my_num()); + *ProgramParty::s().P); } template diff --git a/BMR/Register_inline.h b/BMR/Register_inline.h index 6a275da6..7694c464 100644 --- a/BMR/Register_inline.h +++ b/BMR/Register_inline.h @@ -9,10 +9,10 @@ #include "CommonParty.h" #include "Party.h" - -inline Register ProgramRegister::new_reg() +inline Register::Register() : + garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL), + mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties()) { - return Register(CommonParty::s().get_n_parties()); } #endif /* BMR_REGISTER_INLINE_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index ac643580..e8e01534 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.3 (Aug 25, 2022) + +- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate +- Fix security bug in MAC check when using multithreading +- Fix security bug to prevent selective failure attack by checking earlier +- Fix security bug in Mama: insufficient sacrifice. +- Inverse permutation (@Quitlox) +- Easier direct compilation (@eriktaubeneck) +- Generally allow element-vector operations +- Increase maximum register size to 2^54 +- Client example in Python +- Uniform base OTs across platforms +- Multithreaded base OT computation +- Faster random bit generation in two-player Semi(2k) + ## 0.3.2 (May 27, 2022) - Secure shuffling @@ -7,7 +22,7 @@ The changelog explains changes pulled through from the private development repos - Documented BGV encryption interface - Optimized matrix multiplication in dealer protocol - Fixed security bug in homomorphic encryption parameter generation -- Fixed Security bug in Temi matrix multiplication +- Fixed security bug in Temi matrix multiplication ## 0.3.1 (Apr 19, 2022) diff --git a/CONFIG b/CONFIG index cef15e0b..fb9db200 100644 --- a/CONFIG +++ b/CONFIG @@ -31,24 +31,21 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx ARCH = -march=native MACHINE := $(shell uname -m) +ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) -# set this to 0 to avoid using AVX for OT ifeq ($(OS), Linux) -CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?) -ifeq ($(CHECK_AVX), 0) AVX_OT = 1 else AVX_OT = 0 endif else -AVX_OT = 1 -endif -else ARCH = AVX_OT = 0 endif +USE_KOS = 0 + # allow to set compiler in CONFIG.mine CXX = g++ @@ -87,7 +84,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) @@ -98,3 +95,10 @@ ifeq ($(USE_NTL),1) CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy endif endif + +ifeq ($(USE_KOS),1) +CFLAGS += -DUSE_KOS +else +CFLAGS += -std=c++17 +LDLIBS += -llibOTe -lcryptoTools +endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index e53b7187..2b5ec46a 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -342,7 +342,8 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMCB'] arg_format = ['cb','long'] -class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit memory cell with run-time address to secret bit register. @@ -351,8 +352,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] + direct = staticmethod(ldmsb) -class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit register to secret bit memory cell with run-time address. @@ -361,8 +364,10 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMSBI'] arg_format = ['sb','ci'] + direct = staticmethod(stmsb) -class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit memory cell with run-time address to clear bit register. @@ -371,8 +376,10 @@ class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMCBI'] arg_format = ['cbw','ci'] + direct = staticmethod(ldmcb) -class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit register to clear bit memory cell with run-time address. @@ -381,6 +388,7 @@ class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMCBI'] arg_format = ['cb','ci'] + direct = staticmethod(stmcb) class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5530432b..396769e0 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -198,6 +198,8 @@ class bits(Tape.Register, _structure, _bit): return 0 elif self.is_long_one(other): return self + elif isinstance(other, _vec): + return other & other.from_vec([self]) else: return self._and(other) @read_mem_value @@ -241,6 +243,13 @@ class bits(Tape.Register, _structure, _bit): return self * condition else: return self * cbit.conv(condition) + def expand(self, length): + if self.n in (length, None): + return self + elif self.n == 1: + return self.get_type(length).bit_compose([self] * length) + else: + raise CompilerError('cannot expand from %s to %s' % (self.n, length)) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -295,8 +304,15 @@ class cbits(bits): return op(self, cbits(other)) __add__ = lambda self, other: \ self.clear_op(other, inst.addcb, inst.addcbi, operator.add) - __sub__ = lambda self, other: \ - self.clear_op(-other, inst.addcb, inst.addcbi, operator.add) + def __sub__(self, other): + try: + return self + -other + except: + return type(self)(regint(self) - regint(other)) + def __rsub__(self, other): + return type(self)(other - regint(self)) + def __neg__(self): + return type(self)(-regint(self)) def _xor(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented @@ -589,7 +605,15 @@ class sbits(bits): rows = list(rows) if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() - n_columns = rows[0].n + for row in rows: + try: + n_columns = row.n + break + except: + pass + for i in range(len(rows)): + if util.is_zero(rows[i]): + rows[i] = cls.get_type(n_columns)(0) for row in rows: assert(row.n == n_columns) if n_columns == 1 and len(rows) <= cls.unit: @@ -613,7 +637,7 @@ class sbits(bits): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec): +class sbitvec(_vec, _bit): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -656,6 +680,7 @@ class sbitvec(_vec): [1, 0, 1] """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) + is_clear = False @classmethod def get_type(cls, n): """ Create type for fixed-length vector of registers of secret bits. @@ -691,10 +716,11 @@ class sbitvec(_vec): res.v = _complement_two_extend(list(vector), n)[:n] return res def __init__(self, other=None, size=None): - assert size in (None, 1) if other is not None: if util.is_constant(other): - self.v = [sbit((other >> i) & 1) for i in range(n)] + t = sbits.get_type(size or 1) + self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1)) + for i in range(n)] elif isinstance(other, _vec): self.v = self.bit_extend(other.v, n) elif isinstance(other, (list, tuple)): @@ -702,6 +728,7 @@ class sbitvec(_vec): else: self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n + assert size is None or size == self.v[0].n @classmethod def load_mem(cls, address, size=None): if size not in (None, 1): @@ -733,8 +760,9 @@ class sbitvec(_vec): def reveal(self): return util.untuplify([x.reveal() for x in self.elements()]) @classmethod - def two_power(cls, nn): - return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def two_power(cls, nn, size=1): + return cls.from_vec( + [0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1)) def coerce(self, other): if util.is_constant(other): return self.from_vec(util.bit_decompose(other, n)) @@ -818,16 +846,14 @@ class sbitvec(_vec): return other def __xor__(self, other): other = self.coerce(other) - return self.from_vec(x ^ y for x, y in zip(self.v, other)) + return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): - return self.from_vec(x & y for x, y in zip(self.v, other.v)) + return self.from_vec(x & y for x, y in zip(*self.expand(other))) + def __invert__(self): + return self.from_vec(~x for x in self.v) def if_else(self, x, y): assert(len(self.v) == 1) - try: - return self.from_vec(util.if_else(self.v[0], a, b) \ - for a, b in zip(x, y)) - except: - return util.if_else(self.v[0], x, y) + return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.v) def __len__(self): @@ -890,6 +916,24 @@ class sbitvec(_vec): elements = red.elements() elements += odd return self.from_vec(sbitvec(elements).v) + @classmethod + def comp_result(cls, x): + return cls.get_type(1).from_vec([x]) + def expand(self, other, expand=True): + m = 1 + for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): + try: + m = max(m, x.n) + except: + pass + res = [] + for y in self, other: + if isinstance(y, int): + res.append([x * sbits.get_type(m)().long_one() + for x in util.bit_decompose(y, len(self.v))]) + else: + res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) + return res class bit(object): n = 1 @@ -1139,7 +1183,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase): :param k: bit length of input """ return _sbitintbase.pow2(self, k) -class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): +class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ Vector of signed integers for parallel binary computation:: @@ -1176,7 +1220,8 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): return self other = self.coerce(other) assert(len(self.v) == len(other.v)) - v = sbitint.bit_adder(self.v, other.v) + a, b = self.expand(other) + v = sbitint.bit_adder(a, b) return self.from_vec(v) __radd__ = __add__ def __mul__(self, other): @@ -1184,7 +1229,7 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented - other_bits = util.bit_decompose(other) + _, other_bits = self.expand(other, False) m = float('inf') for x in itertools.chain(self.v, other_bits): try: @@ -1228,6 +1273,8 @@ class cbitfix(object): store_in_mem = lambda self, *args: self.v.store_in_mem(*args) @classmethod def _new(cls, value): + if isinstance(value, list): + return [cls._new(x) for x in value] res = cls() if cls.k < value.unit: bits = value.bit_decompose(cls.k) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 84bdd22b..1a139ef6 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -87,15 +87,14 @@ def LtzRing(a, k): carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] return sint.conv(msb) - return - elif program.options.ring: + else: from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) tmp = a - r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) @@ -190,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed): r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) - shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() + shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() BitLTL(u, masked, r_bits[:n_bits], 0) @@ -231,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed): shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift u = sint() BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime @@ -261,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: BitLTC1(u, c_prime, r, kappa) @@ -510,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a): movs(w[0], r[0]) movs(a_vec[0], a[0]) vmuls(k, t[0], w, a_vec) - vasm_open(k, m, t[0]) + vasm_open(k, True, m, t[0]) PreMulC_end(p, a, c, m, z) def PreMulC_with_inverses(p, a): @@ -538,7 +537,7 @@ def PreMulC_with_inverses(p, a): w[1][0] = r[0][0] for i in range(k): muls(t[0][i], w[1][i], a[i]) - asm_open(m[i], t[0][i]) + asm_open(True, m[i], t[0][i]) PreMulC_end(p, a, c, m, z) def PreMulC_without_inverses(p, a): @@ -563,7 +562,7 @@ def PreMulC_without_inverses(p, a): #adds(tt[0][i], t[0][i], a[i]) #subs(tt[1][i], tt[0][i], a[i]) #startopen(tt[1][i]) - asm_open(u[i], t[0][i]) + asm_open(True, u[i], t[0][i]) for i in range(k-1): muls(v[i], r[i+1], s[i]) w[0] = r[0] @@ -579,7 +578,7 @@ def PreMulC_without_inverses(p, a): mulm(z[i], s[i], u_inv[i]) for i in range(k): muls(t[1][i], w[i], a[i]) - asm_open(m[i], t[1][i]) + asm_open(True, m[i], t[1][i]) PreMulC_end(p, a, c, m, z) def PreMulC_end(p, a, c, m, z): @@ -646,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index eb800ba4..4a4706ff 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -181,7 +181,8 @@ class Compiler: action="store_true", dest="invperm", help="speedup inverse permutation (only use in two-party, " - "semi-honest environment)") + "semi-honest environment)" + ) parser.add_option( "-C", "--CISC", @@ -244,11 +245,9 @@ class Compiler: self.VARS[op.__name__] = op # add open and input separately due to name conflict - self.VARS["open"] = instructions.asm_open self.VARS["vopen"] = instructions.vasm_open self.VARS["gopen"] = instructions.gasm_open self.VARS["vgopen"] = instructions.vgasm_open - self.VARS["input"] = instructions.asm_input self.VARS["ginput"] = instructions.gasm_input self.VARS["comparison"] = comparison @@ -268,7 +267,6 @@ class Compiler: "sgf2nuint", "sgf2nuint32", "sgf2nfloat", - "sfloat", "cfloat", "squant", ]: @@ -276,6 +274,9 @@ class Compiler: def prep_compile(self, name=None): self.parse_args() + if len(self.args) < 1 and name is None: + self.parser.print_help() + exit(1) self.build_program(name=name) self.build_vars() @@ -372,7 +373,7 @@ class Compiler: ) self.prep_compile(self.compile_name) print( - f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" + "Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__) ) self.compile_function() self.finalize_compile() diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index d3d3f8c5..94a47f1b 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,7 +28,7 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_edabit: + if program.Program.prog.use_edabit(): r_prime, r = types.sint.get_edabit(k) elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) @@ -36,7 +36,7 @@ def maskRing(a, k): else: r = [types.sint.get_random_bit() for i in range(k)] r_prime = types.sint.bit_compose(r) - c = ((a + r_prime) << shift).reveal() >> shift + c = ((a + r_prime) << shift).reveal(False) >> shift return c, r def maskField(a, k, kappa): @@ -47,7 +47,7 @@ def maskField(a, k, kappa): comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) # always signed due to usage in equality testing a += two_power(k) - asm_open(c, a + two_power(k) * r_dprime + r_prime) + asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -233,7 +233,7 @@ def Inv(a): ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a - asm_open(c[0], s) + asm_open(True, c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) @@ -281,7 +281,7 @@ def BitDecRingRaw(a, k, m): else: r_bits = [types.sint.get_random_bit() for i in range(m)] r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() + shifted = ((a - r) << n_shift).reveal(False) masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits @@ -299,7 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) pow2 = two_power(k + kappa) - asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) instructions_base.reset_global_vector_size() return res @@ -341,10 +341,10 @@ def B2U_from_Pow2(pow2a, l, kappa): if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l assert n_shift > 0 - c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + + asm_open(True, c, pow2a + two_power(l) * t + sum(two_power(i) * r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(r_bits[0].bit_decompose_clear(c, l)) @@ -386,11 +386,11 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): r_dprime += t1 - t2 if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l - c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk - asm_open(c, a + r_dprime + r_prime) + asm_open(True, c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) @@ -416,7 +416,7 @@ def TruncInRing(to_shift, l, pow2m): rev *= pow2m r_bits = [types.sint.get_random_bit() for i in range(l)] r = types.sint.bit_compose(r_bits) - shifted = (rev - (r << n_shift)).reveal() + shifted = (rev - (r << n_shift)).reveal(False) masked = shifted >> n_shift bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) @@ -457,7 +457,7 @@ def Int2FL(a, gamma, l, kappa=None): v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t - p = (p + gamma - 1 - l) * (1 -z) + p = (p + gamma - 1 - l) * z.bit_not() return v, p, z, s def FLRound(x, mode): @@ -530,7 +530,7 @@ def TruncPrRing(a, k, m, signed=True): msb = r_bits[-1] n_shift = n_ring - (k + 1) tmp = a + r - masked = (tmp << n_shift).reveal() + masked = (tmp << n_shift).reveal(False) shifted = (masked << 1 >> (n_shift + m + 1)) overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ @@ -551,7 +551,7 @@ def TruncPrField(a, k, m, kappa=None): k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime - c = (b + r).reveal() + c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime d = (a - a_prime) / two_to_m @@ -667,14 +667,14 @@ def BitDecFull(a, n_bits=None, maybe_mixed=False): def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) - c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False)) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) - c = (a-b).reveal() + c = (a-b).reveal(False) cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 058b6ff4..91809ba4 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -387,6 +387,14 @@ class use(base.Instruction): code = base.opcodes['USE'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), + find_in_dict(data_types, args[1].i)): + args[2].i} + class use_inp(base.Instruction): """ Input usage. Necessary to avoid reusage while using preprocessing from files. @@ -398,6 +406,13 @@ class use_inp(base.Instruction): code = base.opcodes['USE_INP'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), 'input', args[1].i): + args[2].i} + class use_edabit(base.Instruction): """ edaBit usage. Necessary to avoid reusage while using preprocessing from files. Also used to multithreading for expensive @@ -410,6 +425,10 @@ class use_edabit(base.Instruction): code = base.opcodes['USE_EDABIT'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i} + class use_matmul(base.Instruction): """ Matrix multiplication usage. Used for multithreading of preprocessing. @@ -471,6 +490,11 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] + @classmethod + def get_usage(cls, args): + return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp', + args[0].str): args[1].i} + class nplayers(base.Instruction): """ Store number of players in clear integer register. @@ -783,30 +807,6 @@ class gbitcom(base.Instruction): return True -### -### Special GF(2) arithmetic instructions -### - -@base.vectorize -class gmulbitc(base.MulBase): - r""" Clear GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITC'] - arg_format = ['cgw','cg','cg'] - - def is_gf2n(self): - return True - -@base.vectorize -class gmulbitm(base.MulBase): - r""" Secret GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITM'] - arg_format = ['sgw','sg','cg'] - - def is_gf2n(self): - return True - ### ### Arithmetic with immediate values ### @@ -1707,6 +1707,7 @@ class writesockets(base.IOInstruction): from registers into a socket for a specified client id. If the protocol uses MACs, the client should be different for every party. + :param: number of arguments to follow :param: client id (regint) :param: message type (must be 0) :param: vector size (int) @@ -2162,14 +2163,19 @@ class gconvgf2n(base.Instruction): class asm_open(base.VarArgsInstruction): """ Reveal secret registers (vectors) to clear registers (vectors). - :param: number of argument to follow (multiple of two) + :param: number of argument to follow (odd number) + :param: check after opening (0/1) :param: destination (cint) :param: source (sint) :param: (repeat the last two)... """ __slots__ = [] code = base.opcodes['OPEN'] - arg_format = tools.cycle(['cw','s']) + arg_format = tools.chain(['int'], tools.cycle(['cw','s'])) + + def merge(self, other): + self.args[0] |= other.args[0] + self.args += other.args[1:] @base.gf2n @base.vectorize @@ -2415,12 +2421,17 @@ class shuffle_base(base.DataInstruction): def logn(n): return int(math.ceil(math.log(n, 2))) + @classmethod + def n_swaps(cls, n): + logn = cls.logn(n) + return logn * 2 ** logn - 2 ** logn + 1 + def add_gen_usage(self, req_node, n): # hack for unknown usage req_node.increment(('bit', 'inverse'), float('inf')) # minimal usage with two relevant parties logn = self.logn(n) - n_switches = logn * 2 ** logn + n_switches = self.n_swaps(n) for i in range(self.n_relevant_parties): req_node.increment((self.field_type, 'input', i), n_switches) # multiplications for bit check @@ -2430,7 +2441,7 @@ class shuffle_base(base.DataInstruction): def add_apply_usage(self, req_node, n, record_size): req_node.increment(('bit', 'inverse'), float('inf')) logn = self.logn(n) - n_switches = logn * 2 ** logn * self.n_relevant_parties + n_switches = self.n_swaps(n) * self.n_relevant_parties if n != 2 ** logn: record_size += 1 req_node.increment((self.field_type, 'triple'), @@ -2548,7 +2559,7 @@ class sqrs(base.CISC): c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) - asm_open(c[0], s[2]) + asm_open(False, c[0], s[2]) mulc(c[1], c[0], c[0]) mulm(s[3], self.args[1], c[0]) adds(s[4], s[3], s[3]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index f7aa48f9..fb60d908 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -542,7 +542,7 @@ def cisc(function): def get_bytes(self): assert len(self.kwargs) < 2 - res = int_to_bytes(opcodes['CISC']) + res = LongArgFormat.encode(opcodes['CISC']) res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1) name = self.function.__name__ String.check(name) @@ -720,7 +720,7 @@ class IntArgFormat(ArgFormat): class LongArgFormat(IntArgFormat): @classmethod def encode(cls, arg): - return struct.pack('>Q', arg) + return list(struct.pack('>Q', arg)) def __init__(self, f): self.i = struct.unpack('>Q', f.read(8))[0] @@ -741,6 +741,8 @@ class ImmediateGF2NAF(IntArgFormat): class PlayerNoAF(IntArgFormat): @classmethod def check(cls, arg): + if not util.is_constant(arg): + raise CompilerError('Player number must be known at compile time') super(PlayerNoAF, cls).check(arg) if arg > 256: raise ArgumentError(arg, 'Player number > 256') @@ -823,7 +825,7 @@ class Instruction(object): return (prefix << self.code_length) + self.code def get_encoding(self): - enc = int_to_bytes(self.get_code()) + enc = LongArgFormat.encode(self.get_code()) # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) @@ -958,7 +960,7 @@ class ParsedInstruction: except AttributeError: pass read = lambda: struct.unpack('>I', f.read(4))[0] - full_code = read() + full_code = struct.unpack('>Q', f.read(8))[0] code = full_code % (1 << Instruction.code_length) self.size = full_code >> Instruction.code_length self.type = cls.reverse_opcodes[code] diff --git a/Compiler/library.py b/Compiler/library.py index 799f85d2..1da50e9c 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -243,6 +243,10 @@ def store_in_mem(value, address): try: value.store_in_mem(address) except AttributeError: + if isinstance(value, (list, tuple)): + for i, x in enumerate(value): + store_in_mem(x, address + i) + return # legacy if value.is_clear: if isinstance(address, cint): @@ -261,11 +265,13 @@ def reveal(secret): try: return secret.reveal() except AttributeError: + if secret.is_clear: + return secret if secret.is_gf2n: res = cgf2n() else: res = cint() - instructions.asm_open(res, secret) + instructions.asm_open(True, res, secret) return res @vectorize @@ -883,10 +889,10 @@ def range_loop(loop_body, start, stop=None, step=None): def for_range(start, stop=None, step=None): """ Decorator to execute loop bodies consecutively. Arguments work as - in Python :py:func:`range`, but they can by any public + in Python :py:func:`range`, but they can be any public integer. Information has to be passed out via container types such - as :py:class:`~Compiler.types.Array` or declaring registers as - :py:obj:`global`. Note that changing Python data structures such + as :py:class:`~Compiler.types.Array` or using :py:func:`update`. + Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. @@ -901,13 +907,11 @@ def for_range(start, stop=None, step=None): @for_range(n) def _(i): a[i] = i - global x - x += 1 + x.update(x + 1) Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop even when using - :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` - instead. + :py:class:`~Compiler.types.Array` in a loop. Use + :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -1518,6 +1522,11 @@ def if_then(condition): state = State() if callable(condition): condition = condition() + try: + if not condition.is_clear: + raise CompilerError('cannot branch on secret values') + except AttributeError: + pass state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ @@ -1889,7 +1898,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): theta = int(ceil(log(k/3.5) / log(2))) base.set_global_vector_size(b.size) - alpha = b.get_type(2 * k).two_power(2*f) + alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() diff --git a/Compiler/ml.py b/Compiler/ml.py index 02f0f04e..173c2eac 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -148,7 +148,7 @@ def argmax(x): """ Compute index of maximum element. :param x: iterable - :returns: sint + :returns: sint or 0 if :py:obj:`x` has length 1 """ def op(a, b): comp = (a[1] > b[1]) @@ -164,7 +164,7 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) + m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() @@ -2384,6 +2384,11 @@ class Optimizer: for layer in self.layers: layer.output_weights() + def summary(self): + sizes = [var.total_size() for var in self.thetas] + print(sizes) + print('Trainable params:', sum(sizes)) + class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -2653,9 +2658,7 @@ class keras: return list(self.opt.thetas) def summary(self): - sizes = [var.total_size() for var in self.trainable_variables] - print(sizes) - print('Trainable params:', sum(sizes)) + self.opt.summary() def build(self, input_shape, batch_size=128): data_input_shape = input_shape diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 47253dc4..abdaf123 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -295,7 +295,6 @@ def exp2_fx(a, zero_output=False, as19=False): intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): - assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) @@ -327,6 +326,7 @@ def exp2_fx(a, zero_output=False, as19=False): s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] + bits_to_check = bits[n_bits:-1] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) @@ -338,7 +338,7 @@ def exp2_fx(a, zero_output=False, as19=False): r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() + shifted = ((a.v - r) << n_shift).reveal(False) masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) diff --git a/Compiler/program.py b/Compiler/program.py index d7b57db9..f92ab497 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -545,7 +545,7 @@ class Program(object): """ if change is None: if not self._invperm: - self.relevant_opts.add('invperm') + self.relevant_opts.add("invperm") return self._invperm else: self._invperm = change @@ -1276,7 +1276,7 @@ class Tape: "can_eliminate", "duplicates", ] - maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 + maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): """Creates a new register. @@ -1382,6 +1382,20 @@ class Tape: for dup in self.duplicates: dup.duplicates = self.duplicates + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + if self.program != other.program: + raise CompilerError( + 'cannot update register with one from another thread') + self.link(other) + @property def is_gf2n(self): return ( diff --git a/Compiler/types.py b/Compiler/types.py index d63295c8..6a150bee 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,8 +127,14 @@ def vectorize(operation): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise VectorMismatch('Different vector sizes of operands: %d/%d' - % (self.size, args[0].size)) + if min(args[0].size, self.size) == 1: + size = max(args[0].size, self.size) + self = self.expand_to_vector(size) + args = list(args) + args[0] = args[0].expand_to_vector(size) + else: + raise VectorMismatch('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) try: res = operation(self, *args, **kwargs) @@ -249,8 +255,11 @@ class _number(Tape._no_truth): try: return self.mul(other) except VectorMismatch: - # try reverse multiplication - return NotImplemented + if type(self) != type(other) and 1 in (self.size, other.size): + # try reverse multiplication + return NotImplemented + else: + raise __radd__ = __add__ __rmul__ = __mul__ @@ -1658,6 +1667,8 @@ class regint(_register, _int): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') intoutput(player, self) class localint(Tape._no_truth): @@ -2081,12 +2092,15 @@ class _secret(_register, _secret_structure): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - @vectorize def __truediv__(self, other): """ Secret field division. :param other: any compatible type """ - return self * (self.clear_type(1) / other) + try: + one = self.clear_type(1, size=other.size) + except AttributeError: + one = self.clear_type(1) + return self * (one / other) @vectorize def __rtruediv__(self, other): @@ -2113,12 +2127,12 @@ class _secret(_register, _secret_structure): @set_instruction_type @vectorize - def reveal(self): + def reveal(self, check=True): """ Reveal secret value publicly. :rtype: relevant clear type """ res = self.clear_type() - asm_open(res, self) + asm_open(check, res, self) return res @set_instruction_type @@ -2166,9 +2180,7 @@ class sint(_secret, _int): signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and exponentation (``**``). Modulo only works if the right-hand - operator is a compile-time power of two, and exponentiation only - works if the base is two or if the exponent is a compile-time - integer. + operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -2672,7 +2684,7 @@ class sint(_secret, _int): return comparison.TruncZeros(self, bit_length, n_zeros, signed) @staticmethod - def two_power(n): + def two_power(n, size=None): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): @@ -2690,7 +2702,6 @@ class sint(_secret, _int): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -2698,13 +2709,14 @@ class sint(_secret, _int): :returns: :py:class:`personal` """ if not util.is_constant(player): - secret_mask = sint() - player_mask = cint() - inputmaskreg(secret_mask, player_mask, regint.conv(player)) + secret_mask = sint(size=self.size) + player_mask = cint(size=self.size) + inputmaskreg(secret_mask, player_mask, + regint.conv(player).expand_to_vector(self.size)) return personal(player, - (self + secret_mask).reveal() - player_mask) + (self + secret_mask).reveal(False) - player_mask) else: - res = personal(player, self.clear_type()) + res = personal(player, self.clear_type(size=self.size)) privateoutput(self.size, player, res._v, self) return res @@ -2856,6 +2868,10 @@ class sintbit(sint): else: return super(sintbit, self).__rsub__(other) + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A @@ -2873,6 +2889,7 @@ class sgf2n(_secret, _gf2n): instruction_type = 'gf2n' clear_type = cgf2n reg_type = 'sg' + long_one = staticmethod(lambda: 1) @classmethod def get_type(cls, length): @@ -3022,6 +3039,7 @@ class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False + comp_result = staticmethod(lambda x: x) @staticmethod def half_adder(a, b): @@ -3241,12 +3259,16 @@ class _bitint(Tape._no_truth): del carries[-1] return sums, carries + def expand(self, other): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return a, b + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) from util import bit_not, bit_and, bit_xor + a, b = self.expand(other) n = 1 for x in (a + b): try: @@ -3293,8 +3315,7 @@ class _bitint(Tape._no_truth): a[-1], b[-1] = b[-1], a[-1] def comparison(self, other, const_rounds=False, index=None): - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) + a, b = self.expand(other) self.prep_comparison(a, b) if const_rounds: return self.get_highest_different_bits(a, b, index) @@ -3304,30 +3325,33 @@ class _bitint(Tape._no_truth): def __lt__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 0) + res = util.if_else(not_equal, x, 0) else: - return self.comparison(other, True, 1) + res = self.comparison(other, True, 1) + return self.comp_result(res) def __le__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 1) + res = util.if_else(not_equal, x, x.long_one()) else: - return 1 - self.comparison(other, True, 0) + res = self.comparison(other, True, 0).bit_not() + return self.comp_result(res) def __ge__(self, other): - return 1 - (self < other) + return (self < other).bit_not() def __gt__(self, other): - return 1 - (self <= other) + return (self <= other).bit_not() def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] - return floatingpoint.KMul(diff_bits) + diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] + return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), + diff_bits)) def __ne__(self, other): - return 1 - (self == other) + return (self == other).bit_not() equal = __eq__ @@ -3881,7 +3905,6 @@ class cfix(_number, _structure): def output_if(self, cond): cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) - @vectorize def binary_output(self, player=None): """ Write double-precision floating-point number to ``Player-Data/Binary-Output-P-``. @@ -3890,7 +3913,11 @@ class cfix(_number, _structure): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') + set_global_vector_size(self.size) floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + reset_global_vector_size() class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ @@ -4124,6 +4151,7 @@ class _single(_number, _secret_structure): class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] + is_clear = False def set_precision(cls, f, k = None): cls.f = f @@ -4349,6 +4377,18 @@ class _fix(_single): """ Bit decomposition. """ return self.v.bit_decompose(n_bits or self.k) + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + assert self.f == other.f + self.v.update(other.v) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4737,6 +4777,8 @@ class sfloat(_number, _secret_structure): returning :py:class:`sint`. The other operand can be any of sint/cfix/regint/cint/int/float. + This data type only works with arithmetic computation. + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) """ __slots__ = ['v', 'p', 'z', 's', 'size'] @@ -4835,6 +4877,9 @@ class sfloat(_number, _secret_structure): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): + if program.options.binary: + raise CompilerError( + 'floating-point operations not supported with binary circuits') self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -5227,7 +5272,13 @@ class Array(_vectorizable): def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to - that. """ + that. + + :param l: Python iterable or register vector + :returns: :py:class:`Array` of appropriate type containing the contents + of :py:obj:`l` + + """ if isinstance(l, cls): return l if isinstance(l, _number): @@ -6099,12 +6150,12 @@ class SubMultiArray(_vectorizable): try: res_matrix[i] = self.value_type.row_matrix_mul( self[i], other, res_params) - except AttributeError: + except (AttributeError, CompilerError): # fallback for binary circuits - @library.for_range(other.sizes[1]) + @library.for_range_opt(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[1]) + @library.for_range_opt(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix @@ -6223,13 +6274,7 @@ class SubMultiArray(_vectorizable): res[i] = self.direct_mul_trans(other, indices=indices) def direct_mul_to_matrix(self, other): - """ Matrix multiplication in the virtual machine. - - :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :returns: :py:obj:`Matrix` - - """ + # Obsolete. Use dot(). res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res diff --git a/Compiler/util.py b/Compiler/util.py index 9d84df22..c1bedc27 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -238,6 +238,9 @@ def mem_size(x): except AttributeError: return 1 +def find_in_dict(d, v): + return list(d.keys())[list(d.values()).index(v)] + class set_by_id(object): def __init__(self, init=[]): self.content = {} diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 4657b5d8..27ea7f75 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -22,7 +22,7 @@ private: EC_POINT* point; public: - typedef void next; + typedef P256Element next; typedef void Square; static const true_type invertible; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index f0e3257c..5bef730d 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -45,12 +45,13 @@ int main(int argc, const char** argv) string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::MAC_Check::setup(P); + Share::MAC_Check::setup(P); + DataPositions usage; Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); - BaseMachine machine; - machine.ot_setups.push_back({P, false}); SubProcessor proc(_, MCp, prep, P); pShare sk, __; @@ -60,4 +61,7 @@ int main(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts); + + pShare::MAC_Check::teardown(); + Share::MAC_Check::teardown(); } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 569aa791..ebf0aea9 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -92,9 +92,6 @@ void run(int argc, const char** argv) P256Element::init(); P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); - BaseMachine machine; - machine.ot_setups.push_back({P, true}); - P256Element::Scalar keyp; SeededPRNG G; keyp.randomize(G); @@ -102,6 +99,9 @@ void run(int argc, const char** argv) typedef T pShare; DataPositions usage; + pShare::MAC_Check::setup(P); + T::MAC_Check::setup(P); + OnlineOptions::singleton.batch_size = 1; typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); @@ -137,4 +137,7 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + + pShare::MAC_Check::teardown(); + T::MAC_Check::teardown(); } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 00e05131..62cbd528 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -130,7 +130,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk) assert(p != 0); for (auto& x : r) { - G.get(x, params->p0().numBits() - p.numBits() - 1); + G.get(x, params->p0().numBits() - p.numBits() - 1); x *= p; } tmp.from(r, 0); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 794e7431..f3973026 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -368,7 +368,8 @@ ZZX Cyclotomic(int N) int phi_N(int N) { if (((N - 1) & N) != 0) - throw runtime_error("compile with NTL support"); + throw runtime_error( + "compile with NTL support (USE_NTL=1 in CONFIG.mine)"); else if (N == 1) return 1; else @@ -418,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly) for (int i=0; i& elem) const */ } -void PPData::from_eval(vector& elem) const +void PPData::from_eval(vector&) const { // avoid warning - elem.empty(); throw not_implemented(); /* diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index b19dd62c..dd3f8968 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -17,15 +17,13 @@ PairwiseMachine::PairwiseMachine(Player& P) : { } -PairwiseMachine::PairwiseMachine(int argc, const char** argv) : - MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")), - other_pks(N.num_players(), {setup_p.params, 0}), - pk(other_pks[N.my_num()]), sk(pk) +RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) : + MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise")) { init(); } -void PairwiseMachine::init() +void RealPairwiseMachine::init() { if (use_gf2n) { @@ -63,7 +61,7 @@ PairwiseSetup& PairwiseMachine::setup() } template -void PairwiseMachine::setup_keys() +void RealPairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); @@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys() if (i != N.my_num()) other_pks[i].unpack(os[i]); set_mac_key(s.alphai); + Share::MAC_Check::setup(P); } template -void PairwiseMachine::set_mac_key(T alphai) +void RealPairwiseMachine::set_mac_key(T alphai) { typedef typename T::FD FD; auto& N = P; @@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } -template void PairwiseMachine::setup_keys(); -template void PairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h index c2283443..a8a0c649 100644 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -10,7 +10,7 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/PairwiseSetup.h" -class PairwiseMachine : public MachineBase +class PairwiseMachine : public virtual MachineBase { public: PairwiseSetup setup_p; @@ -23,15 +23,6 @@ public: vector enc_alphas; PairwiseMachine(Player& P); - PairwiseMachine(int argc, const char** argv); - - void init(); - - template - void setup_keys(); - - template - void set_mac_key(T alphai); template PairwiseSetup& setup(); @@ -42,4 +33,18 @@ public: void check(Player& P) const; }; +class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine +{ +public: + RealPairwiseMachine(int argc, const char** argv); + + void init(); + + template + void setup_keys(); + + template + void set_mac_key(T alphai); +}; + #endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */ diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index b2701b2c..be5ee2c1 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,7 +12,7 @@ template