Files
MP-SPDZ/BMR/RealGarbleWire.hpp
Marcel Keller 6a424539c9 SoftSpokenOT.
2022-08-25 13:23:18 +10:00

282 lines
6.9 KiB
C++

/*
* RealGarbleWire.cpp
*
*/
#include "RealGarbleWire.h"
#include "RealProgramParty.h"
#include "Protocols/MascotPrep.h"
template<class T>
void RealGarbleWire<T>::garble(PRFOutputs& prf_output,
const RealGarbleWire<T>& left, const RealGarbleWire<T>& right)
{
auto& party = RealProgramParty<T>::s();
assert(party.prep != 0);
party.prep->get_one(DATA_BIT, mask);
auto& inputter = *party.garble_inputter;
int n = party.N.num_players();
int me = party.N.my_num();
inputter.add_from_all(int128(keys[0][me].r));
for (int k = 0; k < 4; k++)
for (int j = 0; j < n; j++)
inputter.add_from_all(int128(prf_output[j].for_garbling(k).r));
assert(party.shared_proc != 0);
assert(party.garble_protocol != 0);
auto& protocol = *party.garble_protocol;
protocol.prepare_mul(left.mask, right.mask);
GarbleJob<T> job(left.mask, right.mask, mask);
party.garble_jobs.push_back(job);
}
template<class T>
GarbleJob<T>::GarbleJob(T lambda_u, T lambda_v, T lambda_w) :
lambda_u(lambda_u), lambda_v(lambda_v), lambda_w(lambda_w)
{
}
template<class T>
void GarbleJob<T>::middle_round(RealProgramParty<T>& party, Protocol& second_protocol)
{
int n = party.N.num_players();
int me = party.N.my_num();
assert(party.garble_protocol != 0);
auto& protocol = *party.garble_protocol;
lambda_uv = protocol.finalize_mul();
#ifdef DEBUG_MASK
cout << "lambda_u " << party.MC->POpen(lambda_u, *party.P) << endl;
cout << "lambda_v " << party.MC->POpen(lambda_v, *party.P) << endl;
cout << "lambda_w " << party.MC->POpen(lambda_w, *party.P) << endl;
cout << "lambda_uv " << party.MC->POpen(lambda_uv, *party.P) << endl;
#endif
for (int alpha = 0; alpha < 2; alpha++)
for (int beta = 0; beta < 2; beta++)
for (int j = 0; j < n; j++)
{
second_protocol.prepare_mul(party.shared_delta(j),
lambda_uv + lambda_v * alpha + lambda_u * beta
+ T::constant(alpha * beta, me, party.MC->get_alphai())
+ lambda_w);
}
}
template<class T>
void GarbleJob<T>::last_round(RealProgramParty<T>& party, Inputter& inputter,
Protocol& second_protocol, vector<T>& wires)
{
int n = party.N.num_players();
auto& protocol = second_protocol;
vector<T> base_keys;
for (int i = 0; i < n; i++)
base_keys.push_back(inputter.finalize(i));
for (int k = 0; k < 4; k++)
for (int j = 0; j < n; j++)
{
wires.push_back({});
auto& wire = wires.back();
for (int i = 0; i < n; i++)
wire += inputter.finalize(i);
wire += base_keys[j];
wire += protocol.finalize_mul();
}
}
template<class T>
void RealGarbleWire<T>::XOR(const RealGarbleWire<T>& left, const RealGarbleWire<T>& right)
{
PRFRegister::XOR(left, right);
mask = left.mask + right.mask;
}
template<class T>
void RealGarbleWire<T>::inputb(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputb(inputter, processor, args,
inputter.party.P->my_num());
}
template<class T>
void RealGarbleWire<T>::inputbvec(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputbvec(inputter, input_processor, args,
*inputter.party.P);
}
template<class T>
GarbleInputter<T>::GarbleInputter() :
party(RealProgramParty<T>::s()), oss(*party.P)
{
}
template<class T>
void RealGarbleWire<T>::my_input(Input& inputter, bool, int n_bits)
{
assert(n_bits == 1);
inputter.tuples.push_back({this, inputter.party.P->my_num()});
}
template<class T>
void RealGarbleWire<T>::other_input(Input& inputter, int from)
{
inputter.tuples.push_back({this, from});
}
template<class T>
void GarbleInputter<T>::exchange()
{
assert(party.shared_proc != 0);
auto& inputter = party.shared_proc->input;
inputter.reset_all(*party.P);
for (auto& tuple : tuples)
{
int from = tuple.second;
party_id_t from_id = from + 1;
tuple.first->PRFRegister::input(from_id, -1);
if (from_id == party.get_id())
{
char my_mask;
my_mask = party.prng.get_bit();
party.garble_input_masks.serialize(my_mask);
inputter.add_mine(my_mask);
#ifdef DEBUG_MASK
cout << "my mask: " << (int)my_mask << endl;
#endif
}
else
{
inputter.add_other(from);
}
}
inputter.exchange();
for (auto& tuple : tuples)
tuple.first->mask = (inputter.finalize(tuple.second));
// important to make sure that mask is a bit
try
{
for (auto& tuple : tuples)
tuple.first->mask.force_to_bit();
}
catch (not_implemented& e)
{
assert(party.P != 0);
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul();
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)
- tuple.first->mask);
protocol.exchange();
vector<T> to_check;
to_check.reserve(tuples.size());
for (size_t i = 0; i < tuples.size(); i++)
{
to_check.push_back(protocol.finalize_mul());
}
try
{
party.MC->CheckFor(0, to_check, *party.P);
}
catch (mac_fail&)
{
throw runtime_error("input mask not a bit");
}
}
#ifdef DEBUG_MASK
cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl;
#endif
}
template<class T>
void RealGarbleWire<T>::finalize_input(GarbleInputter<T>&, int, int)
{
}
template<class T>
void RealGarbleWire<T>::public_input(bool value)
{
PRFRegister::public_input(value);
mask = {};
}
template<class T>
void RealGarbleWire<T>::random()
{
// no need to randomize keys
PRFRegister::public_input(0);
auto& party = RealProgramParty<T>::s();
assert(party.prep != 0);
party.prep->get_one(DATA_BIT, mask);
// this is necessary to match the fake BMR evaluation phase
party.store_wire(*this);
keys[0].serialize(party.wires);
}
template<class T>
void RealGarbleWire<T>::output()
{
PRFRegister::output();
auto& party = RealProgramParty<T>::s();
assert(party.MC != 0);
assert(party.P != 0);
auto m = party.MC->open(mask, *party.P);
party.garble_output_masks.push_back(m.get_bit(0));
party.taint();
#ifdef DEBUG_MASK
cout << "output mask: " << m << endl;
#endif
}
template<class T>
void RealGarbleWire<T>::store(NoMemory& dest,
const vector<GC::WriteAccess<GC::Secret<RealGarbleWire> > >& accesses)
{
(void) dest;
auto& party = RealProgramParty<T>::s();
for (auto access : accesses)
for (auto& reg : access.source.get_regs())
{
party.push_spdz_wire(SPDZ_STORE, reg);
}
}
template<class T>
void RealGarbleWire<T>::load(
vector<GC::ReadAccess<GC::Secret<RealGarbleWire> > >& accesses,
const NoMemory& source)
{
PRFRegister::load(accesses, source);
auto& party = RealProgramParty<T>::s();
assert(party.prep != 0);
for (auto access : accesses)
for (auto& reg : access.dest.get_regs())
{
party.prep->get_one(DATA_BIT, reg.mask);
party.push_spdz_wire(SPDZ_LOAD, reg);
}
}
template<class T>
void RealGarbleWire<T>::convcbit(Integer& dest, const GC::Clear& source,
GC::Processor<GC::Secret<RealGarbleWire>>& processor)
{
(void) source;
auto& party = RealProgramParty<T>::s();
processor.untaint();
dest = party.convcbit;
}