mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
1140 lines
27 KiB
C++
1140 lines
27 KiB
C++
/*
|
|
* Astra.cpp
|
|
*
|
|
*/
|
|
|
|
#ifndef PROTOCOLS_ASTRA_HPP_
|
|
#define PROTOCOLS_ASTRA_HPP_
|
|
|
|
#include "Astra.h"
|
|
#include "Tools/files.h"
|
|
#include "Tools/ranges.h"
|
|
|
|
template<class T>
|
|
string AstraBase<T>::get_filename(bool preprocessing, const char* name)
|
|
{
|
|
int base = 1 - preprocessing;
|
|
auto res = get_prep_sub_dir<T>(P.num_players() + base, preprocessing) + name
|
|
+ suffix + "-P" + to_string(P.my_num() + base) + "-T"
|
|
+ to_string(BaseMachine::thread_num);
|
|
if (OnlineOptions::singleton.has_option("verbose_astra"))
|
|
fprintf(stderr, "astra filename %s\n", res.c_str());
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
string AstraBase<T>::get_output_filename()
|
|
{
|
|
|
|
return get_filename(not T::real_shares(P), "Outputs");
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::debug()
|
|
{
|
|
if (OnlineOptions::singleton.has_option("verbose_astra"))
|
|
fprintf(stderr, "astra comm %s %p %d %s\n", T::type_string().c_str(), this,
|
|
BaseMachine::thread_num, suffix.c_str());
|
|
}
|
|
|
|
template<class T>
|
|
AstraBase<T>::AstraBase(Player& P) :
|
|
n_mults(0), P(P)
|
|
{
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::set_suffix(const string& suffix)
|
|
{
|
|
this->suffix = "-" + suffix;
|
|
init_prep();
|
|
}
|
|
|
|
template<class T>
|
|
AstraBase<T>::~AstraBase()
|
|
{
|
|
}
|
|
|
|
template<class T>
|
|
Astra<T>::Astra(Player& P) :
|
|
AstraOnlineBase<T>(P)
|
|
{
|
|
}
|
|
|
|
template<class T>
|
|
AstraOnlineBase<T>::AstraOnlineBase(Player& P) :
|
|
AstraBase<T>(P)
|
|
{
|
|
astra_num = P.my_num() + 1;
|
|
}
|
|
|
|
template<class T>
|
|
AstraPrepProtocol<T>::AstraPrepProtocol(Player& P) :
|
|
AstraBase<T>(P), unsplit_input(0), prng_protocol(P),
|
|
prng_protocol_for_input0(P)
|
|
{
|
|
my_num = P.my_num();
|
|
}
|
|
|
|
template<class T>
|
|
AstraPrepProtocol<T>::~AstraPrepProtocol()
|
|
{
|
|
prng_protocol.output_time<T>();
|
|
prng_protocol_for_input0.output_time<T>();
|
|
if (unsplit_input)
|
|
delete unsplit_input;
|
|
}
|
|
|
|
template<class T>
|
|
void AstraOnlineBase<T>::init_prep()
|
|
{
|
|
try
|
|
{
|
|
open_with_check(prep, this->get_filename(false));
|
|
}
|
|
catch (...)
|
|
{
|
|
throw runtime_error(
|
|
"Error with preprocessing in " + this->get_filename(false)
|
|
+ ". You need to run the preprocessing before "
|
|
"the online phase or in parallel.");
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::init_prep()
|
|
{
|
|
if (this->P.my_num() > 0)
|
|
prep.open(this->get_filename(true));
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::init_mul()
|
|
{
|
|
input_pairs.clear();
|
|
inputs.clear();
|
|
results.clear();
|
|
n_mults = 0;
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::prepare_mul(const T& x, const T& y, int)
|
|
{
|
|
this->input_pairs.push_back({x, y});
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::init_dotprod()
|
|
{
|
|
init_mul();
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::prepare_dotprod(const T& x, const T& y)
|
|
{
|
|
prepare_mul(x, y);
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::next_dotprod()
|
|
{
|
|
#define X(Pi) \
|
|
if (my_astra_num() == Pi) \
|
|
{ \
|
|
this->inputs.push_back({}); \
|
|
for (auto& x : this->input_pairs) \
|
|
this->inputs.back() += x[0].local_mul_P##Pi(x[1]); \
|
|
}
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
|
|
n_mults++;
|
|
this->input_pairs.clear();
|
|
}
|
|
|
|
template<class T>
|
|
T AstraBase<T>::finalize_dotprod(int)
|
|
{
|
|
return this->finalize_mul();
|
|
}
|
|
|
|
template<class T>
|
|
T AstraPrepShare<T>::local_mul_P0(const AstraPrepShare& other) const
|
|
{
|
|
return this->sum() * other.sum();
|
|
}
|
|
|
|
template<class T>
|
|
T AstraPrepShare<T>::local_mul_P1(const AstraPrepShare& other) const
|
|
{
|
|
return local_mul_P0(other);
|
|
}
|
|
|
|
template<class T>
|
|
T AstraPrepShare<T>::local_mul_P2(const AstraPrepShare&) const
|
|
{
|
|
return {};
|
|
}
|
|
|
|
template<class T>
|
|
T AstraShare<T>::local_mul_P0(const AstraShare&) const
|
|
{
|
|
throw runtime_error("P0 should be absent");
|
|
}
|
|
|
|
template<class T>
|
|
T AstraShare<T>::local_mul_P1(const AstraShare& other) const
|
|
{
|
|
return m(1) * other.neg_lambda(1) + other.m(1) * neg_lambda(1);
|
|
}
|
|
|
|
template<class T>
|
|
T AstraShare<T>::local_mul_P2(const AstraShare& other) const
|
|
{
|
|
return m(2) * other.m(2) + local_mul_P1(other);
|
|
}
|
|
|
|
template<class T>
|
|
void AstraBase<T>::prepare_exchange()
|
|
{
|
|
n_mults += this->input_pairs.size();
|
|
}
|
|
|
|
template<class T>
|
|
template<int my_num>
|
|
void AstraPrepProtocol<T>::pre()
|
|
{
|
|
auto& inputs = this->inputs;
|
|
auto& results = this->results;
|
|
auto& n_mults = this->n_mults;
|
|
|
|
n_mults += this->input_pairs.size();
|
|
results.reserve(n_mults);
|
|
|
|
if (my_num == 0)
|
|
{
|
|
os.clear();
|
|
os.reserve(n_mults * T::open_type::size());
|
|
}
|
|
else
|
|
{
|
|
os_prep.clear();
|
|
os_prep.reserve(n_mults * 2 * T::open_type::size());
|
|
}
|
|
|
|
T res;
|
|
open_type gamma;
|
|
|
|
for (auto& input : inputs)
|
|
{
|
|
pre_element<my_num>(res);
|
|
pre_gamma<my_num>(res, gamma, input);
|
|
results.push_back(res);
|
|
}
|
|
|
|
for (auto& x : this->input_pairs)
|
|
{
|
|
open_type input;
|
|
#define X(I) if (my_num == I) input = x[0].local_mul_P##I(x[1]);
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
pre_element<my_num>(res);
|
|
pre_gamma<my_num>(res, gamma, input);
|
|
results.push_back(res);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<int my_num>
|
|
void AstraPrepProtocol<T>::pre_element(T& res)
|
|
{
|
|
if (my_num == 0)
|
|
for (int i = 0; i < 2; i++)
|
|
res[i].randomize(prng_protocol.shared_prngs[i]);
|
|
else
|
|
res[1].randomize(prng_protocol.shared_prngs[2 - my_num]);
|
|
}
|
|
|
|
template<class T>
|
|
template<int my_num>
|
|
void AstraPrepProtocol<T>::pre_gamma(T& res, open_type& gamma, const open_type& input)
|
|
{
|
|
if (my_num < 2)
|
|
{
|
|
gamma.randomize(prng_protocol.shared_prngs[my_num]);
|
|
if (my_num == 0)
|
|
os.store_no_resize(input - gamma);
|
|
else
|
|
post(res, gamma);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::exchange()
|
|
{
|
|
CODE_LOCATION
|
|
if (OnlineOptions::singleton.has_option("verbose_astra"))
|
|
fprintf(stderr, "astra exchange %zu\n", this->inputs.size());
|
|
|
|
auto& P = this->P;
|
|
auto& results = this->results;
|
|
int my_num = P.my_num();
|
|
assert(results.size() == 0);
|
|
|
|
#define X(I) if (my_num == I) pre<I>();
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
|
|
if (my_num == 0)
|
|
P.send_to(2, os);
|
|
else if (my_num == 2)
|
|
P.receive_player(0, os);
|
|
|
|
if (my_num == 2)
|
|
{
|
|
if (os.left() < open_type::size() * results.size())
|
|
throw runtime_error("insufficient data in Astra");
|
|
for (auto& res : results)
|
|
post(res, os.get_no_check<open_type>());
|
|
}
|
|
|
|
store(os_prep);
|
|
results.reset();
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::post(T& res, const open_type& gamma)
|
|
{
|
|
os_prep.store_no_resize(gamma);
|
|
os_prep.store_no_resize(res[1]);
|
|
}
|
|
|
|
template<class T>
|
|
T Astra<T>::pre(const open_type& input)
|
|
{
|
|
int my_num = this->astra_num;
|
|
typename T::open_type gamma;
|
|
gamma = os_prep.get_no_check<open_type>();
|
|
T res;
|
|
res[1] = os_prep.get_no_check<open_type>();
|
|
auto m_z = input - res.neg_lambda(my_num) + gamma;
|
|
os.store_no_resize(m_z);
|
|
res.m(my_num) = m_z;
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void Astra<T>::exchange()
|
|
{
|
|
CODE_LOCATION
|
|
if (OnlineOptions::singleton.has_option("verbose_astra"))
|
|
fprintf(stderr, "astra exchange %zu\n", this->inputs.size());
|
|
|
|
auto& P = this->P;
|
|
auto& inputs = this->inputs;
|
|
auto& results = this->results;
|
|
assert(results.size() == 0);
|
|
|
|
size_t n_mults = this->inputs.size() + this->input_pairs.size();
|
|
|
|
int my_num = this->my_astra_num();
|
|
if (my_num > 0)
|
|
{
|
|
this->read(os_prep);
|
|
os.clear();
|
|
os.reserve(n_mults * T::open_type::size());
|
|
|
|
if (os_prep.left() < open_type::size() * n_mults)
|
|
throw runtime_error("insufficient preprocessing");
|
|
|
|
for (auto& input : inputs)
|
|
results.push_back(pre(input));
|
|
|
|
if (my_num == 1)
|
|
for (auto& x : this->input_pairs)
|
|
results.push_back(pre(x[0].local_mul_P1(x[1])));
|
|
else
|
|
for (auto& x : this->input_pairs)
|
|
results.push_back(pre(x[0].local_mul_P2(x[1])));
|
|
|
|
octetStream recv_os;
|
|
P.exchange(1 - P.my_num(), os, recv_os);
|
|
|
|
if (recv_os.left() < open_type::size() * results.size())
|
|
throw runtime_error("insufficient data in Astra");
|
|
|
|
for (auto& res : results)
|
|
res.m(my_num) += recv_os.get_no_check<typename T::open_type>();
|
|
|
|
assert(not os_prep.left());
|
|
}
|
|
|
|
results.reset();
|
|
}
|
|
|
|
template<class T>
|
|
T Astra<T>::finalize_mul(int)
|
|
{
|
|
return this->results.next();
|
|
}
|
|
|
|
template<class T>
|
|
T AstraPrepProtocol<T>::finalize_mul(int)
|
|
{
|
|
return this->results.next();
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
U AstraOnlineBase<T>::read()
|
|
{
|
|
U res;
|
|
octetStream os;
|
|
read(os);
|
|
os.get(res);
|
|
assert(not os.left());
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void AstraOnlineBase<T>::read(octetStream& os)
|
|
{
|
|
if (not prep.is_open())
|
|
init_prep();
|
|
Timer timer;
|
|
TimeScope ts(timer);
|
|
this->debug();
|
|
os.input(prep);
|
|
if (not prep.good())
|
|
throw runtime_error("error in preprocessing reading");
|
|
this->P.comm_stats["Preprocessing transmission"].add(os, ts);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void AstraPrepProtocol<T>::store(const U& value)
|
|
{
|
|
if (this->P.my_num() > 0)
|
|
{
|
|
octetStream os;
|
|
os.store(value);
|
|
store(os);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::store(const octetStream& os)
|
|
{
|
|
if (this->P.my_num() > 0)
|
|
{
|
|
if (not prep.is_open())
|
|
init_prep();
|
|
TimeScope ts(this->P.comm_stats["Preprocessing transmission"].add(os));
|
|
this->debug();
|
|
os.output(prep);
|
|
prep.flush();
|
|
if (not prep.good())
|
|
throw runtime_error("error in preprocessing storing");
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void AstraPrepProtocol<T>::sync(vector<U>& values, Player& P)
|
|
{
|
|
if (P.my_num() == 1)
|
|
{
|
|
if (not outputs.is_open())
|
|
{
|
|
try
|
|
{
|
|
open_with_check(outputs, this->get_output_filename());
|
|
}
|
|
catch (...)
|
|
{
|
|
throw runtime_error("Error with output back channel. "
|
|
"This only works when preprocessing is run in parallel.");
|
|
}
|
|
}
|
|
|
|
Timer timer;
|
|
TimeScope ts(timer);
|
|
octetStream os;
|
|
os.input(outputs);
|
|
os.get(values);
|
|
this->P.comm_stats["Output transmission"].add(os, ts);
|
|
P.send_all(os);
|
|
}
|
|
else
|
|
{
|
|
octetStream os;
|
|
P.receive_player(1, os);
|
|
os.get(values);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void AstraOnlineBase<T>::sync(vector<U>& values, Player& P)
|
|
{
|
|
if (P.my_num() == 0)
|
|
{
|
|
if (not outputs.is_open())
|
|
outputs.open(this->get_output_filename());
|
|
|
|
octetStream os;
|
|
os.store(values);
|
|
TimeScope ts(this->P.comm_stats["Output transmission"].add(os));
|
|
os.output(outputs);
|
|
outputs.flush();
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void AstraPrepProtocol<T>::forward_sync(vector<U>& values)
|
|
{
|
|
store(values);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
void AstraOnlineBase<T>::forward_sync(vector<U>& values)
|
|
{
|
|
values = read<vector<U>>();
|
|
}
|
|
|
|
template<class T>
|
|
int AstraPrepProtocol<T>::rep_index(int i)
|
|
{
|
|
return rep_index(i, my_num);
|
|
}
|
|
|
|
template<class T>
|
|
int AstraPrepProtocol<T>::rep_index(int i, int my_num)
|
|
{
|
|
if (my_num == 0)
|
|
return i;
|
|
else
|
|
{
|
|
if (i == 0)
|
|
// m
|
|
return my_num - 1;
|
|
else
|
|
// lambda
|
|
return 2 - my_num;
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
AstraPrepShare<U> AstraPrepProtocol<T>::from_rep3(const FixedVec<U, 2>& x)
|
|
{
|
|
return from_rep3(x, my_num);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U>
|
|
AstraPrepShare<U> AstraPrepProtocol<T>::from_rep3(const FixedVec<U, 2>& x,
|
|
int my_num)
|
|
{
|
|
AstraPrepShare<U> res;
|
|
for (int i = 0; i < 2; i++)
|
|
res[i] = x[rep_index(i, my_num)];
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
T AstraPrepProtocol<T>::get_random()
|
|
{
|
|
T res;
|
|
for (int i = 0; i < 2; i++)
|
|
res[i].randomize(prng_protocol.shared_prngs[rep_index(i)]);
|
|
store(res);
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
T AstraOnlineBase<T>::get_random()
|
|
{
|
|
return read<T>();
|
|
}
|
|
|
|
template<class T>
|
|
void AstraOnlineBase<T>::randoms_inst(StackedVector<T>& S, const Instruction& instruction)
|
|
{
|
|
octetStream os;
|
|
read(os);
|
|
|
|
for (int j = 0; j < instruction.get_size(); j++)
|
|
{
|
|
auto& res = S[instruction.get_r(0) + j];
|
|
res.unpack(os);
|
|
}
|
|
|
|
assert(not os.left());
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::randoms_inst(StackedVector<T>& S,
|
|
const Instruction& instruction)
|
|
{
|
|
octetStream os;
|
|
|
|
for (int j = 0; j < instruction.get_size(); j++)
|
|
{
|
|
auto& res = S[instruction.get_r(0) + j];
|
|
for (int i = 0; i < 2; i++)
|
|
res[i].randomize_part(
|
|
prng_protocol.shared_prngs[rep_index(i)],
|
|
instruction.get_n());
|
|
res.pack(os);
|
|
}
|
|
|
|
store(os);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void AstraOnlineBase<T>::trunc_pr(const vector<int>& regs, int size,
|
|
SubProcessor<T>& proc, false_type)
|
|
{
|
|
TruncPrTupleList<T> infos(regs, proc.get_S(), size);
|
|
|
|
if (infos.have_big_gap())
|
|
trunc_pr_big_gap(infos.get_big_gap(), size, proc);
|
|
if (infos.have_small_gap())
|
|
this->trunc_pr_small_gap(infos.get_small_gap(), size, proc);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void AstraPrepProtocol<T>::trunc_pr(const vector<int>& regs, int size,
|
|
SubProcessor<T>& proc, false_type)
|
|
{
|
|
TruncPrTupleList<T> infos(regs, proc.get_S(), size);
|
|
|
|
if (infos.have_big_gap())
|
|
trunc_pr_big_gap(infos.get_big_gap(), size, proc);
|
|
if (infos.have_small_gap())
|
|
this->trunc_pr_small_gap(infos.get_small_gap(), size, proc);
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void AstraOnlineBase<T>::trunc_pr_big_gap(
|
|
const TruncPrTupleList<T>& infos, int size,
|
|
SubProcessor<T>&)
|
|
{
|
|
CODE_LOCATION
|
|
typedef typename T::clear value_type;
|
|
bool generate = my_astra_num() == this->gen_player;
|
|
bool compute = my_astra_num() == this->comp_player;
|
|
int my_num = my_astra_num();
|
|
|
|
octetStream cs;
|
|
read(cs);
|
|
|
|
for (auto info : infos)
|
|
{
|
|
assert(info.big_gap());
|
|
this->trunc_pr_counter += size;
|
|
}
|
|
|
|
this->trunc_rounds++;
|
|
|
|
if (generate)
|
|
{
|
|
cs.require<T>(infos.size() * size);
|
|
for (auto info : infos)
|
|
for (auto& y : info.dest_range)
|
|
y = cs.get_no_check<T>();
|
|
}
|
|
|
|
if (compute)
|
|
{
|
|
cs.require<value_type>(infos.size() * size);
|
|
for (auto info : infos)
|
|
{
|
|
auto dest_it = info.dest_range.begin();
|
|
for (auto& x : info.source_range)
|
|
{
|
|
{
|
|
auto& y = *dest_it++;
|
|
y.neg_lambda(my_num) = cs.get_no_check<value_type>();
|
|
y.set_common_m(x.common_m() >> info.m);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (not (compute or generate))
|
|
{
|
|
for (auto info : infos)
|
|
if (info.big_gap())
|
|
{
|
|
cs.require<value_type>(size);
|
|
auto dest_it = info.dest_range.begin();
|
|
for (auto& x : info.source_range)
|
|
{
|
|
auto& y = *dest_it++;
|
|
y.neg_lambda(my_num) = cs.get_no_check<value_type>();
|
|
y.set_common_m(x.common_m() >> info.m);
|
|
}
|
|
}
|
|
}
|
|
|
|
assert(not cs.left());
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void AstraPrepProtocol<T>::trunc_pr_big_gap(
|
|
const TruncPrTupleList<T>& infos, int size,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
CODE_LOCATION
|
|
assert(proc.P.num_players() == 3);
|
|
typedef typename T::clear value_type;
|
|
auto& S = proc.get_S();
|
|
bool generate = my_astra_num() == this->gen_player;
|
|
bool compute = my_astra_num() == this->comp_player;
|
|
|
|
octetStream cs;
|
|
|
|
for (auto info : infos)
|
|
{
|
|
assert(info.big_gap());
|
|
this->trunc_pr_counter += size;
|
|
}
|
|
|
|
this->trunc_rounds++;
|
|
|
|
if (generate)
|
|
{
|
|
SeededPRNG G;
|
|
for (auto info : infos)
|
|
{
|
|
auto dest_it = info.dest_range.begin();
|
|
for (auto& x : info.source_range)
|
|
{
|
|
auto& y = *dest_it++;
|
|
auto r =
|
|
prng_protocol.shared_prngs[1].template get<value_type>();
|
|
y[0] = -value_type(
|
|
-value_type(x.neg_lambda_sum(my_num)) >> info.m) - r;
|
|
y[0].pack(cs);
|
|
y[1] = r;
|
|
// store(y);
|
|
}
|
|
}
|
|
|
|
this->P.send_to(this->comp_player, cs);
|
|
}
|
|
|
|
if (compute)
|
|
{
|
|
this->P.receive_player(this->gen_player, cs);
|
|
octetStream os;
|
|
for (auto info : infos)
|
|
{
|
|
cs.require<value_type>(size);
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
{
|
|
auto& y = S[info.dest_base + i];
|
|
y.neg_lambda(my_num) = cs.get_no_check<value_type>();
|
|
os.store(y.neg_lambda(my_num));
|
|
}
|
|
}
|
|
}
|
|
store(os);
|
|
}
|
|
|
|
if (not (compute or generate))
|
|
{
|
|
for (auto info : infos)
|
|
if (info.big_gap())
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
auto& y = S[info.dest_base + i];
|
|
y.neg_lambda(my_num) =
|
|
prng_protocol.shared_prngs[0].template get<
|
|
value_type>();
|
|
cs.store(y.neg_lambda(my_num));
|
|
}
|
|
store(cs);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
template<int>
|
|
void AstraBase<T>::trunc_pr_small_gap(
|
|
const TruncPrTupleList<T>& infos, int size,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
CODE_LOCATION
|
|
for (auto info : infos)
|
|
{
|
|
assert(info.small_gap());
|
|
this->trunc_pr_counter += size;
|
|
}
|
|
|
|
this->trunc_rounds++;
|
|
|
|
this->gen_values.clear();
|
|
|
|
#define X(I) if (my_astra_num() == I) \
|
|
this->template trunc_pr_small_gap_finish<I>(infos, size, proc);
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::add_gen(octetStream& cs, const typename T::open_type& value)
|
|
{
|
|
T res;
|
|
res[1].randomize(prng_protocol_for_input0.shared_prngs[1]);
|
|
res[0] = value - res[1];
|
|
this->gen_values.push_back(res);
|
|
cs.store_no_resize(res[0]);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U, int MY_NUM>
|
|
void AstraPrepShare<T>::pre_input0(const clear& input, typename U::Protocol& protocol)
|
|
{
|
|
assert(MY_NUM == 0);
|
|
protocol.add_gen(protocol.cs, input);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U, int MY_NUM>
|
|
AstraPrepShare<T> AstraPrepShare<T>::post_input0(typename U::Protocol& protocol)
|
|
{
|
|
This r_prime;
|
|
switch (MY_NUM)
|
|
{
|
|
case 0:
|
|
r_prime = protocol.gen_values.next();
|
|
break;
|
|
case 1:
|
|
protocol.cs.get_no_check(r_prime[1]);
|
|
break;
|
|
case 2:
|
|
r_prime[1].randomize(
|
|
protocol.prng_protocol_for_input0.shared_prngs[0]);
|
|
protocol.cs.store_no_resize(r_prime[1]);
|
|
break;
|
|
}
|
|
return r_prime;
|
|
}
|
|
|
|
template<class T>
|
|
template<class U, int MY_NUM>
|
|
AstraShare<T> AstraShare<T>::post_input0(typename U::Protocol& protocol)
|
|
{
|
|
This res;
|
|
protocol.cs_prep.get_no_check(res[1]);
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::init_input0(size_t n_inputs)
|
|
{
|
|
cs.reset_write_head();
|
|
cs.reserve<open_type>(n_inputs);
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::exchange_input0(size_t n_inputs)
|
|
{
|
|
auto& P = this->P;
|
|
if (P.my_num() == 0)
|
|
{
|
|
P.send_to(1, cs);
|
|
this->gen_values.reset();
|
|
this->gen_values.require(n_inputs);
|
|
}
|
|
else if (P.my_num() == 1)
|
|
{
|
|
P.receive_player(0, cs);
|
|
cs.require<open_type>(n_inputs);
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::finalize_input0(size_t n_inputs)
|
|
{
|
|
if (this->P.my_num() != 0)
|
|
{
|
|
cs.reset_read_head();
|
|
cs.require<open_type>(n_inputs);
|
|
store(cs);
|
|
}
|
|
|
|
assert(not this->gen_values.left());
|
|
}
|
|
|
|
template<class T>
|
|
void AstraOnlineBase<T>::exchange_input0(size_t n_inputs)
|
|
{
|
|
read(cs_prep);
|
|
cs_prep.require<open_type>(n_inputs);
|
|
}
|
|
|
|
template<class T>
|
|
template<int MY_NUM>
|
|
void AstraBase<T>::trunc_pr_small_gap_finish(
|
|
const TruncPrTupleList<T>& infos, int size,
|
|
SubProcessor<T>& proc)
|
|
{
|
|
bool generate = MY_NUM == gen_player;
|
|
|
|
auto& protocol = proc.protocol;
|
|
protocol.init_mul();
|
|
gen_values.reset();
|
|
|
|
init_reduced_mul(size * infos.size());
|
|
init_input0(size * infos.size());
|
|
|
|
for (auto info : infos)
|
|
{
|
|
if (info.small_gap())
|
|
{
|
|
auto source_it = info.source_range.begin();
|
|
for (auto& x : info.dest_range)
|
|
{
|
|
auto y = *source_it++;
|
|
T c_prime, c_dprime;
|
|
typename T::clear c_msb, r_msb;
|
|
|
|
if (generate and not T::real_shares(protocol.P))
|
|
{
|
|
auto r = y.sum() + info.add_before() - 1;
|
|
r_msb = r.msb();
|
|
T::template pre_input0<T, MY_NUM>(r >> info.m, protocol);
|
|
}
|
|
|
|
if (not generate and T::real_shares(protocol.P))
|
|
{
|
|
auto c = y.common_m();
|
|
c_prime.m(-1) = c >> info.m;
|
|
c_msb = c.msb();
|
|
}
|
|
|
|
T r_dprime, prod;
|
|
T::template pre_reduced_mul<T, MY_NUM>(r_dprime, c_dprime, prod,
|
|
protocol, r_msb, c_msb);
|
|
|
|
x = c_prime;
|
|
x += info.correction_shift(prod - (r_dprime + c_dprime));
|
|
}
|
|
}
|
|
}
|
|
|
|
exchange_reduced_mul(size * infos.size());
|
|
exchange_input0(size * infos.size());
|
|
|
|
for (auto info : infos)
|
|
for (auto& x : info.dest_range)
|
|
{
|
|
if (info.small_gap())
|
|
{
|
|
auto r_prime = T::template post_input0<T, MY_NUM>(protocol);
|
|
auto res = T::template post_reduced_mul<T, MY_NUM>(protocol);
|
|
x += r_prime - T::constant(info.subtract_after() - 1);
|
|
x[0] += info.correction_shift(res.second);
|
|
}
|
|
}
|
|
|
|
finalize_input0(size * infos.size());
|
|
}
|
|
|
|
template<class T>
|
|
void AstraOnlineBase<T>::unsplit(StackedVector<T>& dest,
|
|
StackedVector<typename T::bit_type>& source,
|
|
const Instruction& instruction)
|
|
{
|
|
if (instruction.get_start().size() == 1)
|
|
return unsplit1(dest, source, instruction);
|
|
|
|
CODE_LOCATION
|
|
int n_bits = instruction.get_size();
|
|
int unit = T::bit_type::default_length;
|
|
assert(instruction.get_start().size() == 2);
|
|
int my_num = my_astra_num();
|
|
|
|
octetStream os;
|
|
read(os);
|
|
os.require<T>(n_bits);
|
|
|
|
for (int i = 0; i < DIV_CEIL(n_bits, unit); i++)
|
|
{
|
|
// conversion symmetric in binary
|
|
auto x = T::bit_type::from_rep3(source.at(instruction.get_r(0) + i));
|
|
int left = min(unit, n_bits - unit * i);
|
|
typename vector<T>::iterator its[2];
|
|
for (int k = 0; k < 2; k++)
|
|
its[k] = dest.iterator_for_size(
|
|
instruction.get_start()[k] + i * unit, left);
|
|
for (int j = 0; j < left; j++)
|
|
{
|
|
auto& y = *its[0]++;
|
|
os.get_no_check(y);
|
|
y = T::from_rep3(y);
|
|
(*its[1]++).m(my_num) = x.m(my_num).get_bit(j);
|
|
}
|
|
}
|
|
}
|
|
|
|
template<class T>
|
|
void AstraPrepProtocol<T>::unsplit(StackedVector<T>& dest,
|
|
StackedVector<typename T::bit_type>& source,
|
|
const Instruction& instruction)
|
|
{
|
|
if (instruction.get_start().size() == 1)
|
|
return unsplit1(dest, source, instruction);
|
|
|
|
CODE_LOCATION
|
|
int n_bits = instruction.get_size();
|
|
int unit = T::bit_type::default_length;
|
|
assert(instruction.get_start().size() == 2);
|
|
auto& P = this->P;
|
|
|
|
if (not unsplit_input)
|
|
unsplit_input = new ReplicatedInput<Rep3Share<typename T::clear>>(P);
|
|
auto& input = *unsplit_input;
|
|
input.reset_all(P);
|
|
|
|
if (P.my_num() == 0)
|
|
{
|
|
int n_blocks = DIV_CEIL(n_bits, unit);
|
|
auto it = source.iterator_for_size(instruction.get_r(0), n_blocks);
|
|
for (int i = 0; i < n_blocks; i++)
|
|
{
|
|
auto x = (*it++).sum();
|
|
for (int j = 0; j < min(unit, n_bits - unit * i); j++)
|
|
input.add_mine(x.get_bit(j));
|
|
}
|
|
}
|
|
else
|
|
for (int i = 0; i < n_bits; i++)
|
|
input.add_other(0);
|
|
|
|
#define X(I) if (P.my_num() == I) unsplit_finish<I>(dest, instruction);
|
|
X(0) X(1) X(2)
|
|
#undef X
|
|
}
|
|
|
|
template<class T>
|
|
template<int MY_NUM>
|
|
void AstraPrepProtocol<T>::unsplit_finish(StackedVector<T>& dest,
|
|
const Instruction& instruction)
|
|
{
|
|
int n_bits = instruction.get_size();
|
|
|
|
assert(unsplit_input);
|
|
auto& input = *unsplit_input;
|
|
input.exchange();
|
|
os.reset_write_head();
|
|
os.reserve<T>(n_bits);
|
|
|
|
for (auto it : DoubleRange(dest, instruction.get_start()[0],
|
|
instruction.get_start()[1], n_bits))
|
|
{
|
|
auto& x = it.first;
|
|
x = from_rep3(input.finalize_offset(-MY_NUM), MY_NUM);
|
|
os.store_no_resize(x);
|
|
it.second = {};
|
|
}
|
|
|
|
store(os);
|
|
}
|
|
|
|
template<class T>
|
|
template<class U, int MY_NUM>
|
|
void AstraShare<T>::pre_reduced_mul(U& a, U& b, U& c,
|
|
Astra<U>& protocol, const T&, const T& bb)
|
|
{
|
|
auto& os_prep = protocol.os_prep;
|
|
os_prep.get_no_check(a[1]);
|
|
os_prep.get_no_check(c[1]);
|
|
auto Mi = bb * a.neg_lambda(-1) - c.neg_lambda(-1);
|
|
protocol.os.store_no_resize(Mi);
|
|
b.m(-1) = bb;
|
|
c.m(-1) = Mi;
|
|
protocol.results.push_back({});
|
|
}
|
|
|
|
template<class T>
|
|
template<class U, int MY_NUM>
|
|
pair<U, T> AstraShare<T>::post_reduced_mul(typename U::Protocol& protocol)
|
|
{
|
|
pair<U, T> res;
|
|
res.first = protocol.results.next();
|
|
res.second = protocol.recv_os.template get_no_check<clear>();
|
|
return res;
|
|
}
|
|
|
|
template<class T>
|
|
void Astra<T>::init_reduced_mul(size_t n_mul)
|
|
{
|
|
this->read(os_prep);
|
|
os_prep.require<clear>(2 * n_mul);
|
|
|
|
os.reset_write_head();
|
|
os.reserve<clear>(n_mul);
|
|
auto& results = this->results;
|
|
results.clear();
|
|
results.reserve(n_mul);
|
|
}
|
|
|
|
template<class T>
|
|
void Astra<T>::exchange_reduced_mul(size_t n_mul)
|
|
{
|
|
auto& P = this->P;
|
|
P.pass_around(os, recv_os, 1);
|
|
recv_os.require<clear>(n_mul);
|
|
this->results.reset();
|
|
assert(this->results.left() == n_mul);
|
|
}
|
|
|
|
template<class T>
|
|
void Astra<T>::unsplit1(StackedVector<T>& dest,
|
|
StackedVector<typename T::bit_type>& source,
|
|
const Instruction& instruction)
|
|
{
|
|
CODE_LOCATION
|
|
size_t n_bits = instruction.get_size();
|
|
assert(instruction.get_start().size() == 1);
|
|
auto& results = this->results;
|
|
|
|
init_reduced_mul(n_bits);
|
|
auto range = BlockRange(source, instruction.get_r(0), n_bits);
|
|
|
|
for (auto& x : range)
|
|
{
|
|
for (typename T::clear bit : BitLeftRange(x.m(-1), x, range))
|
|
{
|
|
T a, b, c;
|
|
T::template pre_reduced_mul<T, -1>(a, b, c, *this, {}, bit);
|
|
results.back() = a + b - 2 * c;
|
|
}
|
|
}
|
|
|
|
exchange_reduced_mul(n_bits);
|
|
|
|
for (auto& x : Range(dest, instruction.get_start()[0], n_bits))
|
|
{
|
|
auto res = T::template post_reduced_mul<T, 1>(*this);
|
|
x = res.first;
|
|
x.m(-1) -= res.second * 2;
|
|
}
|
|
|
|
assert(not recv_os.left());
|
|
assert(not os_prep.left());
|
|
}
|
|
|
|
#endif
|