Files
MP-SPDZ/Processor/ReplicatedPrep.hpp
2019-04-30 17:25:02 +10:00

314 lines
8.4 KiB
C++

/*
* ReplicatedPrep.cpp
*
*/
#include "ReplicatedPrep.h"
#include "Math/gfp.h"
template<class T>
RingPrep<T>::RingPrep(SubProcessor<T>* proc, DataPositions& usage) :
BufferPrep<T>(usage), protocol(0), proc(proc), base_player(0)
{
}
template<class T>
void RingPrep<T>::set_protocol(typename T::Protocol& protocol)
{
this->protocol = &protocol;
if (proc)
base_player = proc->Proc.thread_num;
}
template<class T>
void ReplicatedRingPrep<T>::buffer_triples()
{
assert(this->protocol != 0);
generate_triples(this->triples, this->buffer_size, this->protocol);
}
template<class T>
void generate_triples(vector<array<T, 3>>& triples, int n_triples,
typename T::Protocol* protocol)
{
triples.resize(n_triples);
protocol->init_mul();
for (size_t i = 0; i < triples.size(); i++)
{
auto& triple = triples[i];
triple[0] = protocol->get_random();
triple[1] = protocol->get_random();
protocol->prepare_mul(triple[0], triple[1]);
}
protocol->exchange();
for (size_t i = 0; i < triples.size(); i++)
triples[i][2] = protocol->finalize_mul();
}
template<class T>
void BufferPrep<T>::get_three_no_count(Dtype dtype, T& a, T& b, T& c)
{
if (dtype != DATA_TRIPLE)
throw not_implemented();
if (triples.empty())
buffer_triples();
a = triples.back()[0];
b = triples.back()[1];
c = triples.back()[2];
triples.pop_back();
}
template<class T>
void RingPrep<T>::buffer_squares()
{
auto proc = this->proc;
auto buffer_size = this->buffer_size;
assert(proc != 0);
vector<T> a_plus_b(buffer_size), as(buffer_size), cs(buffer_size);
T b;
for (int i = 0; i < buffer_size; i++)
{
this->get_three_no_count(DATA_TRIPLE, as[i], b, cs[i]);
a_plus_b[i] = as[i] + b;
}
vector<typename T::open_type> opened(buffer_size);
proc->MC.POpen(opened, a_plus_b, proc->P);
for (int i = 0; i < buffer_size; i++)
this->squares.push_back({{as[i], as[i] * opened[i] - cs[i]}});
}
template<class T>
void ReplicatedRingPrep<T>::buffer_squares()
{
auto protocol = this->protocol;
auto proc = this->proc;
assert(protocol != 0);
auto& squares = this->squares;
squares.resize(this->buffer_size);
protocol->init_mul(proc);
for (size_t i = 0; i < squares.size(); i++)
{
auto& square = squares[i];
square[0] = protocol->get_random();
protocol->prepare_mul(square[0], square[0]);
}
protocol->exchange();
for (size_t i = 0; i < squares.size(); i++)
squares[i][1] = protocol->finalize_mul();
}
template<class T>
void ReplicatedPrep<T>::buffer_inverses()
{
auto protocol = this->protocol;
assert(protocol != 0);
typename T::MAC_Check MC;
BufferPrep<T>::buffer_inverses(MC, protocol->P);
}
template<class T>
void BufferPrep<T>::buffer_inverses(MAC_Check_Base<T>& MC, Player& P)
{
vector<array<T, 3>> triples(buffer_size);
vector<T> c;
for (int i = 0; i < buffer_size; i++)
{
get_three_no_count(DATA_TRIPLE, triples[i][0], triples[i][1], triples[i][2]);
c.push_back(triples[i][2]);
}
vector<typename T::open_type> c_open;
MC.POpen(c_open, c, P);
for (size_t i = 0; i < c.size(); i++)
if (c_open[i] != 0)
inverses.push_back({{triples[i][0], triples[i][1] / c_open[i]}});
triples.clear();
if (inverses.empty())
throw runtime_error("products were all zero");
MC.Check(P);
}
template<class T>
void BufferPrep<T>::get_two_no_count(Dtype dtype, T& a, T& b)
{
switch (dtype)
{
case DATA_SQUARE:
{
if (squares.empty())
buffer_squares();
a = squares.back()[0];
b = squares.back()[1];
squares.pop_back();
return;
}
case DATA_INVERSE:
{
while (inverses.empty())
buffer_inverses();
a = inverses.back()[0];
b = inverses.back()[1];
inverses.pop_back();
return;
}
default:
throw not_implemented();
}
}
template<class T>
void XOR(vector<T>& res, vector<T>& x, vector<T>& y, int buffer_size,
typename T::Protocol& prot, SubProcessor<T>* proc)
{
prot.init_mul(proc);
for (int i = 0; i < buffer_size; i++)
prot.prepare_mul(x[i], y[i]);
prot.exchange();
res.resize(buffer_size);
typename T::clear two = typename T::clear(1) + typename T::clear(1);
for (int i = 0; i < buffer_size; i++)
res[i] = x[i] + y[i] - prot.finalize_mul() * two;
}
template<template<class U> class T>
void buffer_bits_spec(ReplicatedPrep<T<gfp>>& prep, vector<T<gfp>>& bits,
typename T<gfp>::Protocol& prot)
{
(void) bits, (void) prot;
if (prot.get_n_relevant_players() > 10)
{
vector<array<T<gfp>, 2>> squares(prep.buffer_size);
vector<T<gfp>> s;
for (int i = 0; i < prep.buffer_size; i++)
{
prep.get_two(DATA_SQUARE, squares[i][0], squares[i][1]);
s.push_back(squares[i][1]);
}
vector<gfp> open;
typename T<gfp>::MAC_Check().POpen(open, s, prot.P);
auto one = T<gfp>(1, prot.P.my_num());
for (size_t i = 0; i < s.size(); i++)
if (open[i] != 0)
bits.push_back((squares[i][0] / open[i].sqrRoot() + one) / 2);
squares.clear();
if (bits.empty())
throw runtime_error("squares were all zero");
}
else
prep.ReplicatedRingPrep<T<gfp>>::buffer_bits();
}
template<class T>
void RingPrep<T>::buffer_bits()
{
assert(protocol != 0);
auto buffer_size = this->buffer_size;
auto& bits = this->bits;
auto& P = protocol->P;
int n_relevant_players = protocol->get_n_relevant_players();
vector<vector<T>> player_bits(n_relevant_players, vector<T>(buffer_size));
typename T::Input input(proc, P);
for (int i = 0; i < P.num_players(); i++)
input.reset(i);
for (int i = 0; i < n_relevant_players; i++)
{
int input_player = (base_player + i) % P.num_players();
if (input_player == P.my_num())
{
SeededPRNG G;
for (int i = 0; i < buffer_size; i++)
input.add_mine(G.get_bit());
input.send_mine();
for (auto& x : player_bits[i])
x = input.finalize_mine();
}
else
{
for (int i = 0; i < buffer_size; i++)
input.add_other(input_player);
octetStream os;
P.receive_player(input_player, os, true);
for (auto& x : player_bits[i])
input.finalize_other(input_player, x, os);
}
}
auto& prot = *protocol;
XOR(bits, player_bits[0], player_bits[1], buffer_size, prot, proc);
for (int i = 2; i < n_relevant_players; i++)
XOR(bits, bits, player_bits[i], buffer_size, prot, proc);
base_player++;
}
template<>
inline
void RingPrep<Rep3Share<gf2n>>::buffer_bits()
{
assert(protocol != 0);
for (int i = 0; i < DIV_CEIL(buffer_size, gf2n::degree()); i++)
{
Rep3Share<gf2n> share = protocol->get_random();
for (int j = 0; j < gf2n::degree(); j++)
{
bits.push_back(share & 1);
share >>= 1;
}
}
}
template<template<class U> class T>
void buffer_bits_spec(ReplicatedPrep<T<gf2n>>& prep, vector<T<gf2n>>& bits,
typename T<gf2n>::Protocol& prot)
{
(void) bits, (void) prot;
prep.ReplicatedRingPrep<T<gf2n>>::buffer_bits();
}
template<class T>
void ReplicatedPrep<T>::buffer_bits()
{
assert(this->protocol != 0);
buffer_bits_spec(*this, this->bits, *this->protocol);
}
template<class T>
void BufferPrep<T>::get_one_no_count(Dtype dtype, T& a)
{
if (dtype != DATA_BIT)
throw not_implemented();
while (bits.empty())
buffer_bits();
a = bits.back();
bits.pop_back();
}
template<class T>
void BufferPrep<T>::get_input_no_count(T& a, typename T::open_type& x, int i)
{
(void) a, (void) x, (void) i;
if (inputs.size() <= (size_t)i or inputs.at(i).empty())
buffer_inputs(i);
a = inputs[i].back().share;
x = inputs[i].back().value;
inputs[i].pop_back();
}
template<class T>
inline void BufferPrep<T>::buffer_inputs(int player)
{
(void) player;
throw not_implemented();
}
template<class T>
void BufferPrep<T>::get_no_count(vector<T>& S, DataTag tag,
const vector<int>& regs, int vector_size)
{
(void) S, (void) tag, (void) regs, (void) vector_size;
throw not_implemented();
}