Files
MP-SPDZ/Protocols/Replicated.hpp
2023-12-14 12:17:54 +11:00

429 lines
10 KiB
C++

/*
* Replicated.cpp
*
*/
#ifndef PROTOCOLS_REPLICATED_HPP_
#define PROTOCOLS_REPLICATED_HPP_
#include "Replicated.h"
#include "Processor/Processor.h"
#include "Processor/TruncPrTuple.h"
#include "Tools/benchmarking.h"
#include "Tools/Bundle.h"
#include "ReplicatedInput.h"
#include "Rep3Share2k.h"
#include "ReplicatedPO.hpp"
#include "Math/Z2k.hpp"
template<class T>
ProtocolBase<T>::ProtocolBase() :
trunc_pr_counter(0), rounds(0), trunc_rounds(0), dot_counter(0),
bit_counter(0), counter(0)
{
}
template<class T>
Replicated<T>::Replicated(Player& P) : ReplicatedBase(P)
{
assert(T::vector_length == 2);
}
template<class T>
Replicated<T>::Replicated(const ReplicatedBase& other) :
ReplicatedBase(other)
{
}
inline ReplicatedBase::ReplicatedBase(Player& P) : P(P)
{
assert(P.num_players() == 3);
if (not P.is_encrypted())
insecure("unencrypted communication", false);
shared_prngs[0].ReSeed();
octetStream os;
os.append(shared_prngs[0].get_seed(), SEED_SIZE);
P.send_relative(1, os);
P.receive_relative(-1, os);
shared_prngs[1].SetSeed(os.get_data());
}
inline ReplicatedBase::ReplicatedBase(Player& P, array<PRNG, 2>& prngs) :
P(P)
{
for (int i = 0; i < 2; i++)
shared_prngs[i].SetSeed(prngs[i]);
}
inline ReplicatedBase ReplicatedBase::branch() const
{
return {P, shared_prngs};
}
template<class T>
ProtocolBase<T>::~ProtocolBase()
{
#ifdef VERBOSE_COUNT
if (counter or rounds)
cerr << "Number of " << T::type_string() << " multiplications: "
<< counter << " (" << bit_counter << " bits) in " << rounds
<< " rounds" << endl;
if (counter or rounds)
cerr << "Number of " << T::type_string() << " dot products: " << dot_counter << endl;
if (trunc_pr_counter or trunc_rounds)
cerr << "Number of probabilistic truncations: " << trunc_pr_counter << " in " << trunc_rounds << " rounds" << endl;
#endif
}
template<class T>
void ProtocolBase<T>::mulrs(const vector<int>& reg,
SubProcessor<T>& proc)
{
proc.mulrs(reg);
}
template<class T>
void ProtocolBase<T>::multiply(vector<T>& products,
vector<pair<T, T> >& multiplicands, int begin, int end,
SubProcessor<T>& proc)
{
#ifdef VERBOSE_CENTRAL
fprintf(stderr, "multiply from %d to %d in %d\n", begin, end,
BaseMachine::thread_num);
#endif
init(proc.DataF, proc.MC);
init_mul();
for (int i = begin; i < end; i++)
prepare_mul(multiplicands[i].first, multiplicands[i].second);
exchange();
for (int i = begin; i < end; i++)
products[i] = finalize_mul();
}
template<class T>
T ProtocolBase<T>::mul(const T& x, const T& y)
{
init_mul();
prepare_mul(x, y);
exchange();
return finalize_mul();
}
template<class T>
void ProtocolBase<T>::prepare_mult(const T& x, const T& y, int n,
bool)
{
prepare_mul(x, y, n);
}
template<class T>
void ProtocolBase<T>::finalize_mult(T& res, int n)
{
res = finalize_mul(n);
}
template<class T>
T ProtocolBase<T>::finalize_dotprod(int length)
{
counter += length;
dot_counter++;
T res;
for (int i = 0; i < length; i++)
res += finalize_mul();
return res;
}
template<class T>
T ProtocolBase<T>::get_random()
{
if (random.empty())
{
buffer_random();
assert(not random.empty());
}
auto res = random.back();
random.pop_back();
return res;
}
template<class T>
vector<int> ProtocolBase<T>::get_relevant_players()
{
vector<int> res;
int n = dynamic_cast<typename T::Protocol&>(*this).P.num_players();
for (int i = 0; i < T::threshold(n) + 1; i++)
res.push_back(i);
return res;
}
template<class T>
void Replicated<T>::init_mul()
{
for (auto& o : os)
o.reset_write_head();
add_shares.clear();
}
template<class T>
void Replicated<T>::prepare_mul(const T& x,
const T& y, int n)
{
typename T::value_type add_share = x.local_mul(y);
prepare_reshare(add_share, n);
}
template<class T>
void Replicated<T>::prepare_reshare(const typename T::clear& share,
int n)
{
typename T::value_type tmp[2];
for (int i = 0; i < 2; i++)
tmp[i].randomize(shared_prngs[i], n);
auto add_share = share + tmp[0] - tmp[1];
add_share.pack(os[0], n);
add_shares.push_back(add_share);
}
template<class T>
void Replicated<T>::exchange()
{
os[0].append(0);
if (os[0].get_length() > 0)
P.pass_around(os[0], os[1], 1);
this->rounds++;
}
template<class T>
void Replicated<T>::start_exchange()
{
os[0].append(0);
P.send_relative(1, os[0]);
this->rounds++;
}
template<class T>
void Replicated<T>::stop_exchange()
{
P.receive_relative(-1, os[1]);
}
template<class T>
inline T Replicated<T>::finalize_mul(int n)
{
this->counter++;
this->bit_counter += n;
T result;
result[0] = add_shares.next();
result[1].unpack(os[1], n);
return result;
}
template<class T>
inline void Replicated<T>::init_dotprod()
{
init_mul();
dotprod_share.assign_zero();
}
template<class T>
inline void Replicated<T>::prepare_dotprod(const T& x, const T& y)
{
dotprod_share = dotprod_share.lazy_add(x.local_mul(y));
}
template<class T>
inline void Replicated<T>::next_dotprod()
{
dotprod_share.normalize();
prepare_reshare(dotprod_share);
dotprod_share.assign_zero();
}
template<class T>
inline T Replicated<T>::finalize_dotprod(int length)
{
(void) length;
this->dot_counter++;
return finalize_mul();
}
template<class T>
T Replicated<T>::get_random()
{
T res;
for (int i = 0; i < 2; i++)
res[i].randomize(shared_prngs[i]);
return res;
}
template<class T>
void ProtocolBase<T>::randoms_inst(vector<T>& S,
const Instruction& instruction)
{
for (int j = 0; j < instruction.get_size(); j++)
{
auto& res = S[instruction.get_r(0) + j];
randoms(res, instruction.get_n());
}
}
template<class T>
void Replicated<T>::randoms(T& res, int n_bits)
{
for (int i = 0; i < 2; i++)
res[i].randomize_part(shared_prngs[i], n_bits);
}
template<class T>
template<class U>
void Replicated<T>::trunc_pr(const vector<int>& regs, int size, U& proc,
false_type)
{
assert(regs.size() % 4 == 0);
assert(proc.P.num_players() == 3);
assert(proc.Proc != 0);
typedef typename T::clear value_type;
int gen_player = 2;
int comp_player = 1;
bool generate = P.my_num() == gen_player;
bool compute = P.my_num() == comp_player;
ArgList<TruncPrTupleWithGap<value_type>> infos(regs);
auto& S = proc.get_S();
octetStream cs;
ReplicatedInput<T> input(0, *this);
// use https://eprint.iacr.org/2019/131
bool have_small_gap = false;
// use https://eprint.iacr.org/2018/403
bool have_big_gap = false;
for (auto info : infos)
if (info.small_gap())
have_small_gap = true;
else
have_big_gap = true;
if (generate)
{
SeededPRNG G;
for (auto info : infos)
for (int i = 0; i < size; i++)
{
auto& x = S[info.source_base + i];
if (info.small_gap())
{
auto r = G.get<value_type>();
input.add_mine(info.upper(r));
input.add_mine(info.msb(r));
(r + x[0]).pack(cs);
}
else
{
auto& y = S[info.dest_base + i];
auto r = this->shared_prngs[0].template get<value_type>();
y[1] = -value_type(-value_type(x.sum()) >> info.m) - r;
y[1].pack(cs);
y[0] = r;
}
}
P.send_to(comp_player, cs);
}
else if (have_small_gap)
input.add_other(gen_player);
if (compute)
{
P.receive_player(gen_player, cs);
for (auto info : infos)
for (int i = 0; i < size; i++)
{
auto& x = S[info.source_base + i];
if (info.small_gap())
{
auto c = cs.get<value_type>() + x.sum();
input.add_mine(info.upper(c));
input.add_mine(info.msb(c));
}
else
{
auto& y = S[info.dest_base + i];
y[0] = cs.get<value_type>();
y[1] = x[1] >> info.m;
}
}
}
if (have_big_gap and not (compute or generate))
{
for (auto info : infos)
if (info.big_gap())
for (int i = 0; i < size; i++)
{
auto& x = S[info.source_base + i];
auto& y = S[info.dest_base + i];
y[0] = x[0] >> info.m;
y[1] = this->shared_prngs[1].template get<value_type>();
}
}
if (have_small_gap)
{
input.add_other(comp_player);
input.exchange();
init_mul();
for (auto info : infos)
for (int i = 0; i < size; i++)
{
if (info.small_gap())
{
this->trunc_pr_counter++;
auto c_prime = input.finalize(comp_player);
auto r_prime = input.finalize(gen_player);
S[info.dest_base + i] = c_prime - r_prime;
auto c_dprime = input.finalize(comp_player);
auto r_msb = input.finalize(gen_player);
S[info.dest_base + i] += ((r_msb + c_dprime)
<< (info.k - info.m));
prepare_mul(r_msb, c_dprime);
}
}
exchange();
for (auto info : infos)
for (int i = 0; i < size; i++)
if (info.small_gap())
S[info.dest_base + i] -= finalize_mul()
<< (info.k - info.m + 1);
}
}
template<class T>
template<class U>
void Replicated<T>::trunc_pr(const vector<int>& regs, int size, U& proc,
true_type)
{
(void) regs, (void) size, (void) proc;
throw runtime_error("trunc_pr not implemented");
}
template<class T>
template<class U>
void Replicated<T>::trunc_pr(const vector<int>& regs, int size,
U& proc)
{
this->trunc_rounds++;
trunc_pr(regs, size, proc, T::clear::characteristic_two);
}
#endif