mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-13 23:48:15 -05:00
695 lines
17 KiB
C++
695 lines
17 KiB
C++
/*
|
|
* Replicated.cpp
|
|
*
|
|
*/
|
|
|
|
#ifndef PROTOCOLS_REPLICATED_HPP_
|
|
#define PROTOCOLS_REPLICATED_HPP_
|
|
|
|
#include "Replicated.h"
|
|
#include "Processor/Processor.h"
|
|
#include "Processor/TruncPrTuple.h"
|
|
#include "Tools/benchmarking.h"
|
|
#include "Tools/Bundle.h"
|
|
#include "Tools/DoubleRange.h"
|
|
#include "Tools/ranges.h"
|
|
|
|
#include "ReplicatedInput.h"
|
|
#include "Rep3Share2k.h"
|
|
#include "Astra.h"
|
|
|
|
#include "ReplicatedPO.hpp"
|
|
#include "Math/Z2k.hpp"
|
|
|
|
template<class T>
|
|
ProtocolBase<T>::ProtocolBase() :
|
|
trunc_pr_counter(0), trunc_pr_big_counter(0),
|
|
rounds(0), trunc_rounds(0), dot_counter(0),
|
|
bit_counter(0), counter(0)
|
|
{
|
|
buffer_size = OnlineOptions::singleton.batch_size;
|
|
}
|
|
|
|
template<class T>
|
|
Replicated<T>::Replicated(Player& P) :
|
|
ReplicatedBase(P), fast_mode(false)
|
|
{
|
|
assert(T::vector_length == 2);
|
|
}
|
|
|
|
template<class T>
|
|
Replicated<T>::Replicated(const ReplicatedBase& other) :
|
|
Replicated(other.P)
|
|
{
|
|
}
|
|
|
|
template<class T>
|
|
Replicated<T>::~Replicated()
|
|
{
|
|
for (auto x : helper_inputs)
|
|
delete x;
|
|
output_time<T>();
|
|
}
|
|
|
|
template<class T>
|
|
void ReplicatedBase::output_time()
|
|
{
|
|
#ifdef VERBOSE
|
|
if (OnlineOptions::singleton.verbose)
|
|
{
|
|
auto time = randomness_time();
|
|
if (time)
|
|
cout << T::type_string() << " randomness: " << time << " seconds" << endl;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
inline
|
|
double ReplicatedBase::randomness_time()
|
|
{
|
|
return shared_prngs[0].timer.elapsed() + shared_prngs[1].timer.elapsed();
|
|
}
|
|
|
|
template<class T>
|
|
double Replicated<T>::randomness_time()
|
|
{
|
|
auto res = ReplicatedBase::randomness_time();
|
|
for (auto& x : helper_inputs)
|
|
res += x->randomness_time();
|
|
return res;
|
|
}
|
|
|
|
inline ReplicatedBase::ReplicatedBase(Player& P) : P(P)
|
|
{
|
|
assert(P.num_players() == 3);
|
|
if (not P.is_encrypted())
|
|
insecure("unencrypted communication", false);
|
|
|
|
shared_prngs[0].ReSeed();
|
|
octetStream os;
|
|
os.append(shared_prngs[0].get_seed(), SEED_SIZE);
|
|
P.pass_around(os, 1);
|
|
shared_prngs[1].SetSeed(os.get_data());
|
|
}
|
|
|
|
inline ReplicatedBase::ReplicatedBase(Player& P, array<PRNG, 2>& prngs) :
|
|
P(P)
|
|
{
|
|
for (int i = 0; i < 2; i++)
|
|
shared_prngs[i].SetSeed(prngs[i]);
|
|
}
|
|
|
|
inline ReplicatedBase ReplicatedBase::branch() const
|
|
{
|
|
return {P, shared_prngs};
|
|
}
|
|
|
|
template<class T>
|
|
ProtocolBase<T>::~ProtocolBase()
|
|
{
|
|
#ifdef VERBOSE_COUNT
|
|
if (counter or rounds)
|
|
cerr << "Number of " << T::type_string() << " multiplications: "
|
|
<< counter << " (" << bit_counter << " bits) in " << rounds
|
|
<< " rounds" << endl;
|
|
if (counter or rounds)
|
|
cerr << "Number of " << T::type_string() << " dot products: " << dot_counter << endl;
|
|
if (trunc_pr_counter or trunc_rounds)
|
|
cerr << "Number of probabilistic truncations: " << trunc_pr_counter << " in " << trunc_rounds << " rounds" << endl;
|
|
#endif
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::mulrs(const vector<int>& reg,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
proc.mulrs(reg);
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::multiply(vector<T>& products,
|
|
vector<pair<T, T> >& multiplicands, int begin, int end,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
#ifdef VERBOSE_CENTRAL
|
|
fprintf(stderr, "multiply from %d to %d in %d\n", begin, end,
|
|
BaseMachine::thread_num);
|
|
#endif
|
|
|
|
init(proc.DataF, proc.MC);
|
|
init_mul();
|
|
for (int i = begin; i < end; i++)
|
|
prepare_mul(multiplicands[i].first, multiplicands[i].second);
|
|
exchange();
|
|
for (int i = begin; i < end; i++)
|
|
products[i] = finalize_mul();
|
|
}
|
|
|
|
template<class T>
|
|
T ProtocolBase<T>::mul(const T& x, const T& y)
|
|
{
|
|
init_mul();
|
|
prepare_mul(x, y);
|
|
exchange();
|
|
return finalize_mul();
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::prepare_mult(const T& x, const T& y, int n,
|
|
bool)
|
|
{
|
|
prepare_mul(x, y, n);
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::finalize_mult(T& res, int n)
|
|
{
|
|
res = finalize_mul(n);
|
|
}
|
|
|
|
template<class T>
|
|
T ProtocolBase<T>::finalize_dotprod(int length)
|
|
{
|
|
counter += length;
|
|
dot_counter++;
|
|
T res;
|
|
for (int i = 0; i < length; i++)
|
|
res += finalize_mul();
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
T ProtocolBase<T>::get_random()
|
|
{
|
|
if (random.empty())
|
|
{
|
|
buffer_random();
|
|
assert(not random.empty());
|
|
}
|
|
|
|
auto res = random.back();
|
|
random.pop_back();
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
vector<int> ProtocolBase<T>::get_relevant_players()
|
|
{
|
|
vector<int> res;
|
|
int n = dynamic_cast<typename T::Protocol&>(*this).P.num_players();
|
|
for (int i = 0; i < T::threshold(n) + 1; i++)
|
|
res.push_back(i);
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::init_mul()
|
|
{
|
|
if (os[1].left() or add_shares.left())
|
|
throw runtime_error("unused data in Rep3");
|
|
|
|
for (auto& o : os)
|
|
o.reset_write_head();
|
|
add_shares.clear();
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::set_fast_mode(bool change)
|
|
{
|
|
fast_mode = change;
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::prepare_mul(const T& x,
|
|
const T& y, int n)
|
|
{
|
|
typename T::value_type add_share = x.local_mul(y);
|
|
if (not T::clear::binary or fast_mode)
|
|
add_shares.push_back(add_share);
|
|
else
|
|
prepare_reshare(add_share, n);
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::prepare_mul_fast(const T& x,
|
|
const T& y)
|
|
{
|
|
add_shares.push_back(x.local_mul(y));
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::prepare_reshare(const typename T::clear& share,
|
|
int n)
|
|
{
|
|
if (T::clear::binary)
|
|
{
|
|
typename T::value_type tmp[2];
|
|
for (int i = 0; i < 2; i++)
|
|
tmp[i].randomize(shared_prngs[i], n);
|
|
auto add_share = share + tmp[0] - tmp[1];
|
|
add_shares.push_back(add_share);
|
|
add_share.pack(os[0], n);
|
|
}
|
|
else
|
|
add_shares.push_back(share);
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::prepare_exchange()
|
|
{
|
|
if (not T::clear::binary or fast_mode)
|
|
{
|
|
os[0].reserve(add_shares.size() * T::clear::size());
|
|
for (auto& add_share : add_shares)
|
|
{
|
|
add_share += shared_prngs[0];
|
|
add_share -= shared_prngs[1];
|
|
os[0].append_no_resize((octet*) add_share.get_ptr(),
|
|
add_share.size());
|
|
}
|
|
}
|
|
add_shares.reset();
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::exchange()
|
|
{
|
|
CODE_LOCATION
|
|
prepare_exchange();
|
|
os[0].append(0);
|
|
if (os[0].get_length() > 0)
|
|
P.pass_around(os[0], os[1], 1);
|
|
this->rounds++;
|
|
check_received();
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::start_exchange()
|
|
{
|
|
prepare_exchange();
|
|
os[0].append(0);
|
|
P.send_relative(1, os[0]);
|
|
this->rounds++;
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::stop_exchange()
|
|
{
|
|
P.receive_relative(-1, os[1]);
|
|
check_received();
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::check_received()
|
|
{
|
|
if (not T::clear::binary or fast_mode)
|
|
if (os[1].left() < T::clear::size() * add_shares.left())
|
|
throw runtime_error("insufficient information received in Rep3");
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::add_mul(int n)
|
|
{
|
|
// counted in SubProcessor
|
|
// this->counter++;
|
|
if (T::clear::binary)
|
|
this->bit_counter += n < 0 ? T::default_length : n;
|
|
}
|
|
|
|
template<class T>
|
|
inline T Replicated<T>::finalize_mul(int n)
|
|
{
|
|
this->add_mul(n);
|
|
T result;
|
|
result[0] = add_shares.next();
|
|
if (T::clear::binary and not fast_mode)
|
|
result[1].unpack(os[1], n);
|
|
else
|
|
result[1].assign(os[1].consume_no_check(T::clear::size()));
|
|
return result;
|
|
}
|
|
|
|
template<class T>
|
|
T Replicated<T>::finalize_mul_fast()
|
|
{
|
|
this->add_mul(-1);
|
|
T result;
|
|
result[0] = add_shares.next();
|
|
result[1].assign(os[1].consume_no_check(T::clear::size()));
|
|
return result;
|
|
}
|
|
|
|
template<class T>
|
|
inline void Replicated<T>::init_dotprod()
|
|
{
|
|
init_mul();
|
|
dotprod_share.assign_zero();
|
|
}
|
|
|
|
template<class T>
|
|
inline void Replicated<T>::prepare_dotprod(const T& x, const T& y)
|
|
{
|
|
dotprod_share = dotprod_share.lazy_add(x.local_mul(y));
|
|
}
|
|
|
|
template<class T>
|
|
inline void Replicated<T>::next_dotprod()
|
|
{
|
|
dotprod_share.normalize();
|
|
prepare_reshare(dotprod_share);
|
|
dotprod_share.assign_zero();
|
|
}
|
|
|
|
template<class T>
|
|
inline T Replicated<T>::finalize_dotprod(int length)
|
|
{
|
|
(void) length;
|
|
this->dot_counter++;
|
|
return finalize_mul();
|
|
}
|
|
|
|
template<class T>
|
|
T Replicated<T>::get_random()
|
|
{
|
|
return ReplicatedBase::get_random<typename T::open_type>();
|
|
}
|
|
|
|
template<class T>
|
|
FixedVec<T, 2> ReplicatedBase::get_random()
|
|
{
|
|
FixedVec<T, 2> res;
|
|
randomize(res);
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void ReplicatedBase::randomize(FixedVec<T, 2>& res)
|
|
{
|
|
for (int i = 0; i < 2; i++)
|
|
res[i].randomize(shared_prngs[i]);
|
|
}
|
|
|
|
template<class T>
|
|
void ProtocolBase<T>::randoms_inst(StackedVector<T>& S,
|
|
const Instruction& instruction)
|
|
{
|
|
for (int j = 0; j < instruction.get_size(); j++)
|
|
{
|
|
auto& res = S[instruction.get_r(0) + j];
|
|
randoms(res, instruction.get_n());
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void Replicated<T>::randoms(T& res, int n_bits)
|
|
{
|
|
for (int i = 0; i < 2; i++)
|
|
res[i].randomize_part(shared_prngs[i], n_bits);
|
|
}
|
|
|
|
template<class T>
|
|
ReplicatedInput<T>& Replicated<T>::get_helper_input(size_t i)
|
|
{
|
|
while (i >= helper_inputs.size())
|
|
helper_inputs.push_back(new ReplicatedInput<T>(0, *this));
|
|
return *helper_inputs.at(i);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void Replicated<T>::trunc_pr(const vector<int>& regs, int size, U& proc,
|
|
false_type)
|
|
{
|
|
CODE_LOCATION
|
|
assert(regs.size() % 4 == 0);
|
|
assert(proc.P.num_players() == 3);
|
|
assert(proc.Proc != 0);
|
|
bool generate = P.my_num() == gen_player;
|
|
bool compute = P.my_num() == comp_player;
|
|
auto& S = proc.get_S();
|
|
TruncPrTupleList<T> infos(regs, S, size);
|
|
|
|
octetStream cs;
|
|
|
|
auto& input = get_helper_input();
|
|
input.reset_all(P);
|
|
|
|
// use https://eprint.iacr.org/2018/403
|
|
bool have_big_gap = infos.have_big_gap();
|
|
|
|
for (auto info : infos)
|
|
if (info.small_gap())
|
|
{
|
|
this->trunc_pr_counter += size;
|
|
}
|
|
else
|
|
{
|
|
have_big_gap = true;
|
|
this->trunc_pr_big_counter += size;
|
|
}
|
|
|
|
if (generate)
|
|
{
|
|
SeededPRNG G;
|
|
for (auto info : infos)
|
|
{
|
|
if (info.big_gap())
|
|
{
|
|
auto dest_it = info.dest_range.begin();
|
|
cs.reserve(size * value_type::size());
|
|
for (auto& x : info.source_range)
|
|
{
|
|
auto& y = *dest_it++;
|
|
auto r = this->shared_prngs[0].template get<value_type>();
|
|
y[1] = x.sum().signed_rshift(info.m) - r;
|
|
cs.store_no_resize(y[1]);
|
|
y[0] = r;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (have_big_gap)
|
|
P.send_to(comp_player, cs);
|
|
}
|
|
|
|
if (compute)
|
|
{
|
|
if (have_big_gap)
|
|
P.receive_player(gen_player, cs);
|
|
for (auto info : infos)
|
|
{
|
|
if (info.small_gap())
|
|
for (auto& x: info.source_range)
|
|
{
|
|
auto c = x.sum() + info.add_before() - 1;
|
|
input.add_mine(c >> info.m);
|
|
input.add_mine(c.msb());
|
|
}
|
|
else
|
|
{
|
|
auto dest_it = info.dest_range.begin();
|
|
if (cs.left() < size_t(value_type::size() * size))
|
|
throw runtime_error("insufficient data in trunc_pr");
|
|
|
|
for (auto& x: info.source_range)
|
|
{
|
|
auto& y = *dest_it++;
|
|
y[0] = cs.get_no_check<value_type>();
|
|
y[1] = x[1].signed_rshift(info.m);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (have_big_gap and not (compute or generate))
|
|
{
|
|
for (auto info : infos)
|
|
if (info.big_gap())
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
auto& x = S[info.source_base + i];
|
|
auto& y = S[info.dest_base + i];
|
|
y[0] = x[0].signed_rshift(info.m);
|
|
y[1] = this->shared_prngs[1].template get<value_type>();
|
|
}
|
|
}
|
|
|
|
#define X(I) if (P.my_num() == I) trunc_pr_finish<I>(infos, input);
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
}
|
|
|
|
template<class T>
|
|
template<int MY_NUM>
|
|
void Replicated<T>::trunc_pr_finish(TruncPrTupleList<T>& infos,
|
|
ReplicatedInput<T>& input)
|
|
{
|
|
if (infos.have_small_gap())
|
|
{
|
|
input.add_other(comp_player);
|
|
input.exchange();
|
|
auto& input2 = get_helper_input(1);
|
|
input2.reset_all(P);
|
|
|
|
for (auto info : infos)
|
|
if (info.small_gap())
|
|
{
|
|
auto it = info.source_range.begin();
|
|
for (auto& y : info.dest_range)
|
|
{
|
|
const int comp_offset = comp_player - MY_NUM;
|
|
auto c_prime = input.finalize_offset(comp_offset);
|
|
auto c_dprime = input.finalize_offset(comp_offset);
|
|
|
|
T r_prime, r_msb;
|
|
if (MY_NUM != comp_player)
|
|
{
|
|
int index = MY_NUM == 2 ? 0 : 1;
|
|
auto r = (*it++)[index];
|
|
r_prime[index] = r >> info.m;
|
|
r_msb[index] = r.msb();
|
|
input2.add_mine(r_msb.local_mul(c_dprime));
|
|
}
|
|
|
|
y = c_prime + r_prime;
|
|
y -= info.correction_shift(r_msb + c_dprime);
|
|
}
|
|
}
|
|
|
|
input2.add_other(0);
|
|
input2.add_other(2);
|
|
input2.exchange();
|
|
|
|
for (auto info : infos)
|
|
if (info.small_gap())
|
|
for (auto &y : info.dest_range)
|
|
y += info.correction_shift(input2.finalize_offset(2 - MY_NUM)
|
|
+ input2.finalize_offset(0 - MY_NUM))
|
|
- T::constant(info.subtract_after() - 1, MY_NUM);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void Replicated<T>::trunc_pr(const vector<int>& regs, int size, U& proc,
|
|
true_type)
|
|
{
|
|
(void) regs, (void) size, (void) proc;
|
|
throw runtime_error("trunc_pr not implemented");
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void Replicated<T>::trunc_pr(const vector<int>& regs, int size,
|
|
U& proc)
|
|
{
|
|
this->trunc_rounds++;
|
|
trunc_pr(regs, size, proc, T::clear::characteristic_two);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void Replicated<T>::unsplit(StackedVector<T>& dest,
|
|
StackedVector<typename T::bit_type>& source,
|
|
const Instruction& instruction)
|
|
{
|
|
CODE_LOCATION
|
|
int n_bits = instruction.get_size();
|
|
assert(instruction.get_start().size() > 0);
|
|
assert(instruction.get_start().size() <= 2);
|
|
auto& P = this->P;
|
|
|
|
auto& input = get_helper_input();
|
|
input.reset_all(P);
|
|
|
|
if (P.my_num() == 0)
|
|
{
|
|
auto range = BlockRange(source, instruction.get_r(0), n_bits);
|
|
input.prepare(n_bits);
|
|
auto dest_it = dest.iterator_for_size(instruction.get_start()[0],
|
|
n_bits);
|
|
for (auto& x : range)
|
|
for (auto bit : BitLeftRange(x.sum(), x, range))
|
|
input.add_mine_prepared(*dest_it++, bit);
|
|
}
|
|
else
|
|
input.add_other(0);
|
|
|
|
#define X(I) if (P.my_num() == I) unsplit_finish<I>(dest, source, instruction);
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
}
|
|
|
|
template<class T>
|
|
template<int MY_NUM>
|
|
void Replicated<T>::unsplit_finish(StackedVector<T>& dest,
|
|
StackedVector<typename T::bit_type>& source,
|
|
const Instruction& instruction)
|
|
{
|
|
int n_bits = instruction.get_size();
|
|
|
|
auto& input = get_helper_input();
|
|
input.exchange();
|
|
auto range = BlockRange(source, instruction.get_r(0), n_bits);
|
|
|
|
if (instruction.get_start().size() == 2)
|
|
{
|
|
if (MY_NUM != 0)
|
|
{
|
|
for (auto it : DoubleRange(dest, instruction.get_start()[0],
|
|
instruction.get_start()[1], n_bits))
|
|
{
|
|
auto& x = it.first;
|
|
x = input.finalize_offset(-MY_NUM);
|
|
it.second = {};
|
|
}
|
|
|
|
auto dest_it = dest.iterator_for_size(instruction.get_start()[1], n_bits);
|
|
for (auto& x : range)
|
|
for (auto bit : BitLeftRange(x[MY_NUM - 1], x, range))
|
|
(*dest_it++)[MY_NUM - 1] = bit;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
auto& input2 = get_helper_input(1);
|
|
input2.reset_all(P);
|
|
|
|
if (MY_NUM != 0)
|
|
{
|
|
auto dest_it = dest.iterator_for_size(instruction.get_start()[0],
|
|
n_bits);
|
|
input2.prepare(n_bits);
|
|
for (auto& x : range)
|
|
for (auto bit : BitLeftRange(x[MY_NUM - 1], x, range))
|
|
{
|
|
T a, c;
|
|
a[MY_NUM - 1] = bit;
|
|
T b = input.finalize_offset(-MY_NUM);
|
|
input2.add_mine_prepared(c, a.local_mul(b));
|
|
(*dest_it++) = a + b - 2 * c;
|
|
}
|
|
}
|
|
|
|
input2.add_other(1);
|
|
input2.add_other(2);
|
|
input2.exchange();
|
|
|
|
auto dest_range = Range(dest, instruction.get_start()[0], n_bits);
|
|
if (MY_NUM == 0)
|
|
for (auto& x : dest_range)
|
|
{
|
|
x -= 2
|
|
* (input2.finalize_offset(1 - MY_NUM)
|
|
+ input2.finalize_offset(2 - MY_NUM));
|
|
}
|
|
else
|
|
{
|
|
for (auto& x : dest_range)
|
|
x -= 2 * input2.finalize_offset(MY_NUM == 1 ? 1 : -1);
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif
|