diff --git a/.gitignore b/.gitignore index f2ea5d86..45ac6761 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ keys/* ############################## CONFIG.mine config_mine.py +HOSTS # Temporary files # ################### diff --git a/Auth/MAC_Check.cpp b/Auth/MAC_Check.cpp index a14c24b6..86c47b3e 100644 --- a/Auth/MAC_Check.cpp +++ b/Auth/MAC_Check.cpp @@ -43,14 +43,20 @@ MAC_Check::~MAC_Check() { } +template +void MAC_Check::PrepareSending(vector& values, const vector >& S) +{ + values.resize(S.size()); + for (unsigned int i=0; i void MAC_Check::POpen_Begin(vector& values,const vector >& S,const Player& P) { AddToMacs(S); - values.resize(S.size()); - for (unsigned int i=0; istart(values, P); @@ -115,9 +121,9 @@ void MAC_Check::CheckIfNeeded(const Player& P) template -void MAC_Check::AddToCheck(const T& mac, const T& value, const Player& P) +void MAC_Check::AddToCheck(const Share& share, const T& value, const Player& P) { - macs.push_back(mac); + macs.push_back(share.get_mac()); vals.push_back(value); popen_cnt++; CheckIfNeeded(P); @@ -179,6 +185,115 @@ int mc_base_id(int function_id, int thread_num) return (function_id << 28) + ((T::field_type() + 1) << 24) + (thread_num << 16); } +template +MAC_Check_Z2k::MAC_Check_Z2k(const T& ai, const Share& dummy_element, int opening_sum, int max_broadcast, int send_player) : + MAC_Check(ai, opening_sum, max_broadcast, send_player), + dummy_element(dummy_element) +{ +} + +template +void MAC_Check_Z2k::AddToCheck(const Share& share, const T& value, const Player& P) +{ + shares.push_back(share.get_share()); + MAC_Check::AddToCheck(share, value, P); +} + +template +void MAC_Check_Z2k::AddToMacs(const vector >& shares) +{ + for (auto& share : shares) + this->shares.push_back(share.get_share()); + MAC_Check::AddToMacs(shares); +} + +template +void MAC_Check_Z2k::PrepareSending(vector& values, + const vector >& S) +{ + values.clear(); + values.reserve(S.size()); + for (auto& share : S) + values.push_back(V(share.get_share())); +} + +template +Share MAC_Check_Z2k::get_random_element() { + return dummy_element; +} + +template +void MAC_Check_Z2k::set_random_element(const Share& random_element) { + this->dummy_element = random_element; +} + +template +void MAC_Check_Z2k::Check(const Player& P) +{ + if (this->WaitingForCheck() == 0) + return; + + int k = V::N_BITS; + octet seed[SEED_SIZE]; + Create_Random_Seed(seed,P,SEED_SIZE); + PRNG G; + G.SetSeed(seed); + + T y, mj; + y.assign_zero(); + mj.assign_zero(); + vector chi; + for (int i = 0; i < this->popen_cnt; ++i) + { + U temp_chi; + temp_chi.randomize(G); + T xi = this->vals[i]; + y += xi * temp_chi; + T mji = this->macs[i]; + mj += temp_chi * mji; + chi.push_back(temp_chi); + } + + Share r = get_random_element(); + T lj = r.get_mac(); + U pj; + pj.assign_zero(); + for (int i = 0; i < this->popen_cnt; ++i) + { + T xji = shares[i]; + V xbarji = xji; + U pji = U((xji - xbarji) >> k); + pj += chi[i] * pji; + } + pj += U(r.get_share()); + + U pbar(pj); + vector pj_stream(P.num_players()); + pj.pack(pj_stream[P.my_num()]); + P.Broadcast_Receive(pj_stream, true); + for (int j=0; jalphai * y) - (((this->alphai * pbar)) << k) + (lj << k); + vector zjs(P.num_players()); + zjs[P.my_num()] = zj; + Commit_And_Open(zjs, P); + + T zj_sum; + zj_sum.assign_zero(); + for (int i = 0; i < P.num_players(); ++i) + zj_sum += zjs[i]; + + this->vals.erase(this->vals.begin(), this->vals.begin() + this->popen_cnt); + this->macs.erase(this->macs.begin(), this->macs.begin() + this->popen_cnt); + this->shares.erase(this->shares.begin(), this->shares.begin() + this->popen_cnt); + this->popen_cnt=0; + if (!zj_sum.is_zero()) { throw mac_fail(); } +} + template Separate_MAC_Check::Separate_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum, int max_broadcast, int send_player) : @@ -381,4 +496,13 @@ template class Direct_MAC_Check; template class Parallel_MAC_Check; #endif +template class MAC_Check_Z2k, Z2<32>, Z2<32> >; +template class MAC_Check_Z2k, Z2<64>, Z2<64> >; +template class MAC_Check_Z2k, Z2<32>, Z2<64> >; +template class MAC_Check_Z2k, Z2<96>, Z2<64> >; +template class MAC_Check_Z2k, Z2<64>, Z2<128> >; +template class MAC_Check_Z2k, Z2<96>, Z2<160> >; +template class MAC_Check >; template class MAC_Check >; +template class MAC_Check >; +template class MAC_Check >; diff --git a/Auth/MAC_Check.h b/Auth/MAC_Check.h index 59e85ce9..8ecf8057 100644 --- a/Auth/MAC_Check.h +++ b/Auth/MAC_Check.h @@ -76,7 +76,8 @@ class MAC_Check : public TreeSum /* MAC Share */ T alphai; - void AddToMacs(const vector< Share >& shares); + virtual void AddToMacs(const vector< Share >& shares); + virtual void PrepareSending(vector& values,const vector >& S); void AddToValues(vector& values); void GetValues(vector& values); void CheckIfNeeded(const Player& P); @@ -99,7 +100,7 @@ class MAC_Check : public TreeSum */ virtual void POpen_Begin(vector& values,const vector >& S,const Player& P); virtual void POpen_End(vector& values,const vector >& S,const Player& P); - void AddToCheck(const T& mac, const T& value, const Player& P); + virtual void AddToCheck(const Share& share, const T& value, const Player& P); virtual void Check(const Player& P); int number() const { return values_opened; } @@ -107,6 +108,25 @@ class MAC_Check : public TreeSum const T& get_alphai() const { return alphai; } }; +template +class MAC_Check_Z2k : public MAC_Check +{ +protected: + vector shares; + Share dummy_element; + Share get_random_element(); + + void AddToMacs(const vector< Share >& shares); + void PrepareSending(vector& values,const vector >& S); + +public: + void AddToCheck(const Share& share, const T& value, const Player& P); + MAC_Check_Z2k(const T& ai, const Share& dummy_element = {}, int opening_sum=10, int max_broadcast=10, int send_player=0); + virtual void Check(const Player& P); + void set_random_element(const Share& random_element); + virtual ~MAC_Check_Z2k() {}; +}; + template using MAC_Check_ = MAC_Check; diff --git a/Auth/Subroutines.cpp b/Auth/Subroutines.cpp index 042699ce..8a0a42b3 100644 --- a/Auth/Subroutines.cpp +++ b/Auth/Subroutines.cpp @@ -252,4 +252,9 @@ template void Create_Random(gf2n_short& ans,const Player& P); template void Commit_And_Open(vector& data,const Player& P); template void Create_Random(gfp& ans,const Player& P); +template void Commit_And_Open(vector >& data,const Player& P); +template void Commit_And_Open(vector >& data,const Player& P); +template void Commit_And_Open(vector >& data,const Player& P); template void Commit_And_Open(vector >& data,const Player& P); +template void Commit_And_Open(vector >& data,const Player& P); +template void Commit_And_Open(vector >& data,const Player& P); diff --git a/Auth/fake-stuff.cpp b/Auth/fake-stuff.cpp index c585ae90..82967b0c 100644 --- a/Auth/fake-stuff.cpp +++ b/Auth/fake-stuff.cpp @@ -55,6 +55,34 @@ void check_share(vector >& Sa,T& value,T& mac,int N,const T& key) } } +template +void check_share(vector >& Sa, + V& value, + T& mac, + int N, + const T& key) +{ + value.assign(0); + mac.assign(0); + + for (int i=0; i >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G); template void make_share(vector >& Sa,const gfp& a,int N,const gfp& key,PRNG& G); @@ -66,7 +94,26 @@ template void make_share(vector >& Sa,const gf2n_short& a,int template void check_share(vector >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key); #endif -template void check_share(vector > >& Sa,Z2<64>& value,Z2<64>& mac,int N,const Z2<64>& key); +template void check_share( + vector > >& Sa, + Z2<64>& value, + Z2<160>& mac, + int N, + const Z2<160>& key); + +template void check_share( + vector > >& Sa, + Z2<64>& value, + Z2<128>& mac, + int N, + const Z2<128>& key); + +template void check_share( + vector > >& Sa, + Z2<32>& value, + Z2<64>& mac, + int N, + const Z2<64>& key); // Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) void expand_byte(gf2n_short& a,int b) diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h index d4f4c144..7b2eef33 100644 --- a/Auth/fake-stuff.h +++ b/Auth/fake-stuff.h @@ -6,6 +6,7 @@ #include "Math/gf2n.h" #include "Math/gfp.h" +#include "Math/Z2k.h" #include "Math/Share.h" #include @@ -17,6 +18,13 @@ void make_share(vector >& Sa,const T& a,int N,const T& key,PRNG& G); template void check_share(vector >& Sa,T& value,T& mac,int N,const T& key); +template +void check_share(vector >& Sa, + V& value, + T& mac, + int N, + const T& key); + void expand_byte(gf2n_short& a,int b); void collapse_byte(int& b,const gf2n_short& a); diff --git a/CONFIG b/CONFIG index 607c26a7..e609e449 100644 --- a/CONFIG +++ b/CONFIG @@ -13,17 +13,29 @@ PREP_DIR = '-DPREP_DIR="Player-Data/"' # set for 128-bit GF(2^n) and/or OT preprocessing USE_GF2N_LONG = 0 +# set SPDZ_2K bit length parameters K and S +SPDZ2K_K = -DSPDZ2K_K=64 +SPDZ2K_S = -DSPDZ2K_S=64 + +# use additional optimizations in vole protocol +USE_OPT_VOLE = 1 +NUM_VOLE_CHALLENGES = -DNUM_VOLE_CHALLENGES=3 + # set to -march= for optimization # AVX2 support (Haswell or later) changes the bit matrix transpose -ARCH = -mtune=native -mavx +ARCH = -mtune=native -mavx -march=native -#use CONFIG.mine to overwrite DIR settings +# use CONFIG.mine to overwrite DIR settings -include CONFIG.mine ifeq ($(USE_GF2N_LONG),1) GF2N_LONG = -DUSE_GF2N_LONG endif +ifeq ($(USE_OPT_VOLE),1) +OPT_VOLE = -DUSE_OPT_VOLE +endif + # MAX_MOD_SZ must be at least ceil(len(p)/len(word)) # Default is 2, which suffices for 128-bit p # MOD = -DMAX_MOD_SZ=2 @@ -40,7 +52,7 @@ LDLIBS += -lrt endif CXX = g++ -CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror +CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MOD) $(SPDZ2K_K) $(SPDZ2K_S) $(OPT_VOLE) $(NUM_VOLE_CHALLENGES) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 -mbmi2 --std=c++11 -Werror -no-pie CPPFLAGS = $(CFLAGS) LD = g++ diff --git a/Check-Offline-Z2k.cpp b/Check-Offline-Z2k.cpp new file mode 100644 index 00000000..4bc1f56a --- /dev/null +++ b/Check-Offline-Z2k.cpp @@ -0,0 +1,89 @@ +// (C) 2018 University of Bristol. See License.txt + +#include "Math/Z2k.h" +#include "Math/Share.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" + +#include +#include +#include + +template +void check_triples_Z2k(int n_players, string type_char = "") +{ + T keyp; keyp.assign_zero(); + U pp; + ifstream inpf; + for (int i= 0; i < n_players; i++) + { + stringstream ss; + ss << get_prep_dir(n_players, 128, 128) << "Player-MAC-Key-"; + if (type_char.size()) + ss << type_char; + else + ss << U::type_char(); + ss << "-P" << i; + cout << "Opening file " << ss.str() << endl; + inpf.open(ss.str().c_str()); + if (inpf.fail()) { throw file_error(ss.str()); } + pp.input(inpf,false); + cout << " Key " << i << "\t p: " << pp << endl; + keyp.add(pp); + inpf.close(); + } + cout << "--------------\n"; + cout << "Final Keys :\t p: " << keyp << endl; + + ifstream* inputFiles = new ifstream[n_players]; + for (int i = 0; i < n_players; i++) + { + stringstream ss; + ss << get_prep_dir(n_players, 128, 128) << "Triples-"; + if (type_char.size()) + ss << type_char; + else + ss << T::type_char(); + ss << "-P" << i; + inputFiles[i].open(ss.str().c_str()); + cout << "Opening file " << ss.str() << endl; + } + + int j = 0; + while (inputFiles[0].peek() != EOF) + { + V a,b,c,prod; + T mac; + vector> as(n_players), bs(n_players), cs(n_players); + + for (int i = 0; i < n_players; i++) + { + as[i].input(inputFiles[i], false); + bs[i].input(inputFiles[i], false); + cs[i].input(inputFiles[i], false); + } + + check_share(as, a, mac, n_players, keyp); + check_share(bs, b, mac, n_players, keyp); + check_share(cs, c, mac, n_players, keyp); + + prod.mul(a, b); + if (prod != c) + { + cout << j << ": " << c << " != " << a << " * " << b << endl; + throw bad_value(); + } + j++; + } + + cout << dec << j << " correct triples of type " << T::type_string() << endl; + delete[] inputFiles; +} + +int main(int argc, char** argv) +{ + int n_players = 2; + if (argc > 1) + n_players = atoi(argv[1]); + check_triples_Z2k, Z2, Z2>(n_players); +} diff --git a/Check-Offline.cpp b/Check-Offline.cpp index e7ad7c92..4f945ad2 100644 --- a/Check-Offline.cpp +++ b/Check-Offline.cpp @@ -297,11 +297,11 @@ int main(int argc, const char** argv) check_tuples(key2, N, dataF, DATA_INVERSE); check_tuples(keyp, N, dataF, DATA_INVERSE); - Z2<64> keyz2k; - for (int i = 0; i < N; i++) - keyz2k += read_mac_key >(PREP_DATA_PREFIX, i); + // Z2<160> keyz2k; + // for (int i = 0; i < N; i++) + // keyz2k += read_mac_key >(PREP_DATA_PREFIX, i); - check_mult_triples(keyz2k, N, dataF, DATA_Z2K); + // check_mult_triples(keyz2k, N, dataF, DATA_Z2K); for (int i = 0; i < N; i++) delete dataF[i]; diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 7f4c9777..776d01a2 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -94,6 +94,10 @@ class mac_fail: public exception { virtual const char* what() const throw() { return "MacCheck Failure"; } }; +class consistency_check_fail: public exception + { virtual const char* what() const throw() + { return "OT consistency check failed"; } + }; class invalid_program: public exception { virtual const char* what() const throw() { return "Invalid Program"; } diff --git a/FHEOffline/SimpleMachine.cpp b/FHEOffline/SimpleMachine.cpp index 06253b0f..56b5b29b 100644 --- a/FHEOffline/SimpleMachine.cpp +++ b/FHEOffline/SimpleMachine.cpp @@ -291,7 +291,6 @@ void MachineBase::run() << " kbit per " << item_type().substr(0, item_type().length() - 1) << endl; cout << "Produced " << total << " " << item_type() << " in " << timer.elapsed() << " seconds" << endl; - cout << "Throughput: " << total / timer.elapsed() << tradeoff() << endl; cout << "CPU time: " << cpu_timer.elapsed() << endl; extern unsigned long long sent_amount, sent_counter; @@ -300,6 +299,8 @@ void MachineBase::run() cout << sent_amount / sent_counter / N.num_players() << " bytes per call" << endl; + cout << "Time: " << timer.elapsed() << endl; + cout << "Throughput: " << total / timer.elapsed() << endl; mult_performance(); } diff --git a/Makefile b/Makefile index c6abbe4b..3226a3cb 100644 --- a/Makefile +++ b/Makefile @@ -27,14 +27,14 @@ COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(OT) LIB = libSPDZ.a -LIBSIMPLEOT = SimpleOT/libsimpleot.a +LIBSIMPLEOT = -no-pie SimpleOT/libsimpleot.a # used for dependency generation OBJS = $(COMPLETE) DEPS := $(OBJS:.o=.d) -all: gen_input online offline externalIO +all: gen_input online offline externalIO check-passive.x Check-Offline-Z2k.x ifeq ($(USE_NTL),1) all: overdrive she-offline @@ -63,6 +63,9 @@ Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) $(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS) +Check-Offline-Z2k.x: Check-Offline-Z2k.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) Check-Offline-Z2k.cpp -o Check-Offline-Z2k.x $(COMMON) $(PROCESSOR) $(LDLIBS) + Server.x: Server.cpp $(COMMON) $(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS) diff --git a/Math/Share.cpp b/Math/Share.cpp index 569a661c..5c1f5c24 100644 --- a/Math/Share.cpp +++ b/Math/Share.cpp @@ -2,9 +2,9 @@ #include "Share.h" -//#include "Tools/random.h" #include "Math/gfp.h" #include "Math/gf2n.h" +#include "Math/Z2k.h" #include "Math/operators.h" @@ -141,3 +141,8 @@ template class Share; template gf2n_short combine(const vector< Share >& S); template bool check_macs(const vector< Share >& S,const gf2n_short& key); #endif + +template class Share >; +template class Share >; +template class Share >; +template class Share >; diff --git a/Math/Z2k.cpp b/Math/Z2k.cpp index 256f6fe4..a60443e5 100644 --- a/Math/Z2k.cpp +++ b/Math/Z2k.cpp @@ -9,7 +9,7 @@ template Z2::Z2(const bigint& x) : Z2() { auto mp = x.get_mpz_t(); - memcpy(a, mp->_mp_d, sizeof(mp_limb_t) * min(N_WORDS, abs(mp->_mp_size))); + memcpy(a, mp->_mp_d, min((size_t)N_BYTES, sizeof(mp_limb_t) * abs(mp->_mp_size))); if (mp->_mp_size < 0) *this = Z2() - *this; } @@ -20,31 +20,6 @@ bool Z2::get_bit(int i) const return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); } -template -Z2 Z2::operator+(const Z2& other) const -{ - Z2 res; - mpn_add(res.a, a, N_WORDS, other.a, N_WORDS); - res.a[N_WORDS - 1] &= UPPER_MASK; - return res; -} - -template -Z2 Z2::operator-(const Z2& other) const -{ - Z2 res; - mpn_sub(res.a, a, N_WORDS, other.a, N_WORDS); - res.a[N_WORDS - 1] &= UPPER_MASK; - return res; -} - -template -Z2& Z2::operator+=(const Z2& other) -{ - *this = *this + other; - return *this; -} - template Z2 Z2::operator<<(int i) const { @@ -52,7 +27,23 @@ Z2 Z2::operator<<(int i) const int n_limb_shift = i / N_LIMB_BITS; for (int j = n_limb_shift; j < N_WORDS; j++) res.a[j] = a[j - n_limb_shift]; - mpn_lshift(res.a, res.a, N_WORDS, i % N_LIMB_BITS); + int n_inside_shift = i % N_LIMB_BITS; + if (n_inside_shift > 0) + mpn_lshift(res.a, res.a, N_WORDS, n_inside_shift); + res.a[N_WORDS - 1] &= UPPER_MASK; + return res; +} + +template +Z2 Z2::operator>>(int i) const +{ + Z2 res; + int n_limb_shift = i / N_LIMB_BITS; + for (int j = 0; j < N_WORDS - n_limb_shift; j++) + res.a[j] = a[j + n_limb_shift]; + int n_inside_shift = i % N_LIMB_BITS; + if (n_inside_shift > 0) + mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift); res.a[N_WORDS - 1] &= UPPER_MASK; return res; } @@ -60,13 +51,17 @@ Z2 Z2::operator<<(int i) const template bool Z2::operator==(const Z2& other) const { +#ifdef DEBUG_MPN + for (int i = 0; i < N_WORDS; i++) + cout << "cmp " << hex << a[i] << " " << other.a[i] << endl; +#endif return mpn_cmp(a, other.a, N_WORDS) == 0; } template void Z2::randomize(PRNG& G) { - G.get_octets((octet*)a, N_BYTES); + G.get_octets_((octet*)a); } template @@ -98,6 +93,26 @@ void Z2::output(ostream& s, bool human) const s.write((char*)a, N_BYTES); } -#define NS X(32) X(64) X(96) X(128) X(160) X(192) X(224) X(256) X(320) X(672) -#define X(N) template class Z2; -NS +template +ostream& operator<<(ostream& o, const Z2& x) +{ + bool printing = false; + o << "0x" << noshowbase; + o.width(0); + for (int i = x.N_WORDS - 1; i >= 0; i--) + if (x.a[i] or printing or i == 0) + { + o << hex << x.a[i]; + printing = true; + o.width(16); + o.fill('0'); + } + o.width(0); + return o; +} + +#define X(N) \ + template class Z2; \ + template ostream& operator<<(ostream& o, const Z2& x); + +X(32) X(64) X(96) X(128) X(160) X(192) X(224) X(256) X(288) X(320) X(352) X(384) X(416) X(448) X(512) X(672) diff --git a/Math/Z2k.h b/Math/Z2k.h index f2dd3ad9..3ed05626 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -13,6 +13,7 @@ using namespace std; #include "Tools/avx_memcpy.h" #include "bigint.h" #include "field_types.h" +#include "mpn_fixed.h" template class Z2 @@ -28,6 +29,7 @@ class Z2 mp_limb_t a[N_WORDS]; + public: static const int N_BITS = K; static const int N_BYTES = (K + 7) / 8; @@ -40,12 +42,15 @@ public: static DataFieldType field_type() { return DATA_Z2K; } + template + static Z2 Mul(const Z2& x, const Z2& y); + typedef Z2 value_type; Z2() { assign_zero(); } Z2(uint64_t x) : Z2() { a[0] = x; } Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); } - Z2(int x) : Z2(x < 0 ? bigint(x) : uint64_t(x)) {} + Z2(int x) : Z2() { if (x < 0) *this = bigint(x); else *this = uint64_t(x); } Z2(const bigint& x); Z2(const void* buffer) : Z2() { assign(buffer); } template @@ -61,17 +66,23 @@ public: const void* get_ptr() const { return a; } + void negate() { + throw not_implemented(); + } + Z2 operator+(const Z2& other) const; Z2 operator-(const Z2& other) const; template Z2 operator*(const Z2& other) const; - Z2 operator*(bool other) const { return other ? *this : Z2(0); } + Z2 operator*(bool other) const { return other ? *this : Z2(); } Z2& operator+=(const Z2& other); + Z2& operator-=(const Z2& other); Z2 operator<<(int i) const; + Z2 operator>>(int i) const; bool operator==(const Z2& other) const; bool operator!=(const Z2& other) const { return not (*this == other); } @@ -79,8 +90,8 @@ public: void add(const Z2& a, const Z2& b) { *this = a + b; } void add(const Z2& a) { *this += a; } void sub(const Z2& a, const Z2& b) { *this = a - b; } - template - void mul(const Z2& a, const Z2& b) { *this = a * b; } + template + void mul(const Z2& a, const Z2& b) { *this = Z2::Mul(a, b); } template void add(octetStream& os) { add(os.consume(size())); } @@ -100,26 +111,54 @@ public: friend ostream& operator<<(ostream& o, const Z2& x); }; +template +inline Z2 Z2::operator+(const Z2& other) const +{ + Z2 res; + mpn_add_fixed_n(res.a, a, other.a); + res.a[N_WORDS - 1] &= UPPER_MASK; + return res; +} + +template +Z2 Z2::operator-(const Z2& other) const +{ + Z2 res; + mpn_sub_fixed_n(res.a, a, other.a); + res.a[N_WORDS - 1] &= UPPER_MASK; + return res; +} + +template +inline Z2& Z2::operator+=(const Z2& other) +{ + mpn_add_fixed_n(a, other.a, a); + a[N_WORDS - 1] &= UPPER_MASK; + return *this; +} + +template +Z2& Z2::operator-=(const Z2& other) +{ + *this = *this - other; + return *this; +} + +template +template +inline Z2 Z2::Mul(const Z2& x, const Z2& y) +{ + Z2 res; + mpn_mul_fixed_(res.a, x.a, y.a); + res.a[N_WORDS - 1] &= UPPER_MASK; + return res; +} + template template inline Z2 Z2::operator*(const Z2& other) const { - mp_limb_t product[N_WORDS + other.N_WORDS]; - if (K < L) - mpn_mul(product, other.a, other.N_WORDS, a, N_WORDS); - else - mpn_mul(product, a, N_WORDS, other.a, other.N_WORDS); - Z2 res; - avx_memcpy(res.a, product, res.N_BYTES); - return res; -} - -template -inline ostream& operator<<(ostream& o, const Z2& x) -{ - for (int i = 0; i < x.N_WORDS; i++) - o << hex << x.a[i] << " "; - return o; + return Z2::Mul(*this, other); } #endif /* MATH_Z2K_H_ */ diff --git a/Math/modp.cpp b/Math/modp.cpp index a3840742..e86e307f 100644 --- a/Math/modp.cpp +++ b/Math/modp.cpp @@ -14,7 +14,7 @@ bool modp::rewind = false; void modp::randomize(PRNG& G, const Zp_Data& ZpD) { bigint x=G.randomBnd(ZpD.pr); - memcpy(this->x, x.get_mpz_t()->_mp_d, ZpD.get_t()); + memcpy(this->x, x.get_mpz_t()->_mp_d, ZpD.get_t() * sizeof(mp_limb_t)); } void modp::pack(octetStream& o,const Zp_Data& ZpD) const diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h new file mode 100644 index 00000000..67828c36 --- /dev/null +++ b/Math/mpn_fixed.h @@ -0,0 +1,219 @@ +/* + * mpn_fixed.h + * + */ + +#ifndef MATH_MPN_FIXED_H_ +#define MATH_MPN_FIXED_H_ + +#include +#include +#include + +inline void debug_print(const char* name, const mp_limb_t* x, int n) +{ + (void)name, (void)x, (void)n; +#ifdef DEBUG_MPN + cout << name << " "; + for (int i = 0; i < n; i++) + cout << hex << x[n-i-1] << " "; + cout << endl; +#endif +} + +template +inline void mpn_add_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + mpn_add(res, x, N, y, N); +} + +template <> +inline void mpn_add_fixed_n<1>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + *res = *x + *y; +} + +template <> +inline void mpn_add_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, y, 2 * sizeof(mp_limb_t)); + debug_print("x", x, 2); + debug_print("y", y, 2); + debug_print("res", res, 2); + __asm__ ( + "add %2, %0 \n" + "adc %3, %1 \n" + : "+&r"(res[0]), "+r"(res[1]) + : "rm"(x[0]), "rm"(x[1]) + : "cc" + ); + debug_print("res", res, 2); +} + +template <> +inline void mpn_add_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, y, 3 * sizeof(mp_limb_t)); + debug_print("x", x, 3); + debug_print("y", y, 3); + debug_print("res", res, 3); + __asm__ ( + "add %3, %0 \n" + "adc %4, %1 \n" + "adc %5, %2 \n" + : "+&r"(res[0]), "+&r"(res[1]), "+r"(res[2]) + : "rm"(x[0]), "rm"(x[1]), "rm"(x[2]) + : "cc" + ); + debug_print("res", res, 3); +} + +template <> +inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, y, 4 * sizeof(mp_limb_t)); + __asm__ ( + "add %4, %0 \n" + "adc %5, %1 \n" + "adc %6, %2 \n" + "adc %7, %3 \n" + : "+&r"(res[0]), "+&r"(res[1]), "+&r"(res[2]), "+r"(res[3]) + : "rm"(x[0]), "rm"(x[1]), "rm"(x[2]), "rm"(x[3]) + : "cc" + ); +} + +template +inline void mpn_sub_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + mpn_add(res, x, N, y, N); +} + +template <> +inline void mpn_sub_fixed_n<1>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + *res = *x - *y; +} + +template <> +inline void mpn_sub_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, x, 2 * sizeof(mp_limb_t)); + __asm__ ( + "sub %2, %0 \n" + "sbb %3, %1 \n" + : "+r"(res[0]), "+r"(res[1]) + : "rm"(y[0]), "rm"(y[1]) + : "cc" + ); +} + +template <> +inline void mpn_sub_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, x, 3 * sizeof(mp_limb_t)); + __asm__ volatile ( + "sub %3, %0 \n" + "sbb %4, %1 \n" + "sbb %5, %2 \n" + : "+r"(res[0]), "+r"(res[1]), "+r"(res[2]) + : "rm"(y[0]), "rm"(y[1]), "rm"(y[2]) + : "cc" + ); +} + +template <> +inline void mpn_sub_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + memcpy(res, x, 4 * sizeof(mp_limb_t)); + __asm__ volatile ( + "sub %4, %0 \n" + "sbb %5, %1 \n" + "sbb %6, %2 \n" + "sbb %7, %3 \n" + : "+r"(res[0]), "+r"(res[1]), "+r"(res[2]), "+r"(res[3]) + : "rm"(y[0]), "rm"(y[1]), "rm"(y[2]), "rm"(y[3]) + : "cc" + ); +} + +inline void mpn_add_n_use_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y, mp_size_t n) +{ + switch (n) + { +#define CASE(N) \ + case N: \ + mpn_add_fixed_n(res, x, y); \ + break; + CASE(1); + CASE(2); + CASE(3); + CASE(4); + default: + mpn_add_n(res, x, y, n); + break; + } +} + +template +inline void mpn_addmul_1_fixed_(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x) +{ + mp_limb_t lower[L], higher[L]; + lower[L - 1] = 0; + higher[L - 1] = 0; + for (int j = 0; j < M; j++) + lower[j] = _mulx_u64(x, y[j], (long long unsigned*)higher + j); + debug_print("lower", lower, L); + debug_print("higher", higher, L); + debug_print("before add", res, L + 1); + mpn_add_fixed_n(res, lower, res); + debug_print("first add", res, L + 1); + mpn_add_fixed_n(res + 1, higher, res + 1); + debug_print("second add", res, L + 1); +} + +template +inline void mpn_addmul_1_fixed(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x) +{ + mpn_addmul_1_fixed_(res, y, x); +} + +template +inline void mpn_mul_fixed_(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + assert(L <= N + M + 2); + mp_limb_t tmp[N + M + 2]; + avx_memzero(tmp, sizeof(tmp)); + for (int i = 0; i < N; i++) + mpn_addmul_1_fixed(tmp + i, y, x[i]); + inline_mpn_copyi(res, tmp, L); +} + +template <> +inline void mpn_mul_fixed_<3,3,3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + inline_mpn_zero(res, 3); + mp_limb_t* tmp = res; + mpn_addmul_1_fixed_<3,3>(tmp, y, x[0]); + mpn_addmul_1_fixed_<2,2>(tmp + 1, y, x[1]); + mpn_addmul_1_fixed_<1,1>(tmp + 2, y, x[2]); +} + +template <> +inline void mpn_mul_fixed_<4,4,2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + inline_mpn_zero(res, 4); + mp_limb_t* tmp = res; + mpn_addmul_1_fixed_<3,2>(tmp, y, x[0]); + mpn_addmul_1_fixed_<3,2>(tmp + 1, y, x[1]); + mpn_addmul_1_fixed_<2,2>(tmp + 2, y, x[2]); + mpn_addmul_1_fixed_<1,1>(tmp + 3, y, x[3]); +} + +template +inline void mpn_mul_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + mpn_mul_fixed_(res, x, y); +} + +#endif /* MATH_MPN_FIXED_H_ */ diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 32792d35..f5f1649f 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -192,8 +192,7 @@ void square128::transpose() } \ } #ifdef __AVX2__ -#define SIXTEENTOSIXTYFOUR Y(16) Y(32) Y(64) -#define Y(I) { \ +#define Z(I) { \ const int J = I / 8; \ for (int i = 0; i < 16 / J; i++) \ { \ @@ -216,7 +215,7 @@ void square128::transpose() base += 16; X(8) base = k * 16; - SIXTEENTOSIXTYFOUR + Z(16) Z(32) Z(64) for (int i = 0; i < 8; i++) { int a = base + i; @@ -335,7 +334,10 @@ void square128::to(gfp& result) for (int i = 0; i < 128; i++) { memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i])); - mpn_lshift(product, tmp[i/64], 4, i % 64); + if (i % 64 == 0) + memcpy(product, tmp[i/64], sizeof(product)); + else + mpn_lshift(product, tmp[i/64], 4, i % 64); mpn_add_n(sum, product, sum, 4); } mp_limb_t q[4], ans[4]; @@ -567,7 +569,7 @@ Matrix& Matrix::operator=(const Matrix& other) avx_memcpy(seed, &source, min(SEED_SIZE, (int)sizeof(source))); prng.SetSeed(seed); dest = 0; - prng.get_octets((octet*)dest.get_ptr(), U::N_ROW_BYTES); + prng.get_octets_((octet*)dest.get_ptr()); } } return *this; @@ -718,7 +720,7 @@ Slice& Slice::rsub(Slice& other) if (bm.squares.size() < other.end) throw invalid_length(); for (size_t i = other.start; i < other.end; i++) - bm.squares[i].rsub(other.bm.squares[i]); + bm.squares[i].template rsub(other.bm.squares[i]); return *this; } @@ -727,7 +729,7 @@ template Slice& Slice::sub(BitVector& other, int repeat) { if (end * U::PartType::N_COLUMNS > other.size() * repeat) - throw invalid_length(); + throw invalid_length(to_string(U::PartType::N_COLUMNS)); for (size_t i = start; i < end; i++) { bm.squares[i].template sub(other.get_ptr_to_byte(i / repeat, @@ -741,7 +743,7 @@ template void Slice::randomize(int row, PRNG& G) { for (size_t i = start; i < end; i++) - bm.squares[i].randomize(row, G); + bm.squares[i].template randomize(row, G); } template @@ -749,7 +751,7 @@ template void Slice::conditional_add(BitVector& conditions, U& other, bool useOffset) { for (size_t i = start; i < end; i++) - bm.squares[i].conditional_add(conditions, other.squares[i], useOffset * i); + bm.squares[i].template conditional_add(conditions, other.squares[i], useOffset * i); } template <> @@ -774,6 +776,7 @@ void Slice::print() template void Slice::pack(octetStream& os) const { + os.reserve(U::PartType::size() * (end - start)); for (size_t i = start; i < end; i++) bm.squares[i].pack(os); } @@ -801,9 +804,11 @@ template void Slice, Z2 > > >::conditional_add< \ Z2 >(BitVector& conditions, \ Matrix, Z2 > >& other, bool useOffset); \ -X(96, 160) +//X(96, 160) Y(64, 96) +Y(64, 64) +Y(32, 32) template class Matrix; diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 8dbaabe0..e6212394 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -20,10 +20,14 @@ using namespace std; union square128 { + typedef int128 RowType; + const static int N_ROWS = 128; const static int N_COLUMNS = 128; const static int N_ROW_BYTES = 128 / 8; + static size_t size() { return N_ROWS * sizeof(__m128i); } + #ifdef __AVX2__ __m256i doublerows[64]; #endif diff --git a/OT/BitVector.h b/OT/BitVector.h index 34199fde..fba71718 100644 --- a/OT/BitVector.h +++ b/OT/BitVector.h @@ -10,6 +10,7 @@ using namespace std; #include #include +#include #include "Exceptions/Exceptions.h" #include "Networking/data.h" @@ -161,6 +162,7 @@ class BitVector bool get_bit(int i) const { + assert(i < (int)nbits); return (bytes[i/8] >> (i % 8)) & 1; } void set_bit(int i,unsigned int a) @@ -216,7 +218,7 @@ class BitVector void pack(octetStream& o) const; void unpack(octetStream& o); - string str(size_t end = SIZE_MAX) + string str(size_t end = SIZE_MAX) const { stringstream ss; ss << hex; diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.cpp index 25f77608..6acf5d5f 100644 --- a/OT/NPartyTripleGenerator.cpp +++ b/OT/NPartyTripleGenerator.cpp @@ -68,15 +68,19 @@ public: template void amplify(BitVector& a, T& b, Rectangle, T>& c, PRNG& G) { - avx_memzero(this, sizeof(*this)); assert(a.size() == M); this->b = b; for (int i = 0; i < N; i++) { - T r; - r.randomize(G); - this->a[i] += r * a.get_bit(i); - this->c[i] += r * c.rows[i]; + this->a[i] = 0; + this->c[i] = 0; + for (int j = 0; j < M; j++) + { + T r; + r.randomize(G); + this->a[i] += r * a.get_bit(j); + this->c[i] += r * c.rows[j]; + } } } @@ -106,8 +110,45 @@ public: } }; -template -class ShareTriple : public Triple, N> +// T is Z2, U is Z2 +template +class PlainTriple_ : public PlainTriple +{ +public: + + template + void amplify(BitVector& a, U& b, Rectangle, U>& c, PRNG& G) + { + assert(a.size() == M); + this->b = b; + for (int i = 0; i < N; i++) + { + U aa = 0, cc = 0; + for (int j = 0; j < M; j++) + { + U r; + r.randomize(G); + if (a.get_bit(j)) + aa += r; + cc += U::Mul(r, c.rows[j]); + } + this->a[i] = aa; + this->c[i] = cc; + } + } + + void to(vector& valueBits, int i) + { + for (int j = 0; j < N; j++) + { + valueBits[0].set_portion(i * N + j, this->a[j]); + valueBits[2].set_portion(i * N + j, this->c[j]); + } + } +}; + +template +class ShareTriple_ : public Triple, N> { public: void from(PlainTriple& triple, vector& ot_multipliers, @@ -119,9 +160,10 @@ public: for (int j = 0; j < repeat; j++) { T value = triple.byIndex(l,j); - T mac = value * generator.machine.get_mac_key(); + T mac; + mac.mul(value, generator.machine.get_mac_key()); for (int i = 0; i < generator.nparties-1; i++) - mac += ((MascotMultiplier*)ot_multipliers[i])->macs[l][iTriple * repeat + j]; + mac += ((OTMultiplierMac*)ot_multipliers[i])->macs.at(l).at(iTriple * repeat + j); Share& share = this->byIndex(l,j); share.set_share(value); share.set_mac(mac); @@ -129,9 +171,59 @@ public: } } - T computeCheckMAC(const T& maskedA) + Share get_check_value(PRNG& G) { - return this->c[0].get_mac() - maskedA * this->b.get_mac(); + Share res; + res += G.get() * this->b; + for (int i = 0; i < N; i++) + { + res += G.get() * this->a[i]; + res += G.get() * this->c[i]; + } + return res; + } + + template + Triple, 1> reduce() { + Triple, 1> triple; + + Share _a; + _a.set_share(V(this->a[0].get_share())); + _a.set_mac(V(this->a[0].get_mac())); + triple.a[0] = _a; + + Share _b; + _b.set_share(V(this->b.get_share())); + _b.set_mac(V(this->b.get_mac())); + triple.b = _b; + + Share _c; + _c.set_share(V(this->c[0].get_share())); + _c.set_mac(V(this->c[0].get_mac())); + triple.c[0] = _c; + + return triple; + } + +}; + +template +class TripleToSacrifice : public Triple, 1> +{ +public: + template + void prepare_sacrifice(const ShareTriple_& uncheckedTriple, PRNG& G) + { + this->b = uncheckedTriple.b; + U t; + t.randomize(G); + this->a[0] = uncheckedTriple.a[0] * t - uncheckedTriple.a[1]; + this->c[0] = uncheckedTriple.c[0] * t - uncheckedTriple.c[1]; + } + + Share computeCheckShare(const T& maskedA) + { + return this->c[0] - maskedA * this->b; } }; @@ -221,11 +313,23 @@ OTMultiplierBase* NPartyTripleGenerator::new_multiplier(int i) } template<> -OTMultiplierBase* NPartyTripleGenerator::new_multiplier >(int i) +OTMultiplierBase* NPartyTripleGenerator::new_multiplier >(int i) { return new Spdz2kMultiplier<64, 96>(*this, i); } +template<> +OTMultiplierBase* NPartyTripleGenerator::new_multiplier >(int i) +{ + return new Spdz2kMultiplier<64, 64>(*this, i); +} + +template<> +OTMultiplierBase* NPartyTripleGenerator::new_multiplier >(int i) +{ + return new Spdz2kMultiplier<32, 32>(*this, i); +} + template void NPartyTripleGenerator::generate() { @@ -300,11 +404,9 @@ void NPartyTripleGenerator::generateBits(vector< OTMultiplierBase* >& ot_m valueBits[0].randomize_blocks(share_prg); - for (int i = 0; i < nparties-1; i++) - pthread_cond_signal(&ot_multipliers[i]->ready); + signal_multipliers(); timers["Authentication OTs"].start(); - for (int i = 0; i < nparties-1; i++) - pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + wait_for_multipliers(); timers["Authentication OTs"].stop(); octet seed[SEED_SIZE]; @@ -334,8 +436,7 @@ void NPartyTripleGenerator::generateBits(vector< OTMultiplierBase* >& ot_m for (int j = 0; j < nTriplesPerLoop; j++) bits[j].output(outputFile, false); - for (int i = 0; i < nparties-1; i++) - pthread_cond_signal(&ot_multipliers[i]->ready); + signal_multipliers(); } } @@ -355,58 +456,107 @@ void NPartyTripleGenerator::generateBits(vector< OTMultiplierBase* >& ot_multipl throw not_implemented(); } +template +void NPartyTripleGenerator::generateTriplesZ2k(vector< OTMultiplierBase* >& ot_multipliers, + ofstream& outputFile) +{ + (void) outputFile; + const int TAU = Spdz2kMultiplier::TAU; + valueBits.resize(3); + for (int i = 0; i < 2; i++) + valueBits[2*i].resize(TAU * nTriplesPerLoop); + valueBits[1].resize((K + S) * (nTriplesPerLoop + 1)); + b_padded_bits.resize((K + 2 * S) * (nTriplesPerLoop + 1)); + vector< PlainTriple_, Z2, 2> > amplifiedTriples(nTriplesPerLoop); + vector< ShareTriple_, Z2, 2> > uncheckedTriples(nTriplesPerLoop); + MAC_Check_Z2k, Z2, Z2 > MC(machine.get_mac_key >()); + + start_progress(); + + for (int k = 0; k < nloops; k++) + { + print_progress(k); + + for (int j = 0; j < 2; j++) + valueBits[j].randomize_blocks(share_prg); + + for (int j = 0; j < nTriplesPerLoop + 1; j++) + { + Z2 b(valueBits[1].get_ptr_to_bit(j, K + S)); + b_padded_bits.set_portion(j, Z2(b)); + } + + timers["OTs"].start(); + wait_for_multipliers(); + timers["OTs"].stop(); + + octet seed[SEED_SIZE]; + Create_Random_Seed(seed, globalPlayer, SEED_SIZE); + PRNG G; + G.SetSeed(seed); + + for (int j = 0; j < nTriplesPerLoop; j++) + { + BitVector a(valueBits[0].get_ptr_to_bit(j, TAU), TAU); + Z2 b(valueBits[1].get_ptr_to_bit(j, K + S)); + Z2kRectangle c; + c.mul(a, b); + timers["Triple computation"].start(); + for (int i = 0; i < nparties-1; i++) + { + c += ((Spdz2kMultiplier*)ot_multipliers[i])->c_output[j]; + } + timers["Triple computation"].stop(); + amplifiedTriples[j].amplify(a, b, c, G); + amplifiedTriples[j].to(valueBits, j); + } + + signal_multipliers(); + wait_for_multipliers(); + + for (int j = 0; j < nTriplesPerLoop; j++) + { + uncheckedTriples[j].from(amplifiedTriples[j], ot_multipliers, j, *this); + } + + // we can skip the consistency check since we're doing a mac-check next + // get piggy-backed random value + Z2 r_share = b_padded_bits.get_ptr_to_bit(nTriplesPerLoop, K + 2 * S); + Z2 r_mac; + r_mac.mul(r_share, this->machine.template get_mac_key>()); + for (int i = 0; i < this->nparties-1; i++) + r_mac += ((OTMultiplierMac>*)ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop); + Share> r; + r.set_share(r_share); + r.set_mac(r_mac); + + MC.set_random_element(r); + sacrifice, Z2, Z2>(uncheckedTriples, MC, G); + + signal_multipliers(); + } +} + template<> void NPartyTripleGenerator::generateTriples >(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile) { - const int K = 64; - const int S = 96; - const int TAU = Spdz2kMultiplier::TAU; - valueBits.resize(3); - for (int i = 0; i < 2; i++) - valueBits[2*i].resize(TAU * nTriplesPerLoop); - valueBits[1].resize((K + S) * nTriplesPerLoop); - vector< PlainTriple, 2> > amplifiedTriples(nTriplesPerLoop); + this->template generateTriplesZ2k<32, 32>(ot_multipliers, outputFile); +} - start_progress(); +template<> +void NPartyTripleGenerator::generateTriples >(vector< OTMultiplierBase* >& ot_multipliers, + ofstream& outputFile) +{ + this->template generateTriplesZ2k<64, 64>(ot_multipliers, outputFile); +} - for (int k = 0; k < nloops; k++) - { - print_progress(k); - for (int j = 0; j < 2; j++) - valueBits[j].randomize_blocks(share_prg); - - signal_multipliers(); - timers["OTs"].start(); - wait_for_multipliers(); - timers["OTs"].stop(); - - octet seed[SEED_SIZE]; - Create_Random_Seed(seed, globalPlayer, SEED_SIZE); - PRNG G; - G.SetSeed(seed); - - for (int j = 0; j < nTriplesPerLoop; j++) - { - BitVector a(valueBits[0].get_ptr_to_bit(j, TAU), TAU); - Z2 b(valueBits[1].get_ptr_to_bit(j, K + S)); - Z2kRectangle c; - c.mul(a, b); - timers["Triple computation"].start(); - for (int i = 0; i < nparties-1; i++) - { - c += ((Spdz2kMultiplier*)ot_multipliers[i])->c_output[j]; - } - timers["Triple computation"].stop(); - PlainTriple, 2> amplifiedTriple; - amplifiedTriple.amplify(a, b, c, G); - if (machine.output) - amplifiedTriple.output(outputFile); - } - - signal_multipliers(); - } +template<> +void NPartyTripleGenerator::generateTriples >(vector< OTMultiplierBase* >& ot_multipliers, + ofstream& outputFile) +{ + this->template generateTriplesZ2k<64, 96>(ot_multipliers, outputFile); } template @@ -440,8 +590,7 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult valueBits[j].randomize_blocks(share_prg); timers["OTs"].start(); - for (int i = 0; i < nparties-1; i++) - pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + wait_for_multipliers(); timers["OTs"].stop(); for (int j = 0; j < nPreampTriplesPerLoop; j++) @@ -497,11 +646,9 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) amplifiedTriples[iTriple].to(valueBits, iTriple); - for (int i = 0; i < nparties-1; i++) - pthread_cond_signal(&ot_multipliers[i]->ready); + signal_multipliers(); timers["Authentication OTs"].start(); - for (int i = 0; i < nparties-1; i++) - pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + wait_for_multipliers(); timers["Authentication OTs"].stop(); for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) @@ -518,38 +665,88 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult if (machine.check) { - vector< Share > maskedAs(nTriplesPerLoop); - vector< ShareTriple > maskedTriples(nTriplesPerLoop); - for (int j = 0; j < nTriplesPerLoop; j++) - { - maskedTriples[j].amplify(uncheckedTriples[j], G); - maskedAs[j] = maskedTriples[j].a[0]; - } - - vector openedAs(nTriplesPerLoop); - MC.POpen_Begin(openedAs, maskedAs, globalPlayer); - MC.POpen_End(openedAs, maskedAs, globalPlayer); - - for (int j = 0; j < nTriplesPerLoop; j++) - MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), 0, globalPlayer); - - MC.Check(globalPlayer); - - if (machine.generateBits) - generateBitsFromTriples(uncheckedTriples, MC, outputFile); - else - if (machine.output) - for (int j = 0; j < nTriplesPerLoop; j++) - uncheckedTriples[j].output(outputFile, 1); + sacrifice(uncheckedTriples, MC, G); } } } - for (int i = 0; i < nparties-1; i++) - pthread_cond_signal(&ot_multipliers[i]->ready); + signal_multipliers(); } } +template +void NPartyTripleGenerator::sacrifice( + vector > uncheckedTriples, MAC_Check& MC, PRNG& G) +{ + vector< Share > maskedAs(nTriplesPerLoop); + vector > maskedTriples(nTriplesPerLoop); + for (int j = 0; j < nTriplesPerLoop; j++) + { + maskedTriples[j].prepare_sacrifice(uncheckedTriples[j], G); + maskedAs[j] = maskedTriples[j].a[0]; + } + + vector openedAs(nTriplesPerLoop); + MC.POpen_Begin(openedAs, maskedAs, globalPlayer); + MC.POpen_End(openedAs, maskedAs, globalPlayer); + + for (int j = 0; j < nTriplesPerLoop; j++) { + MC.AddToCheck(maskedTriples[j].computeCheckShare(openedAs[j]), 0, + globalPlayer); + } + + MC.Check(globalPlayer); + + if (machine.generateBits) + generateBitsFromTriples(uncheckedTriples, MC, outputFile); + else + if (machine.output) + for (int j = 0; j < nTriplesPerLoop; j++) + uncheckedTriples[j].output(outputFile, 1); +} + +template +void NPartyTripleGenerator::sacrifice( + vector > uncheckedTriples, MAC_Check_Z2k& MC, PRNG& G) +{ + vector< Share > maskedAs(nTriplesPerLoop); + vector > maskedTriples(nTriplesPerLoop); + for (int j = 0; j < nTriplesPerLoop; j++) + { + // compute [p] = t * [a] - [ahat] + // and first part of [sigma], i.e., t * [c] - [chat] + maskedTriples[j].prepare_sacrifice(uncheckedTriples[j], G); + maskedAs[j] = maskedTriples[j].a[0]; + } + + vector openedAs(nTriplesPerLoop); + MC.POpen_Begin(openedAs, maskedAs, globalPlayer); + MC.POpen_End(openedAs, maskedAs, globalPlayer); + + vector> sigmas; + for (int j = 0; j < nTriplesPerLoop; j++) { + // compute t * [c] - [chat] - [b] * p + sigmas.push_back(maskedTriples[j].computeCheckShare(V(openedAs[j]))); + } + vector open_sigmas; + + MC.POpen_Begin(open_sigmas, sigmas, globalPlayer); + MC.POpen_End(open_sigmas, sigmas, globalPlayer); + MC.Check(globalPlayer); + + for (int j = 0; j < nTriplesPerLoop; j++) { + if (V(open_sigmas[j]) != 0) + throw mac_fail(); + } + + if (machine.generateBits) + generateBitsFromTriples(uncheckedTriples, MC, outputFile); + else + if (machine.output) + for (int j = 0; j < nTriplesPerLoop; j++) + uncheckedTriples[j].template reduce().output(outputFile, 1); +} + template<> void NPartyTripleGenerator::generateBitsFromTriples( vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) @@ -576,9 +773,9 @@ void NPartyTripleGenerator::generateBitsFromTriples( } } -template<> +template void NPartyTripleGenerator::generateBitsFromTriples( - vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) + vector< ShareTriple_ >& triples, MAC_Check& MC, ofstream& outputFile) { throw how_would_that_work(); // warning gymnastics @@ -589,14 +786,12 @@ void NPartyTripleGenerator::generateBitsFromTriples( void NPartyTripleGenerator::start_progress() { - for (int i = 0; i < nparties-1; i++) - pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); + wait_for_multipliers(); lock(); signal(); wait(); gettimeofday(&last_lap, 0); - for (int i = 0; i < nparties-1; i++) - pthread_cond_signal(&ot_multipliers[i]->ready); + signal_multipliers(); } void NPartyTripleGenerator::print_progress(int k) @@ -655,3 +850,5 @@ template void NPartyTripleGenerator::generate(); template void NPartyTripleGenerator::generate(); template void NPartyTripleGenerator::generate >(); +template void NPartyTripleGenerator::generate >(); +template void NPartyTripleGenerator::generate >(); diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 0618b672..d99f8eac 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -19,8 +19,11 @@ #define N_AMPLIFY 3 +template +class ShareTriple_; + template -class ShareTriple; +using ShareTriple = ShareTriple_; class NPartyTripleGenerator { @@ -28,7 +31,6 @@ class NPartyTripleGenerator Player globalPlayer; int thread_num; - int my_num; int nbase; struct timeval last_lap; @@ -40,14 +42,25 @@ class NPartyTripleGenerator ofstream outputFile; PRNG share_prg; + template + void generateTriplesZ2k(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile); + template void generateTriples(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile); template void generateBits(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile); - template - void generateBitsFromTriples(vector >& triples, + template + void generateBitsFromTriples(vector >& triples, MAC_Check& MC, ofstream& outputFile); + template + void sacrifice(vector > uncheckedTriples, + MAC_Check& MC, PRNG& G); + + template + void sacrifice(vector > uncheckedTriples, + MAC_Check_Z2k& MC, PRNG& G); + void start_progress(); void print_progress(int k); @@ -65,6 +78,10 @@ public: vector< vector< vector > > baseSenderInputs; vector< vector > baseReceiverOutputs; vector valueBits; + BitVector b_padded_bits; + + int my_num; + int nTriples; int nTriplesPerLoop; diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index 264d72a8..6c268425 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -299,7 +299,6 @@ void OTExtension::transfer(int nOTs, // randomize last 128 + 128 bits that will be discarded for (int i = 0; i < 4; i++) newReceiverInput.set_word(nOTs/64 - i, G.get_word()); - // expand with PRG and create correlation if (ot_role & RECEIVER) { diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index acef3c52..c6260c05 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -76,16 +76,15 @@ void OTCorrelator::resize(int nOTs) receiverOutputMatrix.resize_vertical(nOTs); } -template<> -void OTExtensionWithMatrix::extend >(int nOTs_requested, - BitVector& newReceiverInput) -{ - extend(nOTs_requested, newReceiverInput); -} - // the template is used to denote the field of the hash output template void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput) +{ + extend_correlated(nOTs_requested, newReceiverInput); + hash_outputs(nOTs_requested); +} + +void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& newReceiverInput) { // if (nOTs % nbaseOTs != 0) // throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n"); @@ -133,8 +132,6 @@ void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInp #endif } - hash_outputs(nOTs); - receiverOutputMatrix.resize(nOTs_requested_rounded); senderOutputMatrices[0].resize(nOTs_requested_rounded); senderOutputMatrices[1].resize(nOTs_requested_rounded); @@ -190,8 +187,8 @@ void OTExtensionWithMatrix::expand_transposed() template void OTCorrelator::setup_for_correlation(BitVector& baseReceiverInput, - vector& baseSenderOutputs, - BitMatrix& baseReceiverOutput) + vector& baseSenderOutputs, + U& baseReceiverOutput) { this->baseReceiverInput = baseReceiverInput; receiverOutputMatrix = baseSenderOutputs[0]; @@ -282,28 +279,55 @@ void OTExtensionWithMatrix::transpose(int start, int slice) */ template void OTExtensionWithMatrix::hash_outputs(int nOTs) +{ + hash_outputs(nOTs, senderOutputMatrices, receiverOutputMatrix); +} + +template +void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput) { //cout << "Hashing... " << flush; octetStream os, h_os(HASH_SIZE); - square128 tmp; MMO mmo; #ifdef OTEXT_TIMER timeval startv, endv; gettimeofday(&startv, NULL); #endif - for (int i = 0; i < nOTs / 128; i++) + for (int i = 0; i < 2; i++) + senderOutput[i].resize_vertical(nOTs); + receiverOutput.resize_vertical(nOTs); + + int n_rows = V::PartType::N_ROWS; + if (V::PartType::N_COLUMNS != T::size() * 8) + throw runtime_error("length mismatch for MMO hash"); + if (nOTs % 8 != 0) + throw runtime_error("number of OTs must be divisible by 8"); + + for (int i = 0; i < nOTs; i += 8) { + int i_outer_input = i / 128; + int i_inner_input = i % 128; + int i_outer_output = i / n_rows; + int i_inner_output = i % n_rows; if (ot_role & SENDER) { - tmp = senderOutputMatrices[0].squares[i]; - tmp ^= baseReceiverInput; - senderOutputMatrices[0].squares[i].hash_row_wise(mmo, senderOutputMatrices[0].squares[i]); - senderOutputMatrices[1].squares[i].hash_row_wise(mmo, tmp); + int128 tmp[2][8]; + for (int j = 0; j < 8; j++) + { + tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j]; + tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0); + } + for (int j = 0; j < 2; j++) + mmo.hashBlocks( + &senderOutput[j].squares[i_outer_output].rows[i_inner_output], + &tmp[j]); } if (ot_role & RECEIVER) { - receiverOutputMatrix.squares[i].hash_row_wise(mmo, receiverOutputMatrix.squares[i]); + mmo.hashBlocks( + &receiverOutput.squares[i_outer_output].rows[i_inner_output], + &receiverOutputMatrix.squares[i_outer_input].rows[i_inner_input]); } } //cout << "done.\n"; @@ -324,10 +348,7 @@ void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output) output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { - T c1, c2; - receiverOutputMatrix.squares[j].to(c1); - senderOutputMatrices[0].squares[j].to(c2); - output[j] = c1 - c2; + receiverOutputMatrix.squares[j].template sub(senderOutputMatrices[0].squares[j]).to(output[j]); } } @@ -467,10 +488,17 @@ void OTExtensionWithMatrix::print_pre_expand() template class OTCorrelator; template class OTCorrelator >; -template void OTCorrelator >::correlate(int start, int slice, - BitVector& newReceiverInput, bool useConstantBase, int repeat); -template void OTCorrelator >::correlate(int start, int slice, - BitVector& newReceiverInput, bool useConstantBase, int repeat); +#define Z(BM,GF) \ +template void OTCorrelator::correlate(int start, int slice, \ + BitVector& newReceiverInput, bool useConstantBase, int repeat); \ +template void OTCorrelator::expand(int start, int slice); \ +template void OTCorrelator::reduce_squares(unsigned int nTriples, \ + vector& output); +#define ZZ(BM) Z(BM, gfp) Z(BM, gf2n) + +ZZ(BitMatrix) +ZZ(Matrix ) + template void OTExtensionWithMatrix::print_post_correlate( BitVector& newReceiverInput, int j, int offset, int sender); template void OTExtensionWithMatrix::print_post_correlate( @@ -479,27 +507,35 @@ template void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput); template void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput); -template void OTCorrelator >::expand(int start, int slice); -template void OTCorrelator >::expand(int start, int slice); template void OTExtensionWithMatrix::expand_transposed(); template void OTExtensionWithMatrix::expand_transposed(); -template void OTCorrelator >::reduce_squares(unsigned int nTriples, - vector& output); -template void OTCorrelator >::reduce_squares(unsigned int nTriples, - vector& output); +#define ZZZ(GF, M) \ +template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&); +#define MM Matrix, Z2<160> > > + +ZZZ(gfp, Matrix) +ZZZ(gf2n, Matrix) +ZZZ(Z2<160>, MM) + +#undef X #define X(N,L) \ template class OTCorrelator, Z2 > > >; \ template void OTCorrelator, Z2 > > >::correlate >(int start, int slice, \ BitVector& newReceiverInput, bool useConstantBase, int repeat); \ template void OTCorrelator, Z2 > > >::expand >(int start, int slice); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ + vector >& output); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ + vector >& output); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ vector >& output); \ -X(96, 160) +//X(96, 160) Y(64, 96) +Y(64, 64) +Y(32, 32) + +template void OTExtensionWithMatrix::hash_outputs, Matrix, Z2<128> > > >(int, std::vector, Z2<128> > >, std::allocator, Z2<128> > > > >&, Matrix, Z2<128> > >&); +template void OTExtensionWithMatrix::hash_outputs, Matrix, Z2<64> > > >(int, std::vector, Z2<64> > >, std::allocator, Z2<64> > > > >&, Matrix, Z2<64> > >&); diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index 9a942cf0..266623ee 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -17,8 +17,10 @@ class OTCorrelator : public OTExtension { public: vector senderOutputMatrices; - U receiverOutputMatrix; - U t1, u; + vector matrices; + U& receiverOutputMatrix; + U& t1; + U u; OTCorrelator(int nbaseOTs, int baseLength, int nloops, int nsubloops, @@ -29,19 +31,20 @@ public: OT_ROLE role=BOTH, bool passive=false) : OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, - baseSenderInput, baseReceiverOutput, role, passive) {} + baseSenderInput, baseReceiverOutput, role, passive), + senderOutputMatrices(2), matrices(2), + receiverOutputMatrix(matrices[0]), t1(matrices[1]) {} void resize(int nOTs); template void expand(int start, int slice); void setup_for_correlation(BitVector& baseReceiverInput, - vector& baseSenderOutputs, - BitMatrix& baseReceiverOutput); + vector& baseSenderOutputs, + U& baseReceiverOutput); template void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); template void reduce_squares(unsigned int nTriples, vector& output); - }; class OTExtensionWithMatrix : public OTCorrelator @@ -52,9 +55,9 @@ public: OTExtensionWithMatrix(int nbaseOTs, int baseLength, int nloops, int nsubloops, TwoPartyPlayer* player, - BitVector& baseReceiverInput, - vector< vector >& baseSenderInput, - vector& baseReceiverOutput, + const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput, OT_ROLE role=BOTH, bool passive=false) : OTCorrelator(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, @@ -67,9 +70,12 @@ public: void transfer(int nOTs, const BitVector& receiverInput); template void extend(int nOTs, BitVector& newReceiverInput); + void extend_correlated(int nOTs, BitVector& newReceiverInput); template void expand_transposed(); void transpose(int start, int slice); + template + void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput); void print(BitVector& newReceiverInput, int i = 0); template diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.cpp index b74b19bc..954f19af 100644 --- a/OT/OTMultiplier.cpp +++ b/OT/OTMultiplier.cpp @@ -12,42 +12,59 @@ #include -template -OTMultiplier::OTMultiplier(NPartyTripleGenerator& generator, +//#define OTCORR_TIMER + +template +OTMultiplier::OTMultiplier(NPartyTripleGenerator& generator, int thread_num) : generator(generator), thread_num(thread_num), rot_ext(128, 128, 0, 1, generator.players[thread_num], generator.baseReceiverInput, generator.baseSenderInputs[thread_num], generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check), - auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true), otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true) { - pthread_mutex_init(&mutex, 0); - pthread_cond_init(&ready, 0); - thread = 0; + pthread_mutex_init(&this->mutex, 0); + pthread_cond_init(&this->ready, 0); + this->thread = 0; } template MascotMultiplier::MascotMultiplier(NPartyTripleGenerator& generator, int thread_num) : - OTMultiplier(generator, thread_num) + OTMultiplier(generator, thread_num), + auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true) { c_output.resize(generator.nTriplesPerLoop); } -template -OTMultiplier::~OTMultiplier() +template +Spdz2kMultiplier::Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num) : + OTMultiplier, // bit length used when computing shares + Z2, // bit length of key share + Z2, // bit length used when computing mac shares + Z2kRectangle > // mult-rectangle + (generator, thread_num) { - pthread_mutex_destroy(&mutex); - pthread_cond_destroy(&ready); +#ifdef USE_OPT_VOLE + mac_vole = new OTVole, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); +#else + mac_vole = new OTVoleBase, Z2>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false); +#endif } -template -void OTMultiplier::multiply() +template +OTMultiplier::~OTMultiplier() +{ + pthread_mutex_destroy(&this->mutex); + pthread_cond_destroy(&this->ready); +} + +template +void OTMultiplier::multiply() { keyBits.set(generator.machine.get_mac_key()); - rot_ext.extend(keyBits.size(), keyBits); + rot_ext.extend(keyBits.size(), keyBits); senderOutput.resize(keyBits.size()); for (size_t j = 0; j < keyBits.size(); j++) { @@ -60,19 +77,17 @@ void OTMultiplier::multiply() } rot_ext.receiverOutputMatrix.to(receiverOutput); receiverOutput.resize(keyBits.size()); - auth_ot_ext.init(keyBits, senderOutput, receiverOutput); + init_authenticator(keyBits, senderOutput, receiverOutput); if (generator.machine.generateBits) - multiplyForBits(); + multiplyForBits(); else - multiplyForTriples(); + multiplyForTriples(); } -template -void OTMultiplier::multiplyForTriples() +template +void OTMultiplier::multiplyForTriples() { - auth_ot_ext.resize(generator.nPreampTriplesPerLoop * W::N_COLUMNS); - // dummy input for OT correlator vector _; vector< vector > __; @@ -82,20 +97,24 @@ void OTMultiplier::multiplyForTriples() rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128); - pthread_mutex_lock(&mutex); - pthread_cond_signal(&ready); - pthread_cond_wait(&ready, &mutex); + vector >& baseSenderOutputs = otCorrelator.matrices; + Matrix& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; + + pthread_mutex_lock(&this->mutex); + this->signal_generator(); + this->wait_for_generator(); for (int i = 0; i < generator.nloops; i++) { BitVector aBits = generator.valueBits[0]; //timers["Extension"].start(); - rot_ext.extend(X::N_ROWS * generator.nPreampTriplesPerLoop, aBits); + rot_ext.extend_correlated(X::N_ROWS * generator.nPreampTriplesPerLoop, aBits); + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); //timers["Extension"].stop(); //timers["Correlation"].start(); - otCorrelator.setup_for_correlation(aBits, rot_ext.senderOutputMatrices, - rot_ext.receiverOutputMatrix); + otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, + baseReceiverOutput); otCorrelator.template correlate(0, generator.nPreampTriplesPerLoop, generator.valueBits[1], false, generator.nAmplify); //timers["Correlation"].stop(); @@ -104,16 +123,32 @@ void OTMultiplier::multiplyForTriples() this->after_correlation(); - pthread_cond_signal(&this->ready); - pthread_cond_wait(&this->ready, &this->mutex); + this->signal_generator(); + this->wait_for_generator(); } - pthread_mutex_unlock(&mutex); + pthread_mutex_unlock(&this->mutex); +} + +template +void MascotMultiplier::init_authenticator(const BitVector& keyBits, + const vector< vector >& senderOutput, + const vector& receiverOutput) { + this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput); +} + +template +void Spdz2kMultiplier::init_authenticator(const BitVector& keyBits, + const vector< vector >& senderOutput, + const vector& receiverOutput) { + this->mac_vole->init(keyBits, senderOutput, receiverOutput); } template void MascotMultiplier::after_correlation() { + this->auth_ot_ext.resize(this->generator.nPreampTriplesPerLoop * square128::N_COLUMNS); + this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop, this->c_output); @@ -141,16 +176,46 @@ void Spdz2kMultiplier::after_correlation() { this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop, this->c_output); + + this->signal_generator(); + this->wait_for_generator(); + + this->macs.resize(3); +#ifdef OTCORR_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + for (int j = 0; j < 3; j++) + { + int nValues = this->generator.nTriplesPerLoop; + BitVector* bits; + if (//this->generator.machine.check && + (j % 2 == 0)){ + nValues *= 2; + bits = &(this->generator.valueBits[j]); + } + else { + // piggy-backing mask after the b's + nValues++; + bits = &(this->generator.b_padded_bits); + } + this->mac_vole->evaluate(this->macs[j], nValues, *bits); + } +#ifdef OTCORR_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tCorrelated OT time: " << elapsed/1000000 << endl << flush; +#endif } template<> -void OTMultiplier::multiplyForBits() +void OTMultiplier::multiplyForBits() { - multiplyForTriples(); + multiplyForTriples(); } template<> -void OTMultiplier::multiplyForBits() +void OTMultiplier::multiplyForBits() { int nBits = generator.nTriplesPerLoop + generator.field_size; int nBlocks = ceil(1.0 * nBits / generator.field_size); @@ -162,8 +227,8 @@ void OTMultiplier::multiplyForBits() macs[0].resize(nBits); pthread_mutex_lock(&mutex); - pthread_cond_signal(&ready); - pthread_cond_wait(&ready, &mutex); + signal_generator(); + wait_for_generator(); for (int i = 0; i < generator.nloops; i++) { @@ -178,25 +243,27 @@ void OTMultiplier::multiplyForBits() macs[0][j] = r ^ s; } - pthread_cond_signal(&ready); - pthread_cond_wait(&ready, &mutex); + signal_generator(); + wait_for_generator(); } pthread_mutex_unlock(&mutex); } -template -void OTMultiplier::multiplyForBits() +template +void OTMultiplier::multiplyForBits() { throw runtime_error("bit generation not implemented in this case"); } -template class OTMultiplier; -template class OTMultiplier; +template class OTMultiplier; +template class OTMultiplier; template class MascotMultiplier; template class MascotMultiplier; #define X(K, S) \ template class Spdz2kMultiplier; \ - template class OTMultiplier, Z2, Z2, Z2kRectangle, Z2kRectangle >; + template class OTMultiplier, Z2, Z2, Z2kRectangle >; X(64, 96) +X(64, 64) +X(32, 32) diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 222601e1..1350500d 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -12,6 +12,7 @@ using namespace std; #include "OT/OTExtensionWithMatrix.h" +#include "OT/OTVole.h" #include "OT/Rectangle.h" #include "Tools/random.h" @@ -26,11 +27,22 @@ public: virtual ~OTMultiplierBase() {} virtual void multiply() = 0; + + void signal_generator() { pthread_cond_signal(&ready); } + void wait_for_generator() { pthread_cond_wait(&ready, &mutex); } }; -template -class OTMultiplier : public OTMultiplierBase +template +class OTMultiplierMac : public OTMultiplierBase { +public: + vector< vector > macs; +}; + +template +class OTMultiplier : public OTMultiplierMac +{ +protected: BitVector keyBits; vector< vector > senderOutput; vector receiverOutput; @@ -39,14 +51,15 @@ class OTMultiplier : public OTMultiplierBase void multiplyForBits(); virtual void after_correlation() = 0; + virtual void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput) = 0; public: NPartyTripleGenerator& generator; int thread_num; OTExtensionWithMatrix rot_ext; - OTCorrelator > auth_ot_ext; OTCorrelator > otCorrelator; - vector< vector > macs; OTMultiplier(NPartyTripleGenerator& generator, int thread_num); virtual ~OTMultiplier(); @@ -54,9 +67,13 @@ public: }; template -class MascotMultiplier : public OTMultiplier +class MascotMultiplier : public OTMultiplier { + OTCorrelator > auth_ot_ext; void after_correlation(); + void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput); public: vector c_output; @@ -64,24 +81,25 @@ public: MascotMultiplier(NPartyTripleGenerator& generator, int thread_num); }; +// values, key, mac, mult-rectangle template -class Spdz2kMultiplier: public OTMultiplier, Z2, Z2, - Z2kRectangle, Z2kRectangle > +class Spdz2kMultiplier: public OTMultiplier, Z2, Z2, + Z2kRectangle > { void after_correlation(); + void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput); public: static const int TAU = TAU(K, S); - static const int MAC_BITS = K + S; + static const int PASSIVE_MULT_BITS = K + S; + static const int MAC_BITS = K + 2 * S; - vector > c_output; + vector > c_output; + OTVoleBase, Z2>* mac_vole; - Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num) : - OTMultiplier, Z2, Z2, - Z2kRectangle, - Z2kRectangle >(generator, thread_num) - { - } + Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num); }; #endif /* OT_OTMULTIPLIER_H_ */ diff --git a/OT/OTVole.cpp b/OT/OTVole.cpp new file mode 100644 index 00000000..611f1a31 --- /dev/null +++ b/OT/OTVole.cpp @@ -0,0 +1,358 @@ +// (C) 2018 University of Bristol. See License.txt + +#include "OTVole.h" +#include "Tools/oct.h" +#include "Auth/Subroutines.h" + +//#define OTVOLE_TIMER + +template +void OTVoleBase::evaluate(vector& output, const vector& newReceiverInput) { + const int N1 = newReceiverInput.size() + 1; + output.resize(newReceiverInput.size()); + vector os(2); + + if (this->ot_role & SENDER) { + T extra; + extra.randomize(local_prng); + vector _corr(newReceiverInput); + _corr.push_back(extra); + corr_prime = Row(_corr); + for (int i = 0; i < S; ++i) + { + t0[i] = Row(N1); + t0[i].randomize(this->G_sender[i][0]); + t1[i] = Row(N1); + t1[i].randomize(this->G_sender[i][1]); + Row u = corr_prime + t1[i] + t0[i]; + u.pack(os[0]); + } + } + send_if_ot_sender(this->player, os, this->ot_role); + if (this->ot_role & RECEIVER) { + for (int i = 0; i < S; ++i) + { + t[i] = Row(N1); + t[i].randomize(this->G_receiver[i]); + + int choice_bit = this->baseReceiverInput.get_bit(i); + if (choice_bit == 1) { + u[i].unpack(os[1]); + a[i] = u[i] - t[i]; + } else { + a[i] = t[i]; + u[i].unpack(os[1]); + } + } + } + os[0].reset_write_head(); + os[1].reset_write_head(); + + if (!this->passive_only) { + this->consistency_check(os); + } + + Row res(N1); + if (this->ot_role & RECEIVER) { + for (int i = 0; i < S; ++i) + res += (a[i] << i); + } + if (this->ot_role & SENDER) { + for (int i = 0; i < S; ++i) + res -= (t0[i] << i); + } + for (int i = 0; i < N1 - 1; ++i) + output[i] = res.rows[i]; +} + +template +void OTVoleBase::evaluate(vector& output, int nValues, const BitVector& newReceiverInput) { + if (newReceiverInput.size() != (size_t) nValues * T::N_BITS) + throw invalid_length(); + vector values(nValues); + for (int i = 0; i < nValues; ++i) + values[i] = T(newReceiverInput.get_ptr_to_bit(i, T::N_BITS)); + evaluate(output, values); +} + +template +void OTVoleBase::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) const { + avx_memzero(coefficients, num_blocks); + for (int i = 0; i < num_blocks; ++i) + coefficients[i] = G.get_doubleword(); +} + +template +void OTVoleBase::hash_row(octetStream& os, const Row& row, const __m128i* coefficients) { + octet hash[VOLE_HASH_SIZE] = {0}; + this->hash_row(hash, row, coefficients); + os.append(hash, VOLE_HASH_SIZE); +} + +template +void OTVoleBase::hash_row(octet* hash, const Row& row, const __m128i* coefficients) { + const __m128i* blocks = (const __m128i*) row.get_ptr(); + + int num_blocks = (row.size() * T::size()) / 16; + __m128i prods[2]; + avx_memzero(prods, sizeof(prods)); + __m128i res[2]; + avx_memzero(res, sizeof(res)); + + for (int i = 0; i < num_blocks; ++i) { + mul128(blocks[i], coefficients[i], &prods[0], &prods[1]); + res[0] ^= prods[0]; + res[1] ^= prods[1]; + } + + int total_bytes = row.size() * T::size(); + if (total_bytes % 16 != 0) { + // need to handle "tail" bytes that did not evenly fit into block + int overflow_size = total_bytes % 16; + const octet* overflow_bytes = ((octet*) row.get_ptr()) + total_bytes - overflow_size; + octet bytes[16] = {0}; + memcpy(bytes, overflow_bytes, overflow_size); + const __m128i* extra = (const __m128i*) bytes; + mul128(*extra, coefficients[num_blocks], &prods[0], &prods[1]); + res[0] ^= prods[0]; + res[1] ^= prods[1]; + } + + crypto_generichash(hash, crypto_generichash_BYTES, + (octet*) res, crypto_generichash_BYTES, NULL, 0); +} + +template +void OTVoleBase::consistency_check(vector& os) { + PRNG coef_prng_sender; + PRNG coef_prng_receiver; + + if (this->ot_role & RECEIVER) { + coef_prng_receiver.ReSeed(); + os[0].append(coef_prng_receiver.get_seed(), SEED_SIZE); + } + send_if_ot_receiver(this->player, os, this->ot_role); + if (this->ot_role & SENDER) { + octet seed[SEED_SIZE]; + os[1].consume(seed, SEED_SIZE); + coef_prng_sender.SetSeed(seed); + } + os[0].reset_write_head(); + os[1].reset_write_head(); + + if (this->ot_role & SENDER) { +#ifdef OTVOLE_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + int total_bytes = t0[0].size() * T::size(); + int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); + __m128i coefficients[num_blocks]; + this->set_coeffs(coefficients, coef_prng_sender, num_blocks); + + Row t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size()); + for (int alpha = 0; alpha < S; ++alpha) + { + for (int beta = 0; beta < S; ++beta) + { + t00 = t0[alpha] - t0[beta]; + t01 = t0[alpha] - t1[beta]; + t10 = t1[alpha] - t0[beta]; + t11 = t1[alpha] - t1[beta]; + + this->hash_row(os[0], t00, coefficients); + this->hash_row(os[0], t01, coefficients); + this->hash_row(os[0], t10, coefficients); + this->hash_row(os[0], t11, coefficients); + } + } +#ifdef OTVOLE_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tCheck time sender: " << elapsed/1000000 << endl << flush; +#endif + } + + send_if_ot_sender(this->player, os, this->ot_role); + if (this->ot_role & RECEIVER) { +#ifdef OTVOLE_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + int total_bytes = t[0].size() * T::size(); + int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); + __m128i coefficients[num_blocks]; + this->set_coeffs(coefficients, coef_prng_receiver, num_blocks); + + octet h00[VOLE_HASH_SIZE] = {0}; + octet h01[VOLE_HASH_SIZE] = {0}; + octet h10[VOLE_HASH_SIZE] = {0}; + octet h11[VOLE_HASH_SIZE] = {0}; + vector> hashes(2); + hashes[0] = {h00, h01}; + hashes[1] = {h10, h11}; + + for (int alpha = 0; alpha < S; ++alpha) + { + for (int beta = 0; beta < S; ++beta) + { + os[1].consume(hashes[0][0], VOLE_HASH_SIZE); + os[1].consume(hashes[0][1], VOLE_HASH_SIZE); + os[1].consume(hashes[1][0], VOLE_HASH_SIZE); + os[1].consume(hashes[1][1], VOLE_HASH_SIZE); + + int choice_alpha = this->baseReceiverInput.get_bit(alpha); + int choice_beta = this->baseReceiverInput.get_bit(beta); + + Row tmp = t[alpha] - t[beta]; + octet* choice_hash = hashes[choice_alpha][choice_beta]; + octet diff_t[VOLE_HASH_SIZE] = {0}; + this->hash_row(diff_t, tmp, coefficients); + + octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta]; + octet other_diff[VOLE_HASH_SIZE] = {0}; + tmp = u[alpha] - u[beta] - t[alpha] + t[beta]; + this->hash_row(other_diff, tmp, coefficients); + + if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) { + throw consistency_check_fail(); + } + if (!OCTETS_EQUAL(not_choice_hash, other_diff, VOLE_HASH_SIZE)) { + throw consistency_check_fail(); + } + if (alpha != beta && u[alpha] == u[beta]) { + throw consistency_check_fail(); + } + } + } +#ifdef OTVOLE_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tCheck receiver: " << elapsed/1000000 << endl << flush; +#endif + } +} + +template +void OTVole::consistency_check(vector& os) { + PRNG sender_prg; + PRNG receiver_prg; + + if (this->ot_role & RECEIVER) { + receiver_prg.ReSeed(); + os[0].append(receiver_prg.get_seed(), SEED_SIZE); + } + send_if_ot_receiver(this->player, os, this->ot_role); + if (this->ot_role & SENDER) { + octet seed[SEED_SIZE]; + os[1].consume(seed, SEED_SIZE); + sender_prg.SetSeed(seed); + } + os[0].reset_write_head(); + os[1].reset_write_head(); + + if (this->ot_role & SENDER) { +#ifdef OTVOLE_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + int total_bytes = this->t0[0].size() * T::size(); + int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); + __m128i coefficients[num_blocks]; + this->set_coeffs(coefficients, sender_prg, num_blocks); + + Row t00(this->t0.size()), t01(this->t0.size()), t10(this->t0.size()), t11(this->t0.size()); + for (int alpha = 0; alpha < U::N_BITS; ++alpha) + { + for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i) + { + int beta = sender_prg.get_uint(U::N_BITS); + + t00 = this->t0[alpha] - this->t0[beta]; + t01 = this->t0[alpha] - this->t1[beta]; + t10 = this->t1[alpha] - this->t0[beta]; + t11 = this->t1[alpha] - this->t1[beta]; + + this->hash_row(os[0], t00, coefficients); + this->hash_row(os[0], t01, coefficients); + this->hash_row(os[0], t10, coefficients); + this->hash_row(os[0], t11, coefficients); + } + } +#ifdef OTVOLE_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tCheck time sender: " << elapsed/1000000 << endl << flush; +#endif + } + + send_if_ot_sender(this->player, os, this->ot_role); + if (this->ot_role & RECEIVER) { +#ifdef OTVOLE_TIMER + timeval totalstartv, totalendv; + gettimeofday(&totalstartv, NULL); +#endif + int total_bytes = this->t[0].size() * T::size(); + int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0); + __m128i coefficients[num_blocks]; + this->set_coeffs(coefficients, receiver_prg, num_blocks); + + octet h00[VOLE_HASH_SIZE] = {0}; + octet h01[VOLE_HASH_SIZE] = {0}; + octet h10[VOLE_HASH_SIZE] = {0}; + octet h11[VOLE_HASH_SIZE] = {0}; + vector> hashes(2); + hashes[0] = {h00, h01}; + hashes[1] = {h10, h11}; + + for (int alpha = 0; alpha < U::N_BITS; ++alpha) + { + for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i) + { + int beta = receiver_prg.get_uint(U::N_BITS); + + os[1].consume(hashes[0][0], VOLE_HASH_SIZE); + os[1].consume(hashes[0][1], VOLE_HASH_SIZE); + os[1].consume(hashes[1][0], VOLE_HASH_SIZE); + os[1].consume(hashes[1][1], VOLE_HASH_SIZE); + + int choice_alpha = this->baseReceiverInput.get_bit(alpha); + int choice_beta = this->baseReceiverInput.get_bit(beta); + + Row tmp = this->t[alpha] - this->t[beta]; + octet* choice_hash = hashes[choice_alpha][choice_beta]; + octet diff_t[VOLE_HASH_SIZE] = {0}; + this->hash_row(diff_t, tmp, coefficients); + + octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta]; + octet other_diff[VOLE_HASH_SIZE] = {0}; + tmp = this->u[alpha] - this->u[beta] - this->t[alpha] + this->t[beta]; + this->hash_row(other_diff, tmp, coefficients); + + if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) { + throw consistency_check_fail(); + } + if (!OCTETS_EQUAL(not_choice_hash, other_diff, VOLE_HASH_SIZE)) { + throw consistency_check_fail(); + } + if (alpha != beta && this->u[alpha] == this->u[beta]) { + throw consistency_check_fail(); + } + } + } +#ifdef OTVOLE_TIMER + gettimeofday(&totalendv, NULL); + double elapsed = timeval_diff(&totalstartv, &totalendv); + cout << "\t\tCheck receiver: " << elapsed/1000000 << endl << flush; +#endif + } +} + +template class OTVoleBase, Z2<96>>; +template class OTVoleBase, Z2<64>>; +template class OTVoleBase, Z2<32>>; + +template class OTVole, Z2<96>>; +template class OTVole, Z2<64>>; +template class OTVole, Z2<32>>; + diff --git a/OT/OTVole.h b/OT/OTVole.h new file mode 100644 index 00000000..35b822f4 --- /dev/null +++ b/OT/OTVole.h @@ -0,0 +1,86 @@ +// (C) 2018 University of Bristol. See License.txt + +#ifndef _OTVOLE +#define _OTVOLE + +#include "Math/Z2k.h" +#include "OTExtension.h" +#include "Row.h" + +using namespace std; + +template +class OTVoleBase : public OTExtension +{ +public: + static const int S = U::N_BITS; + + OTVoleBase(int nbaseOTs, int baseLength, + int nloops, int nsubloops, + TwoPartyPlayer* player, + const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput, + OT_ROLE role=BOTH, + bool passive=false) + : OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, + baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive), + corr_prime(), + t0(U::N_BITS), + t1(U::N_BITS), + u(U::N_BITS), + t(U::N_BITS), + a(U::N_BITS) { + // need to flip roles for OT extension init, reset to original role here + this->ot_role = role; + local_prng.ReSeed(); + } + + void evaluate(vector& output, const vector& newReceiverInput); + + void evaluate(vector& output, int nValues, const BitVector& newReceiverInput); + + protected: + + // Sender fields + Row corr_prime; + vector> t0, t1; + // Receiver fields + vector> u, t, a; + // Both + PRNG local_prng; + + virtual void consistency_check (vector& os); + + void set_coeffs(__m128i* coefficients, PRNG& G, int num_elements) const; + + void hash_row(octetStream& os, const Row& row, const __m128i* coefficients); + + void hash_row(octet* hash, const Row& row, const __m128i* coefficients); + +}; + +template +class OTVole : public OTVoleBase +{ + +public: + OTVole(int nbaseOTs, int baseLength, + int nloops, int nsubloops, + TwoPartyPlayer* player, + const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput, + OT_ROLE role=BOTH, + bool passive=false) + : OTVoleBase(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput, + baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive) { + } + +protected: + + void consistency_check(vector& os); + +}; + +#endif diff --git a/OT/Rectangle.cpp b/OT/Rectangle.cpp index 91f546ff..003b0c86 100644 --- a/OT/Rectangle.cpp +++ b/OT/Rectangle.cpp @@ -102,8 +102,11 @@ void Rectangle::unpack(octetStream& o) #define X(N,L) \ template class Rectangle, Z2 > ; \ template void Rectangle, Z2 >::to(Z2& result); \ + template void Rectangle, Z2 >::to(Z2& result); \ template void Rectangle, Z2 >::to(Z2& result); \ + template void Rectangle, Z2 >::to(Z2& result); \ Y(64, 96) -X(96, 160) +Y(64, 64) +Y(32, 32) diff --git a/OT/Rectangle.h b/OT/Rectangle.h index 4c1a1c8d..e9a803b2 100644 --- a/OT/Rectangle.h +++ b/OT/Rectangle.h @@ -11,18 +11,22 @@ #include "Math/Z2k.h" #define TAU(K, S) 2 * K + 4 * S -#define Y(K, S) X(TAU(K, S), K + S) X(K + S, K + S) +#define Y(K, S) X(TAU(K, S), K + S) X(S, K + 2 * S) template class Rectangle { public: + typedef V RowType; + static const int N_ROWS = U::N_BITS; static const int N_COLUMNS = V::N_BITS; static const int N_ROW_BYTES = V::N_BYTES; V rows[N_ROWS]; + static size_t size() { return N_ROWS * RowType::size(); } + bool operator==(const Rectangle& other) const; bool operator!=(const Rectangle& other) const { return not (*this == other); } @@ -32,6 +36,8 @@ public: Rectangle& operator+=(const Rectangle& other); Rectangle operator-(const Rectangle & other); + template + Rectangle& sub(Rectangle& other) { return other.rsub_(*this); } template Rectangle& rsub(Rectangle& other) { return rsub_(other); } Rectangle& rsub_(Rectangle& other); diff --git a/OT/Row.cpp b/OT/Row.cpp new file mode 100644 index 00000000..04f22df2 --- /dev/null +++ b/OT/Row.cpp @@ -0,0 +1,113 @@ +#include "OT/Row.h" +#include "Exceptions/Exceptions.h" + +template +bool Row::operator ==(const Row& other) const +{ + return rows == other.rows; +} + +template +Row& Row::operator +=(const Row& other) +{ + if (rows.size() != other.size()) { + throw invalid_length(); + } + for (size_t i = 0; i < this->size(); i++) + rows[i] += other.rows[i]; + return *this; +} + +template +Row& Row::operator -=(const Row& other) +{ + if (rows.size() != other.size()) { + throw invalid_length(); + } + for (size_t i = 0; i < this->size(); i++) + rows[i] -= other.rows[i]; + return *this; +} + +template +Row& Row::operator *=(const T& other) +{ + for (size_t i = 0; i < this->size(); i++) + rows[i] = rows[i] * other; + return *this; +} + +template +Row Row::operator *(const T& other) +{ + Row res = *this; + res *= other; + return res; +} + +template +Row Row::operator +(const Row& other) +{ + Row res = other; + res += *this; + return res; +} + +template +Row Row::operator -(const Row& other) +{ + Row res = *this; + res-=other; + return res; +} + +template +void Row::randomize(PRNG& G) +{ + for (size_t i = 0; i < this->size(); i++) + rows[i].randomize(G); +} + +template +Row Row::operator<<(int i) const { + if (i >= T::size() * 8) { + throw invalid_params(); + } + Row res = *this; + for (size_t j = 0; j < this->size(); j++) + res.rows[j] = res.rows[j] << i; + return res; +} + +template +void Row::pack(octetStream& o) const +{ + o.store(this->size()); + for (size_t i = 0; i < this->size(); i++) + rows[i].pack(o); +} + +template +void Row::unpack(octetStream& o) +{ + size_t size; + o.get(size); + this->rows.resize(size); + for (size_t i = 0; i < this->size(); i++) + rows[i].unpack(o); +} + +template +ostream& operator<<(ostream& o, const Row& x) +{ + for (size_t i = 0; i < x.size(); ++i) + o << x.rows[i] << " | "; + return o; +} + +template class Row>; +template ostream& operator<<(ostream& o, const Row>& x); +template class Row>; +template ostream& operator<<(ostream& o, const Row>& x); +template class Row>; +template ostream& operator<<(ostream& o, const Row>& x); diff --git a/OT/Row.h b/OT/Row.h new file mode 100644 index 00000000..f2961e6b --- /dev/null +++ b/OT/Row.h @@ -0,0 +1,52 @@ +#ifndef OT_ROW_H_ +#define OT_ROW_H_ + +#include "Math/Z2k.h" +#include "Math/gf2nlong.h" +#define VOLE_HASH_SIZE crypto_generichash_BYTES + +template +class Row +{ +public: + + vector rows; + + Row(int size) : rows(size) {} + + Row() : rows() {} + + Row(const vector& _rows) : rows(_rows) {} + + bool operator==(const Row& other) const; + bool operator!=(const Row& other) const { return not (*this == other); } + + Row& operator+=(const Row& other); + Row& operator-=(const Row& other); + + Row& operator*=(const T& other); + + Row operator*(const T& other); + Row operator+(const Row & other); + Row operator-(const Row & other); + + Row operator<<(int i) const; + + // fine, since elements in vector are allocated contiguously + const void* get_ptr() const { return rows[0].get_ptr(); } + + void randomize(PRNG& G); + + void pack(octetStream& o) const; + void unpack(octetStream& o); + + size_t size() const { return rows.size(); } + + template + friend ostream& operator<<(ostream& o, const Row& x); +}; + +template +using Z2kRow = Row>; + +#endif /* OT_ROW_H_ */ diff --git a/OT/TripleMachine.cpp b/OT/TripleMachine.cpp index 54bd2967..3a328bbc 100644 --- a/OT/TripleMachine.cpp +++ b/OT/TripleMachine.cpp @@ -112,6 +112,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) : primeField = opt.get("-P")->isSet; bonding = opt.get("-b")->isSet; z2k = opt.get("-Z")->isSet; + check |= z2k; bigint p; if (output) @@ -134,7 +135,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) : G.ReSeed(); mac_key2.randomize(G); mac_keyp.randomize(G); - mac_key296.randomize(G); + mac_keyz.randomize(G); } void TripleMachine::run() @@ -168,8 +169,8 @@ void TripleMachine::run() generators[i]->lock(); if (z2k) { - pthread_create(&threads[i], 0, run_ngenerator_thread >, - generators[i]); + pthread_create(&threads[i], 0, run_ngenerator_thread >, + generators[i]); continue; } if (primeField) @@ -222,8 +223,9 @@ void TripleMachine::run() void TripleMachine::output_mac_keys() { - if (z2k) - write_mac_key(prep_data_dir, my_num, mac_key296); + if (z2k) { + write_mac_key(prep_data_dir, my_num, mac_keyz); + } else write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2); } @@ -240,5 +242,15 @@ template<> gfp TripleMachine::get_mac_key() template<> Z2<96> TripleMachine::get_mac_key() { - return mac_key296; + return mac_keyz; +} + +template<> Z2<64> TripleMachine::get_mac_key() +{ + return mac_keyz; +} + +template<> Z2<32> TripleMachine::get_mac_key() +{ + return mac_keyz; } diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index c67322be..e2508f29 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -17,7 +17,7 @@ class TripleMachine : public OfflineMachineBase { gf2n mac_key2; gfp mac_keyp; - Z2<96> mac_key296; + Z2 mac_keyz; public: int nloops; diff --git a/Processor/Buffer.cpp b/Processor/Buffer.cpp index 1c05d8b6..d057f263 100644 --- a/Processor/Buffer.cpp +++ b/Processor/Buffer.cpp @@ -193,10 +193,10 @@ void BufferHelper::purge() template class Buffer< Share, Share >; template class Buffer< Share, Share >; -template class Buffer< Share >, Share > >; +template class Buffer< Share >, Share > >; template class Buffer< InputTuple, RefInputTuple >; template class Buffer< InputTuple, RefInputTuple >; -template class Buffer< InputTuple >, RefInputTuple > >; +template class Buffer< InputTuple >, RefInputTuple > >; template class Buffer< gfp, gfp >; template class Buffer< gf2n, gf2n >; diff --git a/Processor/Buffer.h b/Processor/Buffer.h index b1bd2237..a1e21620 100644 --- a/Processor/Buffer.h +++ b/Processor/Buffer.h @@ -69,13 +69,13 @@ class BufferHelper public: Buffer< U, V > bufferp; Buffer< U, V > buffer2; - Buffer< U >, V > > bufferz2k; + Buffer< U >, V > > bufferz2k; ifstream* files[N_DATA_FIELD_TYPE]; BufferHelper() { memset(files, 0, sizeof(files)); } void input(V& a) { bufferp.input(a); } void input(V& a) { buffer2.input(a); } - void input(V >& a) { bufferz2k.input(a); } + void input(V >& a) { bufferz2k.input(a); } BufferBase& get_buffer(DataFieldType field_type); void setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type = 0); void close(); diff --git a/Processor/Data_Files.cpp b/Processor/Data_Files.cpp index 41bb7859..21c2665b 100644 --- a/Processor/Data_Files.cpp +++ b/Processor/Data_Files.cpp @@ -6,8 +6,8 @@ #include -const char* Data_Files::field_names[] = { "p", "2", "Z2^64" }; -const char* Data_Files::long_field_names[] = { "gfp", "gf2n", "Z2^64" }; +const char* Data_Files::field_names[] = { "p", "2", "Z2^160" }; +const char* Data_Files::long_field_names[] = { "gfp", "gf2n", "Z2^160" }; const bool Data_Files::implemented[N_DATA_FIELD_TYPE][N_DTYPE] = { { true, true, true, true, false, false }, { true, true, true, true, true, true }, diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp index 22742da9..68016dab 100644 --- a/Processor/Machine.cpp +++ b/Processor/Machine.cpp @@ -31,7 +31,7 @@ Machine::Machine(int my_number, Names& playerNames, prep_dir_prefix = get_prep_dir(N.num_players(), lgp, lg2); read_setup(prep_dir_prefix); - char filename[1024]; + char filename[2048]; int nn; sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number); diff --git a/Tools/Config.cpp b/Tools/Config.cpp index f773c4c4..4990d4b7 100644 --- a/Tools/Config.cpp +++ b/Tools/Config.cpp @@ -122,7 +122,7 @@ namespace Config { pubkeys[i].resize(crypto_sign_PUBLICKEYBYTES); infile.read((char*)&pubkeys[i][0],pubkeys[i].size()); } - } catch (ConfigError e) { + } catch (ConfigError& e) { pubkeys.resize(0); } diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index 42160805..c89f83d7 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -10,20 +10,27 @@ #include "Math/gf2n.h" #include "Math/gfp.h" #include "Math/bigint.h" +#include "Math/Z2k.h" #include void MMO::zeroIV() { - octet key[AES_BLK_SIZE]; - memset(key,0,AES_BLK_SIZE*sizeof(octet)); - setIV(key); + if (N_KEYS > (1 << 8)) + throw not_implemented(); + for (int i = 0; i < N_KEYS; i++) + { + octet key[AES_BLK_SIZE]; + memset(key, 0, AES_BLK_SIZE * sizeof(octet)); + key[i] = i; + setIV(i, key); + } } -void MMO::setIV(octet key[AES_BLK_SIZE]) +void MMO::setIV(int i, octet key[AES_BLK_SIZE]) { - aes_schedule(IV,key); + aes_schedule(IV[i],key); } @@ -50,26 +57,91 @@ void MMO::encrypt_and_xor(void* output, const void* input, const octet* key, _mm_storeu_si128(((__m128i*)output) + indices[i], out[i]); } -template <> -void MMO::hashOneBlock(octet* output, octet* input) +template +void MMO::hashBlocks(void* output, const void* input) { - encrypt_and_xor<1>(output, input, IV); + int n_blocks = DIV_CEIL(T::size(), 16); + if (n_blocks > N_KEYS) + throw runtime_error("not enough MMO keys"); + __m128i tmp[N]; + int block_size = sizeof(tmp[0]); + for (int i = 0; i < n_blocks; i++) + { + encrypt_and_xor(tmp, input, IV[i]); + for (int j = 0; j < N; j++) + memcpy((char*)output + j * sizeof(T) + i * block_size, &tmp[j], + min(T::size() - i * block_size, block_size)); + } } - template <> -void MMO::hashOneBlock(octet* output, octet* input) +void MMO::hashBlocks(void* output, const void* input) { - encrypt_and_xor<1>(output, input, IV); + if (gfp::get_ZpD().get_t() != 2) + throw not_implemented(); + encrypt_and_xor<1>(output, input, IV[0]); while (mpn_cmp((mp_limb_t*)output, gfp::get_ZpD().get_prA(), gfp::t()) >= 0) - encrypt_and_xor<1>(output, output, IV); + encrypt_and_xor<1>(output, output, IV[0]); } template <> void MMO::hashBlockWise(octet* output, octet* input) { for (int i = 0; i < 16; i++) - encrypt_and_xor<8>(&((__m128i*)output)[i*8], &((__m128i*)input)[i*8], IV); + encrypt_and_xor<8>(&((__m128i*)output)[i*8], &((__m128i*)input)[i*8], IV[0]); +} + +template <> +void MMO::hashBlocks(void* output, const void* input) +{ + if (gfp::get_ZpD().get_t() != 2) + throw not_implemented(); + __m128i* in = (__m128i*)input; + __m128i* out = (__m128i*)output; + encrypt_and_xor<8>(out, in, IV[0]); + int left = 8; + int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + while (left) + { + int now_left = 0; + for (int j = 0; j < left; j++) + if (mpn_cmp((mp_limb_t*)&out[indices[j]], gfp::get_ZpD().get_prA(), gfp::t()) >= 0) + { + indices[now_left] = indices[j]; + now_left++; + } + left = now_left; + + // and now my favorite hack + switch (left) { + case 8: + ecb_aes_128_encrypt<8>(out, out, IV[0], indices); + break; + case 7: + ecb_aes_128_encrypt<7>(out, out, IV[0], indices); + break; + case 6: + ecb_aes_128_encrypt<6>(out, out, IV[0], indices); + break; + case 5: + ecb_aes_128_encrypt<5>(out, out, IV[0], indices); + break; + case 4: + ecb_aes_128_encrypt<4>(out, out, IV[0], indices); + break; + case 3: + ecb_aes_128_encrypt<3>(out, out, IV[0], indices); + break; + case 2: + ecb_aes_128_encrypt<2>(out, out, IV[0], indices); + break; + case 1: + ecb_aes_128_encrypt<1>(out, out, IV[0], indices); + break; + default: + break; + } + } } template <> @@ -79,49 +151,11 @@ void MMO::hashBlockWise(octet* output, octet* input) { __m128i* in = &((__m128i*)input)[i*8]; __m128i* out = &((__m128i*)output)[i*8]; - encrypt_and_xor<8>(out, in, IV); - int left = 8; - int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - while (left) - { - int now_left = 0; - for (int j = 0; j < left; j++) - if (mpn_cmp((mp_limb_t*)&out[indices[j]], gfp::get_ZpD().get_prA(), gfp::t()) >= 0) - { - indices[now_left] = indices[j]; - now_left++; - } - left = now_left; - - // and now my favorite hack - switch (left) { - case 8: - ecb_aes_128_encrypt<8>(out, out, IV, indices); - break; - case 7: - ecb_aes_128_encrypt<7>(out, out, IV, indices); - break; - case 6: - ecb_aes_128_encrypt<6>(out, out, IV, indices); - break; - case 5: - ecb_aes_128_encrypt<5>(out, out, IV, indices); - break; - case 4: - ecb_aes_128_encrypt<4>(out, out, IV, indices); - break; - case 3: - ecb_aes_128_encrypt<3>(out, out, IV, indices); - break; - case 2: - ecb_aes_128_encrypt<2>(out, out, IV, indices); - break; - case 1: - ecb_aes_128_encrypt<1>(out, out, IV, indices); - break; - default: - break; - } - } + hashBlocks(out, in); } } + +#define ZZ(F,N) \ + template void MMO::hashBlocks(void*, const void*); +#define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8) +Z(gf2n) Z(Z2<64>) Z(Z2<128>) Z(Z2<160>) Z(Z2<256>) diff --git a/Tools/MMO.h b/Tools/MMO.h index dd331902..aba6ba93 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -13,7 +13,8 @@ // Matyas-Meyer-Oseas hashing class MMO { - octet IV[176] __attribute__((aligned (16))); + static const int N_KEYS = 2; + octet IV[N_KEYS][176] __attribute__((aligned (16))); template static void encrypt_and_xor(void* output, const void* input, @@ -25,9 +26,11 @@ class MMO public: MMO() { zeroIV(); } void zeroIV(); - void setIV(octet key[AES_BLK_SIZE]); + void setIV(int i, octet key[AES_BLK_SIZE]); template - void hashOneBlock(octet* output, octet* input); + void hashOneBlock(void* output, const void* input) { hashBlocks((T*)output, input); } + template + void hashBlocks(void* output, const void* input); template void hashBlockWise(octet* output, octet* input); template diff --git a/Tools/int.h b/Tools/int.h index 00d43d5c..8114819b 100644 --- a/Tools/int.h +++ b/Tools/int.h @@ -18,7 +18,6 @@ typedef unsigned char octet; typedef unsigned long word; #endif - inline int CEIL_LOG2(int x) { int result = 0; diff --git a/Tools/oct.h b/Tools/oct.h new file mode 100644 index 00000000..e0fdec0b --- /dev/null +++ b/Tools/oct.h @@ -0,0 +1,19 @@ +#ifndef TOOLS_OCT_H_ +#define TOOLS_OCT_H_ + +typedef unsigned char octet; + +inline void PRINT_OCTET(const octet* bytes, size_t size) { + for (size_t i = 0; i < size; ++i) + cout << hex << (int) bytes[i]; + cout << flush << endl; +} + +inline bool OCTETS_EQUAL(const octet* left, const octet* right, int size) { + for (int i = 0; i < size; ++i) + if (left[i] != right[i]) + return false; + return true; +} + +#endif diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index 2842dbcb..adaac897 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -36,16 +36,6 @@ void octetStream::assign(const octetStream& os) } -void octetStream::swap(octetStream& os) -{ - const size_t size = sizeof(octetStream); - char tmp[size]; - memcpy(tmp, this, size); - memcpy(this, &os, size); - memcpy(&os, tmp, size); -} - - octetStream::octetStream(size_t maxlen) { mxlen=maxlen; len=0; ptr=0; diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 65b904d6..69ec105a 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -42,10 +42,10 @@ class octetStream void resize(size_t l); void resize_precise(size_t l); + void reserve(size_t l); void clear(); void assign(const octetStream& os); - void swap(octetStream& os); octetStream() : len(0), mxlen(0), ptr(0), data(0) {} octetStream(size_t maxlen); @@ -177,6 +177,11 @@ inline void octetStream::resize_precise(size_t l) mxlen=l; } +inline void octetStream::reserve(size_t l) +{ + if (len + l > mxlen) + resize_precise(len + l); +} inline void octetStream::append(const octet* x, const size_t l) { diff --git a/Tools/random.cpp b/Tools/random.cpp index 586b2330..f67d6574 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -3,6 +3,7 @@ #include "Tools/random.h" #include "Math/bigint.h" +#include "Auth/Subroutines.h" #include #include @@ -38,6 +39,12 @@ void PRNG::SetSeed(PRNG& G) SetSeed(tmp); } +void PRNG::SecureSeed(Player& player) +{ + Create_Random_Seed(seed, player, SEED_SIZE); + InitSeed(); +} + void PRNG::InitSeed() { #ifdef USE_AES @@ -136,7 +143,25 @@ unsigned int PRNG::get_uint() return ans; } - +unsigned int PRNG::get_uint(int upper) +{ + // adopting Java 7 implementation of bounded nextInt here + if (upper <= 0) + throw invalid_argument("Must be positive"); + // power of 2 case + if ((upper & (upper - 1)) == 0) { + unsigned int r = (upper < 255) ? get_uchar() : get_uint(); + // zero out higher order bits + return r % upper; + } + // not power of 2 + int r, reduced; + do { + r = (upper < 255) ? get_uchar() : get_uint(); + reduced = r % upper; + } while (r - reduced + (upper - 1) < 0); + return reduced; +} unsigned char PRNG::get_uchar() { @@ -147,7 +172,6 @@ unsigned char PRNG::get_uchar() return ans; } - __m128i PRNG::get_doubleword() { if (cnt > RAND_SIZE - 16) diff --git a/Tools/random.h b/Tools/random.h index 359e2334..2cbb74cf 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -6,6 +6,7 @@ #include "Tools/octetStream.h" #include "Tools/sha1.h" #include "Tools/aes.h" +#include "Tools/avx_memcpy.h" #define USE_AES @@ -19,6 +20,7 @@ #define RAND_SIZE (PIPELINES * AES_BLK_SIZE) #endif +class Player; /* This basically defines a randomness expander, if using * as a real PRG on an input stream you should first collapse @@ -62,12 +64,14 @@ class PRNG // Set seed from array void SetSeed(const unsigned char*); void SetSeed(PRNG& G); + void SecureSeed(Player& player); void InitSeed(); double get_double(); bool get_bit() { return get_uchar() & 1; } unsigned char get_uchar(); unsigned int get_uint(); + unsigned int get_uint(int upper); void get_bigint(bigint& res, int n_bits, bool positive = true); void get(bigint& res, int n_bits, bool positive = true); void get(int& res, int n_bits, bool positive = true); @@ -82,9 +86,34 @@ class PRNG __m128i get_doubleword(); void get_octetStream(octetStream& ans,int len); void get_octets(octet* ans, int len); + template + void get_octets_(octet* ans); + + template + T get(); const octet* get_seed() const { return seed; } }; +template +T PRNG::get() +{ + T res; + res.randomize(*this); + return res; +} + +template +inline void PRNG::get_octets_(octet* ans) +{ + if (L < RAND_SIZE - cnt) + { + avx_memcpy(ans, random + cnt, L); + cnt += L; + } + else + get_octets(ans, L); +} + #endif diff --git a/check-passive.cpp b/check-passive.cpp index 0e31b90a..90ea056d 100644 --- a/check-passive.cpp +++ b/check-passive.cpp @@ -70,6 +70,75 @@ void check_triples(int n_players, string type_char = "") delete[] inputFiles; } +template +void check_triples_Z2k(int n_players, string type_char = "", bool macs = true) +{ + ifstream* inputFiles = new ifstream[n_players]; + for (int i = 0; i < n_players; i++) + { + stringstream ss; + ss << get_prep_dir(n_players, 128, 128) << "Triples-"; + if (type_char.size()) + ss << type_char; + else + ss << T::type_char(); + ss << "-P" << i; + inputFiles[i].open(ss.str().c_str()); + cout << "Opening file " << ss.str() << endl; + } + + int j = 0; + while (inputFiles[0].peek() != EOF) + { + U a,b,c,cc,tmp,prod; + T dummy; + vector as(n_players), bs(n_players), cs(n_players); + for (int i = 0; i < n_players; i++) + { + as[i].input(inputFiles[i], false); + if (macs) + dummy.input(inputFiles[i], false); + bs[i].input(inputFiles[i], false); + if (macs) + dummy.input(inputFiles[i], false); + cs[i].input(inputFiles[i], false); + if (macs) + dummy.input(inputFiles[i], false); + } + + a = accumulate(as.begin(), as.end(), U()); + b = accumulate(bs.begin(), bs.end(), U()); + c = accumulate(cs.begin(), cs.end(), U()); + + prod = a * b; + if (prod != c) + { + cout << T::type_string() << ": Error in " << j << endl; + cout << "a " << a << " " << as[0] << " " << as[1] << endl; + cout << "b " << b << " " << bs[0] << " " << bs[1] << endl; + cout << "c " << c << " " << cs[0] << " " << cs[1] << endl; + for (int i = 0; i < 2; i++) + for (int j = 0; j < 2; j++) + { + tmp = as[i] * bs[j]; + cc += tmp; + cout << "a" << i << " * b" << j << " " << tmp << endl; + } + cout << "c " << c << endl; + cout << "cc " << cc << endl; + cout << "a*b " << prod << endl; + cout << "DID YOU INDICATE THE CORRECT NUMBER OF PLAYERS?" << endl; + + return; + } + + j++; + } + + cout << dec << j << " correct triples of type " << T::type_string() << endl; + delete[] inputFiles; +} + int main(int argc, char** argv) { int n_players = 2; @@ -79,5 +148,5 @@ int main(int argc, char** argv) gfp::init_field(gfp::pr(), false); check_triples(n_players); check_triples(n_players); - check_triples >(n_players, "Z2^64"); + check_triples_Z2k, Z2<64>>(n_players); }