mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
545 lines
14 KiB
C++
545 lines
14 KiB
C++
/*
|
|
* Rep4.cpp
|
|
*
|
|
*/
|
|
|
|
#include "Rep4.h"
|
|
#include "GC/square64.h"
|
|
#include "Processor/TruncPrTuple.h"
|
|
|
|
template<class T>
|
|
Rep4<T>::Rep4(Player& P) :
|
|
my_num(P.my_num()), P(P)
|
|
{
|
|
assert(P.num_players() == 4);
|
|
|
|
rep_prngs[0].ReSeed();
|
|
|
|
octetStreams to_send(P), to_receive;
|
|
for (int i = 1; i < 3; i++)
|
|
to_send[P.get_player(-i)].append(rep_prngs[0].get_seed(), SEED_SIZE);
|
|
|
|
P.send_receive_all(to_send, to_receive);
|
|
|
|
for (int i = 1; i < 3; i++)
|
|
rep_prngs[i].SetSeed(to_receive[P.get_player(i)].get_data());
|
|
}
|
|
|
|
template<class T>
|
|
Rep4<T>::Rep4(Player& P, prngs_type& prngs) :
|
|
my_num(P.my_num()), P(P)
|
|
{
|
|
for (int i = 0; i < 3; i++)
|
|
rep_prngs[i].SetSeed(prngs[i]);
|
|
}
|
|
|
|
template<class T>
|
|
Rep4<T>::~Rep4()
|
|
{
|
|
for (auto& x : receive_hashes)
|
|
for (auto& y : x)
|
|
if (y.size > 0)
|
|
{
|
|
check();
|
|
return;
|
|
}
|
|
|
|
for (auto& x : send_hashes)
|
|
for (auto& y : x)
|
|
if (y.size > 0)
|
|
{
|
|
check();
|
|
return;
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
Rep4<T> Rep4<T>::branch()
|
|
{
|
|
return {P, rep_prngs};
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::init_mul()
|
|
{
|
|
for (auto& x : add_shares)
|
|
x.clear();
|
|
bit_lengths.clear();
|
|
|
|
send_os.reset(P);
|
|
receive_os.reset(P);
|
|
channels.resize(P.num_players(), vector<bool>(P.num_players(), false));
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::reset_joint_input(int n_inputs)
|
|
{
|
|
results.clear();
|
|
results.resize(n_inputs);
|
|
bit_lengths.clear();
|
|
bit_lengths.resize(n_inputs, -1);
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::prepare_joint_input(int sender, int backup, int receiver,
|
|
int outsider, vector<open_type>& inputs)
|
|
{
|
|
prepare_joint_input(sender, backup, receiver, outsider, inputs, results);
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::prepare_joint_input(int sender, int backup, int receiver,
|
|
int outsider, vector<open_type>& inputs, vector<ResTuple>& results)
|
|
{
|
|
channels[sender][receiver] = true;
|
|
|
|
if (P.my_num() != receiver)
|
|
{
|
|
int index = P.get_offset(receiver) - 1;
|
|
for (auto& x : results)
|
|
{
|
|
x.r = rep_prngs[index].get();
|
|
x.res[index] += x.r;
|
|
}
|
|
|
|
if (P.my_num() == sender or P.my_num() == backup)
|
|
{
|
|
int offset = P.get_offset(outsider) - 1;
|
|
size_t n_results = results.size();
|
|
for (size_t i = 0; i < n_results; i++)
|
|
{
|
|
auto& input = inputs[i];
|
|
input -= results[i].r;
|
|
results[i].res[offset] += input;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (P.my_num() == backup)
|
|
{
|
|
send_hashes[sender][receiver].update(inputs, bit_lengths);
|
|
}
|
|
|
|
if (sender == P.my_num())
|
|
{
|
|
assert(inputs.size() == bit_lengths.size());
|
|
switch (P.get_offset(backup))
|
|
{
|
|
case 2:
|
|
for (size_t i = 0; i < inputs.size(); i++)
|
|
inputs[i].pack(send_os[3 - my_num], bit_lengths[i]);
|
|
break;
|
|
case 1:
|
|
for (size_t i = 0; i < inputs.size(); i++)
|
|
inputs[i].pack(send_os[get_player(-1)], bit_lengths[i]);
|
|
break;
|
|
default:
|
|
throw not_implemented();
|
|
}
|
|
}
|
|
|
|
for (auto& x : send_os)
|
|
x.append(0);
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::finalize_joint_input(int sender, int backup, int receiver,
|
|
int outsider)
|
|
{
|
|
finalize_joint_input(sender, backup, receiver, outsider, results);
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::finalize_joint_input(int sender, int backup, int receiver,
|
|
int, vector<ResTuple>& results)
|
|
{
|
|
if (P.my_num() == receiver)
|
|
{
|
|
assert(results.size() == bit_lengths.size());
|
|
T res;
|
|
size_t n_results = results.size();
|
|
octetStream* os;
|
|
int index;
|
|
switch (P.get_offset(backup))
|
|
{
|
|
case 2:
|
|
os = &receive_os[get_player(1)];
|
|
index = 2;
|
|
break;
|
|
default:
|
|
os = &receive_os[3 - P.my_num()];
|
|
index = 1;
|
|
break;
|
|
}
|
|
|
|
auto start = os->get_data_ptr();
|
|
for (size_t i = 0; i < n_results; i++)
|
|
{
|
|
auto& x = results[i];
|
|
res[1].unpack(*os, bit_lengths[i]);
|
|
x.res[index] += res[1];
|
|
}
|
|
|
|
os->consume(0);
|
|
receive_hashes[sender][backup].update(start,
|
|
os->get_data_ptr() - start);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
int Rep4<T>::get_player(int offset)
|
|
{
|
|
return (my_num + offset) & 3;
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::prepare_mul(const T& x, const T& y, int n_bits)
|
|
{
|
|
auto a = get_addshares(x, y);
|
|
for (int i = 0; i < 5; i++)
|
|
add_shares[i].push_back(a[i]);
|
|
bit_lengths.push_back(n_bits);
|
|
}
|
|
|
|
template<class T>
|
|
array<typename T::open_type, 5> Rep4<T>::get_addshares(const T& x, const T& y)
|
|
{
|
|
array<open_type, 5> res;
|
|
for (int i = 0; i < 2; i++)
|
|
res[get_player(i - 1)] =
|
|
(x[i] + x[i + 1]) * y[i] + x[i] * y[i + 1];
|
|
res[4] = x[0] * y[2] + x[2] * y[0];
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::init_dotprod()
|
|
{
|
|
init_mul();
|
|
dotprod_shares = {};
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::prepare_dotprod(const T& x, const T& y)
|
|
{
|
|
auto a = get_addshares(x, y);
|
|
for (int i = 0; i < 5; i++)
|
|
dotprod_shares[i] += a[i];
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::next_dotprod()
|
|
{
|
|
for (int i = 0; i < 5; i++)
|
|
add_shares[i].push_back(dotprod_shares[i]);
|
|
bit_lengths.push_back(-1);
|
|
dotprod_shares = {};
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::exchange()
|
|
{
|
|
auto& a = add_shares;
|
|
results.clear();
|
|
results.resize(a[4].size());
|
|
prepare_joint_input(0, 1, 3, 2, a[0]);
|
|
prepare_joint_input(1, 2, 0, 3, a[1]);
|
|
prepare_joint_input(2, 3, 1, 0, a[2]);
|
|
prepare_joint_input(3, 0, 2, 1, a[3]);
|
|
prepare_joint_input(0, 2, 3, 1, a[4]);
|
|
prepare_joint_input(1, 3, 2, 0, a[4]);
|
|
P.send_receive_all(channels, send_os, receive_os);
|
|
finalize_joint_input(0, 1, 3, 2);
|
|
finalize_joint_input(1, 2, 0, 3);
|
|
finalize_joint_input(2, 3, 1, 0);
|
|
finalize_joint_input(3, 0, 2, 1);
|
|
finalize_joint_input(0, 2, 3, 1);
|
|
finalize_joint_input(1, 3, 2, 0);
|
|
}
|
|
|
|
template<class T>
|
|
T Rep4<T>::finalize_mul(int n_bits)
|
|
{
|
|
this->counter++;
|
|
if (n_bits == -1)
|
|
return results.next().res;
|
|
else
|
|
return finalize_mul(n_bits, T::clear::binary);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
T Rep4<T>::finalize_mul(int n_bits, true_type)
|
|
{
|
|
return results.next().res.mask(n_bits);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
T Rep4<T>::finalize_mul(int, false_type)
|
|
{
|
|
throw runtime_error("bit-wise multiplication not supported");
|
|
}
|
|
|
|
template<class T>
|
|
T Rep4<T>::finalize_dotprod(int)
|
|
{
|
|
return finalize_mul();
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::check()
|
|
{
|
|
octetStreams to_send(P);
|
|
for (int i = 1; i < 4; i++)
|
|
for (int j = 0; j < 4; j++)
|
|
to_send[P.get_player(i)].concat(send_hashes[j][P.get_player(i)].final());
|
|
|
|
octetStreams to_receive;
|
|
P.send_receive_all(to_send, to_receive);
|
|
|
|
octetStream tmp;
|
|
for (int i = 1; i < 4; i++)
|
|
for (int j = 0; j < 4; j++)
|
|
{
|
|
to_receive[P.get_player(-i)].consume(tmp, Hash::hash_length);
|
|
if (receive_hashes[j][P.get_player(-i)].final() != tmp)
|
|
throw runtime_error(
|
|
"hash mismatch for sender " + to_string(j)
|
|
+ " and backup " + to_string(P.get_player(-i)));
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
T Rep4<T>::get_random()
|
|
{
|
|
T res;
|
|
for (int i = 0; i < 3; i++)
|
|
res[i].randomize(rep_prngs[i]);
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::randoms(T& res, int n_bits)
|
|
{
|
|
for (int i = 0; i < 3; i++)
|
|
res[i].randomize_part(rep_prngs[i], n_bits);
|
|
}
|
|
|
|
template<class T>
|
|
void Rep4<T>::trunc_pr(const vector<int>& regs, int size,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
trunc_pr<0>(regs, size, proc, T::clear::characteristic_two);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void Rep4<T>::trunc_pr(const vector<int>&, int, SubProcessor<T>&, true_type)
|
|
{
|
|
throw runtime_error("only implemented for integer-like domains");
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void Rep4<T>::trunc_pr(const vector<int>& regs, int size,
|
|
SubProcessor<T>& proc, false_type)
|
|
{
|
|
assert(regs.size() % 4 == 0);
|
|
this->trunc_pr_counter += size * regs.size() / 4;
|
|
typedef typename T::open_type open_type;
|
|
|
|
vector<TruncPrTupleWithGap<open_type>> infos;
|
|
for (size_t i = 0; i < regs.size(); i += 4)
|
|
infos.push_back({regs, i});
|
|
|
|
PointerVector<T> rs(size * infos.size());
|
|
for (int i = 2; i < 4; i++)
|
|
{
|
|
int index = P.get_offset(i) - 1;
|
|
if (index >= 0)
|
|
for (auto& r : rs)
|
|
r[index].randomize(rep_prngs[index]);
|
|
}
|
|
|
|
vector<T> cs;
|
|
cs.reserve(rs.size());
|
|
for (auto& info : infos)
|
|
{
|
|
for (int j = 0; j < size; j++)
|
|
cs.push_back(proc.get_S_ref(info.source_base + j) + rs.next());
|
|
}
|
|
|
|
octetStream c_os;
|
|
vector<open_type> eval_inputs;
|
|
if (P.my_num() < 2)
|
|
{
|
|
if (P.my_num() == 0)
|
|
for (auto& c : cs)
|
|
(c[1] + c[2]).pack(c_os);
|
|
else
|
|
for (auto& c : cs)
|
|
(c[1] + c[0]).pack(c_os);
|
|
P.send_to(2 + P.my_num(), c_os);
|
|
P.send_to(3 - P.my_num(), c_os.hash());
|
|
}
|
|
else
|
|
{
|
|
P.receive_player(P.my_num() - 2, c_os);
|
|
octetStream hash;
|
|
P.receive_player(3 - P.my_num(), hash);
|
|
if (hash != c_os.hash())
|
|
throw runtime_error("hash mismatch in joint message passing");
|
|
PointerVector<open_type> open_cs;
|
|
if (P.my_num() == 2)
|
|
for (auto& c : cs)
|
|
open_cs.push_back(c_os.get<open_type>() + c[1] + c[2]);
|
|
else
|
|
for (auto& c : cs)
|
|
open_cs.push_back(c_os.get<open_type>() + c[1] + c[0]);
|
|
for (auto& info : infos)
|
|
for (int j = 0; j < size; j++)
|
|
{
|
|
auto c = open_cs.next();
|
|
auto c_prime = info.upper(c);
|
|
if (not info.big_gap())
|
|
{
|
|
auto c_msb = info.msb(c);
|
|
eval_inputs.push_back(c_msb);
|
|
}
|
|
eval_inputs.push_back(c_prime);
|
|
}
|
|
}
|
|
|
|
PointerVector<open_type> inputs;
|
|
bool generate = proc.P.my_num() < 2;
|
|
if (generate)
|
|
{
|
|
inputs.reserve(2 * rs.size());
|
|
rs.reset();
|
|
for (auto& info : infos)
|
|
for (int j = 0; j < size; j++)
|
|
{
|
|
auto& r = rs.next();
|
|
if (not info.big_gap())
|
|
inputs.push_back(info.msb(r.sum()));
|
|
inputs.push_back(info.upper(r.sum()));
|
|
}
|
|
}
|
|
|
|
init_mul();
|
|
size_t n_inputs = max(inputs.size(), eval_inputs.size());
|
|
reset_joint_input(n_inputs);
|
|
PointerVector<ResTuple> gen_results(n_inputs);
|
|
PointerVector<ResTuple> eval_results(n_inputs);
|
|
prepare_joint_input(0, 1, 3, 2, inputs, gen_results);
|
|
prepare_joint_input(2, 3, 1, 0, eval_inputs, eval_results);
|
|
P.send_receive_all(channels, send_os, receive_os);
|
|
finalize_joint_input(0, 1, 3, 2, gen_results);
|
|
finalize_joint_input(2, 3, 1, 0, eval_results);
|
|
|
|
init_mul();
|
|
for (auto& info : infos)
|
|
for (int j = 0; j < size; j++)
|
|
{
|
|
if (not info.big_gap())
|
|
prepare_mul(gen_results.next().res, eval_results.next().res);
|
|
gen_results.next();
|
|
eval_results.next();
|
|
}
|
|
|
|
if (not add_shares[0].empty())
|
|
exchange();
|
|
|
|
eval_results.reset();
|
|
gen_results.reset();
|
|
|
|
for (auto& info : infos)
|
|
for (int j = 0; j < size; j++)
|
|
{
|
|
if (info.big_gap())
|
|
proc.get_S_ref(info.dest_base + j) = eval_results.next().res
|
|
- gen_results.next().res;
|
|
else
|
|
{
|
|
auto b = gen_results.next().res + eval_results.next().res
|
|
- 2 * finalize_mul();
|
|
proc.get_S_ref(info.dest_base + j) = eval_results.next().res
|
|
- gen_results.next().res + (b << (info.k - info.m));
|
|
}
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void Rep4<T>::split(vector<T>& dest, const vector<int>& regs, int n_bits,
|
|
const U* source, int n_inputs)
|
|
{
|
|
assert(regs.size() / n_bits == 2);
|
|
assert(n_bits <= 64);
|
|
int unit = GC::Clear::N_BITS;
|
|
int my_num = P.my_num();
|
|
int i0 = -1;
|
|
|
|
switch (my_num)
|
|
{
|
|
case 0:
|
|
i0 = 1;
|
|
break;
|
|
case 1:
|
|
i0 = 0;
|
|
break;
|
|
case 2:
|
|
i0 = 1;
|
|
break;
|
|
case 3:
|
|
i0 = 0;
|
|
break;
|
|
}
|
|
|
|
vector<BitVec> to_share;
|
|
init_mul();
|
|
|
|
for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++)
|
|
{
|
|
int start = k * unit;
|
|
int m = min(unit, n_inputs - start);
|
|
|
|
square64 square;
|
|
|
|
for (int j = 0; j < m; j++)
|
|
{
|
|
auto& input_share = source[j + start];
|
|
auto input_value = input_share[i0] + input_share[i0 + 1];
|
|
square.rows[j] = Integer(input_value).get();
|
|
}
|
|
|
|
square.transpose(m, n_bits);
|
|
|
|
for (int j = 0; j < n_bits; j++)
|
|
{
|
|
to_share.push_back(square.rows[j]);
|
|
bit_lengths.push_back(m);
|
|
}
|
|
}
|
|
|
|
array<PointerVector<ResTuple>, 2> results;
|
|
for (auto& x : results)
|
|
x.resize(to_share.size());
|
|
prepare_joint_input(0, 1, 3, 2, to_share, results[0]);
|
|
prepare_joint_input(2, 3, 1, 0, to_share, results[1]);
|
|
P.send_receive_all(channels, send_os, receive_os);
|
|
finalize_joint_input(0, 1, 3, 2, results[0]);
|
|
finalize_joint_input(2, 3, 1, 0, results[1]);
|
|
|
|
for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++)
|
|
{
|
|
for (int j = 0; j < n_bits; j++)
|
|
for (int i = 0; i < 2; i++)
|
|
{
|
|
auto res = results[i].next().res;
|
|
dest.at(regs.at(2 * j + i) + k) = res;
|
|
}
|
|
}
|
|
}
|