/* * Rep4.cpp * */ #include "Rep4.h" #include "GC/square64.h" #include "Processor/TruncPrTuple.h" template Rep4::Rep4(Player& P) : my_num(P.my_num()), P(P) { assert(P.num_players() == 4); rep_prngs[0].ReSeed(); octetStreams to_send(P), to_receive; for (int i = 1; i < 3; i++) to_send[P.get_player(-i)].append(rep_prngs[0].get_seed(), SEED_SIZE); P.send_receive_all(to_send, to_receive); for (int i = 1; i < 3; i++) rep_prngs[i].SetSeed(to_receive[P.get_player(i)].get_data()); } template Rep4::Rep4(Player& P, prngs_type& prngs) : my_num(P.my_num()), P(P) { for (int i = 0; i < 3; i++) rep_prngs[i].SetSeed(prngs[i]); } template Rep4::~Rep4() { for (auto& x : receive_hashes) for (auto& y : x) if (y.size > 0) { check(); return; } for (auto& x : send_hashes) for (auto& y : x) if (y.size > 0) { check(); return; } } template Rep4 Rep4::branch() { return {P, rep_prngs}; } template void Rep4::init_mul() { for (auto& x : add_shares) x.clear(); bit_lengths.clear(); send_os.reset(P); receive_os.reset(P); channels.resize(P.num_players(), vector(P.num_players(), false)); } template void Rep4::reset_joint_input(int n_inputs) { results.clear(); results.resize(n_inputs); bit_lengths.clear(); bit_lengths.resize(n_inputs, -1); } template void Rep4::prepare_joint_input(int sender, int backup, int receiver, int outsider, vector& inputs) { prepare_joint_input(sender, backup, receiver, outsider, inputs, results); } template void Rep4::prepare_joint_input(int sender, int backup, int receiver, int outsider, vector& inputs, vector& results) { channels[sender][receiver] = true; if (P.my_num() != receiver) { int index = P.get_offset(receiver) - 1; for (auto& x : results) { x.r = rep_prngs[index].get(); x.res[index] += x.r; } if (P.my_num() == sender or P.my_num() == backup) { int offset = P.get_offset(outsider) - 1; size_t n_results = results.size(); for (size_t i = 0; i < n_results; i++) { auto& input = inputs[i]; input -= results[i].r; results[i].res[offset] += input; } } } if (P.my_num() == backup) { send_hashes[sender][receiver].update(inputs, bit_lengths); } if (sender == P.my_num()) { assert(inputs.size() == bit_lengths.size()); switch (P.get_offset(backup)) { case 2: for (size_t i = 0; i < inputs.size(); i++) inputs[i].pack(send_os[3 - my_num], bit_lengths[i]); break; case 1: for (size_t i = 0; i < inputs.size(); i++) inputs[i].pack(send_os[get_player(-1)], bit_lengths[i]); break; default: throw not_implemented(); } } for (auto& x : send_os) x.append(0); } template void Rep4::finalize_joint_input(int sender, int backup, int receiver, int outsider) { finalize_joint_input(sender, backup, receiver, outsider, results); } template void Rep4::finalize_joint_input(int sender, int backup, int receiver, int, vector& results) { if (P.my_num() == receiver) { assert(results.size() == bit_lengths.size()); T res; size_t n_results = results.size(); octetStream* os; int index; switch (P.get_offset(backup)) { case 2: os = &receive_os[get_player(1)]; index = 2; break; default: os = &receive_os[3 - P.my_num()]; index = 1; break; } auto start = os->get_data_ptr(); for (size_t i = 0; i < n_results; i++) { auto& x = results[i]; res[1].unpack(*os, bit_lengths[i]); x.res[index] += res[1]; } os->consume(0); receive_hashes[sender][backup].update(start, os->get_data_ptr() - start); } } template int Rep4::get_player(int offset) { return (my_num + offset) & 3; } template void Rep4::prepare_mul(const T& x, const T& y, int n_bits) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) add_shares[i].push_back(a[i]); bit_lengths.push_back(n_bits); } template array Rep4::get_addshares(const T& x, const T& y) { array res; for (int i = 0; i < 2; i++) res[get_player(i - 1)] = (x[i] + x[i + 1]) * y[i] + x[i] * y[i + 1]; res[4] = x[0] * y[2] + x[2] * y[0]; return res; } template void Rep4::init_dotprod() { init_mul(); dotprod_shares = {}; } template void Rep4::prepare_dotprod(const T& x, const T& y) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) dotprod_shares[i] += a[i]; } template void Rep4::next_dotprod() { for (int i = 0; i < 5; i++) add_shares[i].push_back(dotprod_shares[i]); bit_lengths.push_back(-1); dotprod_shares = {}; } template void Rep4::exchange() { auto& a = add_shares; results.clear(); results.resize(a[4].size()); prepare_joint_input(0, 1, 3, 2, a[0]); prepare_joint_input(1, 2, 0, 3, a[1]); prepare_joint_input(2, 3, 1, 0, a[2]); prepare_joint_input(3, 0, 2, 1, a[3]); prepare_joint_input(0, 2, 3, 1, a[4]); prepare_joint_input(1, 3, 2, 0, a[4]); P.send_receive_all(channels, send_os, receive_os); finalize_joint_input(0, 1, 3, 2); finalize_joint_input(1, 2, 0, 3); finalize_joint_input(2, 3, 1, 0); finalize_joint_input(3, 0, 2, 1); finalize_joint_input(0, 2, 3, 1); finalize_joint_input(1, 3, 2, 0); } template T Rep4::finalize_mul(int n_bits) { this->counter++; if (n_bits == -1) return results.next().res; else return finalize_mul(n_bits, T::clear::binary); } template template T Rep4::finalize_mul(int n_bits, true_type) { return results.next().res.mask(n_bits); } template template T Rep4::finalize_mul(int, false_type) { throw runtime_error("bit-wise multiplication not supported"); } template T Rep4::finalize_dotprod(int) { return finalize_mul(); } template void Rep4::check() { octetStreams to_send(P); for (int i = 1; i < 4; i++) for (int j = 0; j < 4; j++) to_send[P.get_player(i)].concat(send_hashes[j][P.get_player(i)].final()); octetStreams to_receive; P.send_receive_all(to_send, to_receive); octetStream tmp; for (int i = 1; i < 4; i++) for (int j = 0; j < 4; j++) { to_receive[P.get_player(-i)].consume(tmp, Hash::hash_length); if (receive_hashes[j][P.get_player(-i)].final() != tmp) throw runtime_error( "hash mismatch for sender " + to_string(j) + " and backup " + to_string(P.get_player(-i))); } } template T Rep4::get_random() { T res; for (int i = 0; i < 3; i++) res[i].randomize(rep_prngs[i]); return res; } template void Rep4::randoms(T& res, int n_bits) { for (int i = 0; i < 3; i++) res[i].randomize_part(rep_prngs[i], n_bits); } template void Rep4::trunc_pr(const vector& regs, int size, SubProcessor& proc) { trunc_pr<0>(regs, size, proc, T::clear::characteristic_two); } template template void Rep4::trunc_pr(const vector&, int, SubProcessor&, true_type) { throw runtime_error("only implemented for integer-like domains"); } template template void Rep4::trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type) { assert(regs.size() % 4 == 0); this->trunc_pr_counter += size * regs.size() / 4; typedef typename T::open_type open_type; vector> infos; for (size_t i = 0; i < regs.size(); i += 4) infos.push_back({regs, i}); PointerVector rs(size * infos.size()); for (int i = 2; i < 4; i++) { int index = P.get_offset(i) - 1; if (index >= 0) for (auto& r : rs) r[index].randomize(rep_prngs[index]); } vector cs; cs.reserve(rs.size()); for (auto& info : infos) { for (int j = 0; j < size; j++) cs.push_back(proc.get_S_ref(info.source_base + j) + rs.next()); } octetStream c_os; vector eval_inputs; if (P.my_num() < 2) { if (P.my_num() == 0) for (auto& c : cs) (c[1] + c[2]).pack(c_os); else for (auto& c : cs) (c[1] + c[0]).pack(c_os); P.send_to(2 + P.my_num(), c_os); P.send_to(3 - P.my_num(), c_os.hash()); } else { P.receive_player(P.my_num() - 2, c_os); octetStream hash; P.receive_player(3 - P.my_num(), hash); if (hash != c_os.hash()) throw runtime_error("hash mismatch in joint message passing"); PointerVector open_cs; if (P.my_num() == 2) for (auto& c : cs) open_cs.push_back(c_os.get() + c[1] + c[2]); else for (auto& c : cs) open_cs.push_back(c_os.get() + c[1] + c[0]); for (auto& info : infos) for (int j = 0; j < size; j++) { auto c = open_cs.next(); auto c_prime = info.upper(c); if (not info.big_gap()) { auto c_msb = info.msb(c); eval_inputs.push_back(c_msb); } eval_inputs.push_back(c_prime); } } PointerVector inputs; bool generate = proc.P.my_num() < 2; if (generate) { inputs.reserve(2 * rs.size()); rs.reset(); for (auto& info : infos) for (int j = 0; j < size; j++) { auto& r = rs.next(); if (not info.big_gap()) inputs.push_back(info.msb(r.sum())); inputs.push_back(info.upper(r.sum())); } } init_mul(); size_t n_inputs = max(inputs.size(), eval_inputs.size()); reset_joint_input(n_inputs); PointerVector gen_results(n_inputs); PointerVector eval_results(n_inputs); prepare_joint_input(0, 1, 3, 2, inputs, gen_results); prepare_joint_input(2, 3, 1, 0, eval_inputs, eval_results); P.send_receive_all(channels, send_os, receive_os); finalize_joint_input(0, 1, 3, 2, gen_results); finalize_joint_input(2, 3, 1, 0, eval_results); init_mul(); for (auto& info : infos) for (int j = 0; j < size; j++) { if (not info.big_gap()) prepare_mul(gen_results.next().res, eval_results.next().res); gen_results.next(); eval_results.next(); } if (not add_shares[0].empty()) exchange(); eval_results.reset(); gen_results.reset(); for (auto& info : infos) for (int j = 0; j < size; j++) { if (info.big_gap()) proc.get_S_ref(info.dest_base + j) = eval_results.next().res - gen_results.next().res; else { auto b = gen_results.next().res + eval_results.next().res - 2 * finalize_mul(); proc.get_S_ref(info.dest_base + j) = eval_results.next().res - gen_results.next().res + (b << (info.k - info.m)); } } } template template void Rep4::split(vector& dest, const vector& regs, int n_bits, const U* source, int n_inputs) { assert(regs.size() / n_bits == 2); assert(n_bits <= 64); int unit = GC::Clear::N_BITS; int my_num = P.my_num(); int i0 = -1; switch (my_num) { case 0: i0 = 1; break; case 1: i0 = 0; break; case 2: i0 = 1; break; case 3: i0 = 0; break; } vector to_share; init_mul(); for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) { int start = k * unit; int m = min(unit, n_inputs - start); square64 square; for (int j = 0; j < m; j++) { auto& input_share = source[j + start]; auto input_value = input_share[i0] + input_share[i0 + 1]; square.rows[j] = Integer(input_value).get(); } square.transpose(m, n_bits); for (int j = 0; j < n_bits; j++) { to_share.push_back(square.rows[j]); bit_lengths.push_back(m); } } array, 2> results; for (auto& x : results) x.resize(to_share.size()); prepare_joint_input(0, 1, 3, 2, to_share, results[0]); prepare_joint_input(2, 3, 1, 0, to_share, results[1]); P.send_receive_all(channels, send_os, receive_os); finalize_joint_input(0, 1, 3, 2, results[0]); finalize_joint_input(2, 3, 1, 0, results[1]); for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) { for (int j = 0; j < n_bits; j++) for (int i = 0; i < 2; i++) { auto res = results[i].next().res; dest.at(regs.at(2 * j + i) + k) = res; } } }