#include "NPartyTripleGenerator.h" #include "OT/OTExtensionWithMatrix.h" #include "OT/OTMultiplier.h" #include "Math/gfp.h" #include "Math/Share.h" #include "Math/operators.h" #include "Auth/Subroutines.h" #include "Auth/MAC_Check.h" #include #include #include template class Triple { public: T a[N]; T b; T c[N]; int repeat(int l) { switch (l) { case 0: case 2: return N; case 1: return 1; default: throw bad_value(); } } T& byIndex(int l, int j) { switch (l) { case 0: return a[j]; case 1: return b; case 2: return c[j]; default: throw bad_value(); } } template void amplify(const Triple& uncheckedTriple, PRNG& G) { b = uncheckedTriple.b; for (int i = 0; i < N; i++) for (int j = 0; j < M; j++) { typename T::value_type r; r.randomize(G); a[i] += r * uncheckedTriple.a[j]; c[i] += r * uncheckedTriple.c[j]; } } void output(ostream& outputStream, int n = N, bool human = false) { for (int i = 0; i < n; i++) { a[i].output(outputStream, human); b.output(outputStream, human); c[i].output(outputStream, human); } } }; template class PlainTriple : public Triple { public: // this assumes that valueBits[1] is still set to the bits of b void to(vector& valueBits, int i) { for (int j = 0; j < N; j++) { valueBits[0].set_int128(i * N + j, this->a[j].to_m128i()); valueBits[2].set_int128(i * N + j, this->c[j].to_m128i()); } } }; template class ShareTriple : public Triple, N> { public: void from(PlainTriple& triple, vector*>& ot_multipliers, int iTriple, const NPartyTripleGenerator& generator) { for (int l = 0; l < 3; l++) { int repeat = this->repeat(l); for (int j = 0; j < repeat; j++) { T value = triple.byIndex(l,j); T mac = value * generator.machine.get_mac_key(); for (int i = 0; i < generator.nparties-1; i++) mac += ot_multipliers[i]->macs[l][iTriple * repeat + j]; Share& share = this->byIndex(l,j); share.set_share(value); share.set_mac(mac); } } } T computeCheckMAC(const T& maskedA) { return this->c[0].get_mac() - maskedA * this->b.get_mac(); } }; /* * Copies the relevant base OTs from setup * N.B. setup must not be stored as it will be used by other threads */ NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, TripleMachine& machine) : globalPlayer(names, - thread_num * machine.nplayers * machine.nplayers), thread_num(thread_num), my_num(setup.get_my_num()), nloops(nloops), nparties(setup.get_nparties()), machine(machine) { nTriplesPerLoop = DIV_CEIL(_nTriples, nloops); nTriples = nTriplesPerLoop * nloops; field_size = 128; nAmplify = machine.amplify ? N_AMPLIFY : 1; nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify; int n = nparties; //baseReceiverInput = machines[0]->baseReceiverInput; //baseSenderInputs.resize(n-1); //baseReceiverOutputs.resize(n-1); nbase = setup.get_nbase(); baseReceiverInput.resize(nbase); baseReceiverOutputs.resize(n - 1); baseSenderInputs.resize(n - 1); players.resize(n-1); gf2n_long::init_field(128); for (int i = 0; i < n-1; i++) { // i for indexing, other_player is actual number int other_player, id; if (i >= my_num) other_player = i + 1; else other_player = i; // copy base OT inputs + outputs for (int j = 0; j < 128; j++) { baseReceiverInput.set_bit(j, (unsigned int)setup.get_base_receiver_input(j)); } baseReceiverOutputs[i] = setup.baseOTs[i]->receiver_outputs; baseSenderInputs[i] = setup.baseOTs[i]->sender_inputs; // new TwoPartyPlayer with unique id for each thread + pair of players if (my_num < other_player) id = (thread_num+1)*n*n + my_num*n + other_player; else id = (thread_num+1)*n*n + other_player*n + my_num; players[i] = new TwoPartyPlayer(names, other_player, id); cout << "Set up with player " << other_player << " in thread " << thread_num << " with id " << id << endl; } pthread_mutex_init(&mutex, 0); pthread_cond_init(&ready, 0); } NPartyTripleGenerator::~NPartyTripleGenerator() { for (size_t i = 0; i < players.size(); i++) delete players[i]; //delete nplayer; pthread_mutex_destroy(&mutex); pthread_cond_destroy(&ready); } template void* run_ot_thread(void* ptr) { ((OTMultiplier*)ptr)->multiply(); return NULL; } template void NPartyTripleGenerator::generate() { vector< OTMultiplier* > ot_multipliers(nparties-1); timers["Generator thread"].start(); for (int i = 0; i < nparties-1; i++) { ot_multipliers[i] = new OTMultiplier(*this, i); pthread_mutex_lock(&ot_multipliers[i]->mutex); pthread_create(&(ot_multipliers[i]->thread), 0, run_ot_thread, ot_multipliers[i]); } // add up the shares from each thread and write to file stringstream ss; ss << machine.prep_data_dir; if (machine.generateBits) ss << "Bits-"; else ss << "Triples-"; ss << T::type_char() << "-P" << my_num; if (thread_num != 0) ss << "-" << thread_num; ofstream outputFile; if (machine.output) outputFile.open(ss.str().c_str()); if (machine.generateBits) generateBits(ot_multipliers, outputFile); else generateTriples(ot_multipliers, outputFile); timers["Generator thread"].stop(); if (machine.output) cout << "Written " << nTriples << " outputs to " << ss.str() << endl; else cout << "Generated " << nTriples << " outputs" << endl; // wait for threads to finish for (int i = 0; i < nparties-1; i++) { pthread_mutex_unlock(&ot_multipliers[i]->mutex); pthread_join(ot_multipliers[i]->thread, NULL); cout << "OT thread " << i << " finished\n" << flush; } cout << "OT threads finished\n"; for (size_t i = 0; i < ot_multipliers.size(); i++) delete ot_multipliers[i]; } template<> void NPartyTripleGenerator::generateBits(vector< OTMultiplier* >& ot_multipliers, ofstream& outputFile) { PRNG share_prg; share_prg.ReSeed(); int nBitsToCheck = nTriplesPerLoop + field_size; valueBits.resize(1); valueBits[0].resize(ceil(1.0 * nBitsToCheck / field_size) * field_size); MAC_Check MC(machine.get_mac_key()); vector< Share > bits(nBitsToCheck); vector< Share > to_open(1); vector opened(1); start_progress(ot_multipliers); for (int k = 0; k < nloops; k++) { print_progress(k); valueBits[0].randomize_blocks(share_prg); for (int i = 0; i < nparties-1; i++) pthread_cond_signal(&ot_multipliers[i]->ready); timers["Authentication OTs"].start(); for (int i = 0; i < nparties-1; i++) pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); timers["Authentication OTs"].stop(); octet seed[SEED_SIZE]; Create_Random_Seed(seed, globalPlayer, SEED_SIZE); PRNG G; G.SetSeed(seed); Share check_sum; gf2n r; for (int j = 0; j < nBitsToCheck; j++) { gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key(); for (int i = 0; i < nparties-1; i++) mac_sum += ot_multipliers[i]->macs[0][j]; bits[j].set_share(valueBits[0].get_bit(j)); bits[j].set_mac(mac_sum); r.randomize(G); check_sum += r * bits[j]; } to_open[0] = check_sum; MC.POpen_Begin(opened, to_open, globalPlayer); MC.POpen_End(opened, to_open, globalPlayer); MC.Check(globalPlayer); if (machine.output) for (int j = 0; j < nTriplesPerLoop; j++) bits[j].output(outputFile, false); for (int i = 0; i < nparties-1; i++) pthread_cond_signal(&ot_multipliers[i]->ready); } } template<> void NPartyTripleGenerator::generateBits(vector< OTMultiplier* >& ot_multipliers, ofstream& outputFile) { generateTriples(ot_multipliers, outputFile); } template void NPartyTripleGenerator::generateTriples(vector< OTMultiplier* >& ot_multipliers, ofstream& outputFile) { PRNG share_prg; share_prg.ReSeed(); valueBits.resize(3); for (int i = 0; i < 2; i++) valueBits[2*i].resize(field_size * nPreampTriplesPerLoop); valueBits[1].resize(field_size * nTriplesPerLoop); vector< PlainTriple > preampTriples; vector< PlainTriple > amplifiedTriples; vector< ShareTriple > uncheckedTriples; MAC_Check MC(machine.get_mac_key()); if (machine.amplify) preampTriples.resize(nTriplesPerLoop); if (machine.generateMACs) { amplifiedTriples.resize(nTriplesPerLoop); uncheckedTriples.resize(nTriplesPerLoop); } start_progress(ot_multipliers); for (int k = 0; k < nloops; k++) { print_progress(k); for (int j = 0; j < 2; j++) valueBits[j].randomize_blocks(share_prg); timers["OTs"].start(); for (int i = 0; i < nparties-1; i++) pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); timers["OTs"].stop(); for (int j = 0; j < nPreampTriplesPerLoop; j++) { T a(valueBits[0].get_int128(j)); T b(valueBits[1].get_int128(j / nAmplify)); T c = a * b; timers["Triple computation"].start(); for (int i = 0; i < nparties-1; i++) { c += ot_multipliers[i]->c_output[j]; } timers["Triple computation"].stop(); if (machine.amplify) { preampTriples[j/nAmplify].a[j%nAmplify] = a; preampTriples[j/nAmplify].b = b; preampTriples[j/nAmplify].c[j%nAmplify] = c; } else if (machine.output) { timers["Writing"].start(); a.output(outputFile, false); b.output(outputFile, false); c.output(outputFile, false); timers["Writing"].stop(); } } if (machine.amplify) { octet seed[SEED_SIZE]; Create_Random_Seed(seed, globalPlayer, SEED_SIZE); PRNG G; G.SetSeed(seed); for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) { PlainTriple triple; triple.amplify(preampTriples[iTriple], G); if (machine.generateMACs) amplifiedTriples[iTriple] = triple; else if (machine.output) { timers["Writing"].start(); triple.output(outputFile); timers["Writing"].stop(); } } if (machine.generateMACs) { for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) amplifiedTriples[iTriple].to(valueBits, iTriple); for (int i = 0; i < nparties-1; i++) pthread_cond_signal(&ot_multipliers[i]->ready); timers["Authentication OTs"].start(); for (int i = 0; i < nparties-1; i++) pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); timers["Authentication OTs"].stop(); for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) { uncheckedTriples[iTriple].from(amplifiedTriples[iTriple], ot_multipliers, iTriple, *this); if (!machine.check and machine.output) { timers["Writing"].start(); amplifiedTriples[iTriple].output(outputFile); timers["Writing"].stop(); } } if (machine.check) { vector< Share > maskedAs(nTriplesPerLoop); vector< ShareTriple > maskedTriples(nTriplesPerLoop); for (int j = 0; j < nTriplesPerLoop; j++) { maskedTriples[j].amplify(uncheckedTriples[j], G); maskedAs[j] = maskedTriples[j].a[0]; } vector openedAs(nTriplesPerLoop); MC.POpen_Begin(openedAs, maskedAs, globalPlayer); MC.POpen_End(openedAs, maskedAs, globalPlayer); for (int j = 0; j < nTriplesPerLoop; j++) MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), int128(0), globalPlayer); MC.Check(globalPlayer); if (machine.generateBits) generateBitsFromTriples(uncheckedTriples, MC, outputFile); else if (machine.output) for (int j = 0; j < nTriplesPerLoop; j++) uncheckedTriples[j].output(outputFile, 1); } } } for (int i = 0; i < nparties-1; i++) pthread_cond_signal(&ot_multipliers[i]->ready); } } template<> void NPartyTripleGenerator::generateBitsFromTriples( vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) { vector< Share > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop); for (int i = 0; i < nTriplesPerLoop; i++) a_plus_b[i] = triples[i].a[0] + triples[i].b; vector opened(nTriplesPerLoop); MC.POpen_Begin(opened, a_plus_b, globalPlayer); MC.POpen_End(opened, a_plus_b, globalPlayer); for (int i = 0; i < nTriplesPerLoop; i++) a_squared[i] = triples[i].a[0] * opened[i] - triples[i].c[0]; MC.POpen_Begin(opened, a_squared, globalPlayer); MC.POpen_End(opened, a_squared, globalPlayer); Share one(gfp(1), globalPlayer.my_num(), MC.get_alphai()); for (int i = 0; i < nTriplesPerLoop; i++) { gfp root = opened[i].sqrRoot(); if (root.is_zero()) continue; Share bit = (triples[i].a[0] / root + one) / gfp(2); if (machine.output) bit.output(outputFile, false); } } template<> void NPartyTripleGenerator::generateBitsFromTriples( vector< ShareTriple >& triples, MAC_Check& MC, ofstream& outputFile) { throw how_would_that_work(); // warning gymnastics triples[0]; MC.number(); outputFile << ""; } template void NPartyTripleGenerator::start_progress(vector< OTMultiplier* >& ot_multipliers) { for (int i = 0; i < nparties-1; i++) pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex); lock(); signal(); wait(); gettimeofday(&last_lap, 0); for (int i = 0; i < nparties-1; i++) pthread_cond_signal(&ot_multipliers[i]->ready); } void NPartyTripleGenerator::print_progress(int k) { if (thread_num == 0 && my_num == 0) { struct timeval stop; gettimeofday(&stop, 0); if (timeval_diff_in_seconds(&last_lap, &stop) > 1) { double diff = timeval_diff_in_seconds(&machine.start, &stop); double throughput = k * nTriplesPerLoop * machine.nthreads / diff; double remaining = diff * (nloops - k) / k; cout << k << '/' << nloops << ", throughput: " << throughput << ", time left: " << remaining << ", elapsed: " << diff << ", estimated total: " << (diff + remaining) << endl; last_lap = stop; } } } void NPartyTripleGenerator::lock() { pthread_mutex_lock(&mutex); } void NPartyTripleGenerator::unlock() { pthread_mutex_unlock(&mutex); } void NPartyTripleGenerator::signal() { pthread_cond_signal(&ready); } void NPartyTripleGenerator::wait() { pthread_cond_wait(&ready, &mutex); } template void NPartyTripleGenerator::generate(); template void NPartyTripleGenerator::generate();