mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
965 lines
25 KiB
C++
965 lines
25 KiB
C++
/*
|
|
* Register.cpp
|
|
*
|
|
*/
|
|
|
|
#include "Register.h"
|
|
#include "Wire.h"
|
|
#include "GarbledGate.h"
|
|
#include "Gate.h"
|
|
#include "Party.h"
|
|
#include "TrustedParty.h"
|
|
#include "CommonParty.h"
|
|
#include "Register_inline.h"
|
|
#include "RealGarbleWire.h"
|
|
|
|
#include "prf.h"
|
|
|
|
#include "GC/Secret.h"
|
|
#include "GC/Secret_inline.h"
|
|
#include "GC/Processor.h"
|
|
#include "Tools/FlexBuffer.h"
|
|
|
|
#include "GC/Processor.hpp"
|
|
|
|
#include <unistd.h>
|
|
|
|
ostream& EvalRegister::out = cout;
|
|
|
|
int Register::counter = 0;
|
|
|
|
void Register::init(int n_parties)
|
|
{
|
|
keys.init(n_parties);
|
|
mask = NO_SIGNAL;
|
|
external = NO_SIGNAL;
|
|
}
|
|
|
|
void Register::init(int rfd, int n_parties) {
|
|
(void)rfd;
|
|
mask = CommonParty::s().prng.get_uchar();
|
|
#ifdef DEBUG
|
|
mask = counter & 1;
|
|
counter++;
|
|
#endif
|
|
#ifdef DEBUG_MASK
|
|
mask = 0;
|
|
#endif
|
|
mask = mask>0 ? 1 : 0;
|
|
keys.init(n_parties);
|
|
keys.randomize();
|
|
for (int i = 0; i < 2; i++)
|
|
for (size_t j = 0; j < keys[i].size(); j++)
|
|
if (keys[i][j].get_signal() != i)
|
|
keys[i][j] ^= Key(1);
|
|
}
|
|
|
|
void Register::set_eval_keys()
|
|
{
|
|
if (external == NO_SIGNAL)
|
|
throw exception();
|
|
garbled_entry = keys[external];
|
|
}
|
|
|
|
void Register::set_eval_keys(Key* keys, int n_parties, int except)
|
|
{
|
|
this->keys.copy_from(keys, n_parties, except);
|
|
set_eval_keys();
|
|
}
|
|
|
|
void Register::set_external_key(party_id_t i, const Key& key)
|
|
{
|
|
check_external();
|
|
garbled_entry[i-1] = key;
|
|
keys[external][i-1] = key;
|
|
//keys[1-external][i-1] = 0;
|
|
}
|
|
|
|
void Register::reset_non_external_key(party_id_t i)
|
|
{
|
|
check_external();
|
|
keys[1-external][i-1] = 0;
|
|
}
|
|
|
|
void Register::set_external(char ext)
|
|
{
|
|
external = ext;
|
|
check_external();
|
|
garbled_entry = keys[external];
|
|
}
|
|
|
|
void Register::check_mask() const
|
|
{
|
|
#ifdef SIGNAL_CHECK
|
|
if (mask * (1 - mask) != 0)
|
|
{
|
|
if (mask == NO_SIGNAL)
|
|
throw runtime_error("no signal in mask");
|
|
else
|
|
throw runtime_error("invalid mask");
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void Register::set_mask(char mask)
|
|
{
|
|
#ifdef DEBUG_MASK
|
|
cout << "setting mask: " << (int)mask << endl;
|
|
#endif
|
|
this->mask = mask;
|
|
check_mask();
|
|
}
|
|
|
|
void Register::check_external() const
|
|
{
|
|
#ifdef SIGNAL_CHECK
|
|
if (external * (1 - external) != 0)
|
|
{
|
|
if (external == NO_SIGNAL)
|
|
throw runtime_error("no signal in external value");
|
|
else
|
|
throw runtime_error("invalid external value");
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void Register::check_signal_key(int my_id, KeyVector& garbled_entry)
|
|
{
|
|
if (garbled_entry[my_id - 1] != keys[get_external()][my_id - 1])
|
|
{
|
|
if (garbled_entry[my_id - 1] == keys[1 - get_external()][my_id - 1])
|
|
throw runtime_error("key matches inverted signal");
|
|
else
|
|
{
|
|
cout << "signal " << (int)get_external() << ", ";
|
|
cout << garbled_entry[my_id - 1] << " none of ";
|
|
for (int i = 0; i < 2; i++)
|
|
cout << keys[i][my_id - 1] << " ";
|
|
cout << endl;
|
|
throw runtime_error("key doesn't match signal");
|
|
}
|
|
}
|
|
}
|
|
|
|
void Register::print_input(int id)
|
|
{
|
|
for (int x = 0; x < 2; x++)
|
|
{
|
|
printf("k^I_{%u,%d}: ",id,x);
|
|
for (unsigned int i = 0; i < keys[x].size(); i++)
|
|
std::cout << keys[x][i] << " ";
|
|
std::cout << std::endl;
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void EvalRegister::andrs(GC::Processor<GC::Secret<EvalRegister> >& processor,
|
|
const vector<int>& args)
|
|
{
|
|
ProgramParty& party = ProgramParty::s();
|
|
int total = 0;
|
|
for (size_t j = 0; j < args.size(); j += 4)
|
|
total += args[j];
|
|
if (total < party.threshold)
|
|
{
|
|
// run in single thread
|
|
processor.andrs(args);
|
|
return;
|
|
}
|
|
|
|
int max_gates_per_thread = (total + N_EVAL_THREADS - 1) / N_EVAL_THREADS;
|
|
int i_thread = 0, i_gate = 0;
|
|
party.and_jobs[0].reset(processor.S, args, 0, party.next_gate(0),
|
|
total, party.get_n_parties());
|
|
for (size_t j = 0; j < args.size(); j += 4)
|
|
{
|
|
AndJob& and_job = party.and_jobs[i_thread];
|
|
GC::Secret<EvalRegister>& dest = processor.S[args[j + 1]];
|
|
dest.resize_regs(args[j]);
|
|
processor.complexity += args[j];
|
|
for (int i = 0; i < args[j]; i++)
|
|
{
|
|
and_job.gates[i_gate].unserialize(party.garbled_circuit,
|
|
party.get_n_parties());
|
|
i_gate++;
|
|
}
|
|
and_job.end = j + 4;
|
|
if (i_gate >= max_gates_per_thread or and_job.end >= args.size())
|
|
{
|
|
party.eval_threads[i_thread].request(and_job);
|
|
gate_id_t gate_id = party.next_gate(i_gate);
|
|
i_gate = 0;
|
|
// advance to next thread but only if not on least thread
|
|
if(and_job.end < args.size())
|
|
party.and_jobs[++i_thread].reset(processor.S, args, and_job.end,
|
|
gate_id, total, party.get_n_parties());
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i <= i_thread; i++)
|
|
party.eval_threads[i].done();
|
|
}
|
|
|
|
void EvalRegister::op(const ProgramRegister& left, const ProgramRegister& right, Function func)
|
|
{
|
|
(void)func;
|
|
#ifdef DEBUG
|
|
cout << "eval left " << &left << " " << dec << " " << left.get_id() << endl;
|
|
cout << "eval right " << &right << " " << dec << " " << right.get_id() << endl;
|
|
#endif
|
|
ProgramParty& party = *ProgramParty::singleton;
|
|
GarbledGate gate(party.get_n_parties());
|
|
party.next_gate(gate);
|
|
gate.unserialize(party.garbled_circuit, party.get_n_parties());
|
|
party.prf_output.resize(PAD_TO_8(party.get_n_parties()) * sizeof(__m128i));
|
|
Register::eval(left, right, gate, party._id, party.prf_output.data(),
|
|
get_id(), left.get_id(), right.get_id());
|
|
}
|
|
|
|
void Register::eval(const Register& left, const Register& right, GarbledGate& gate,
|
|
party_id_t my_id, char* prf_output, int w_o, int w_l, int w_r)
|
|
{
|
|
(void)w_o;
|
|
(void)w_l;
|
|
(void)w_r;
|
|
size_t n_parties = CommonParty::singleton->get_n_parties();
|
|
int sig_l = left.get_external();
|
|
int sig_r = right.get_external();
|
|
int entry = 2 * sig_l + sig_r;
|
|
|
|
#ifdef DEBUG_MASK
|
|
cout << "input signals: " << sig_l << " " << sig_r << endl;
|
|
#endif
|
|
|
|
#ifdef DEBUG
|
|
gate.print();
|
|
cout << "picking " << entry << endl;
|
|
#endif
|
|
|
|
garbled_entry = gate[entry];
|
|
int ext_l = entry%2 ? 1 : 0 ;
|
|
int ext_r = entry<2 ? 0 : 1 ;
|
|
|
|
#ifdef DEBUG
|
|
printf("left input");
|
|
phex(gate.input(ext_l, 1), 16);
|
|
printf("right input");
|
|
phex(gate.input(ext_r, 1), 16);
|
|
unsigned g = gate.id;
|
|
#endif
|
|
|
|
Key k;
|
|
for(party_id_t i=1; i<=n_parties; i++) {
|
|
#ifdef DEBUG
|
|
std::cout << "using key: " << left.external_key(i) << endl;
|
|
#endif
|
|
PRF_chunk(left.external_key(i), gate.input(ext_l, 1), prf_output, PAD_TO_8(n_parties));
|
|
for(party_id_t j=1; j<=n_parties; j++) {
|
|
k = *(Key*)(prf_output+16*(j-1));
|
|
#ifdef DEBUG
|
|
printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_l, sig_l,ext_l,g,j);
|
|
std::cout << k << std::endl;
|
|
#endif
|
|
garbled_entry[j-1] -= k;
|
|
}
|
|
|
|
#ifdef DEBUG
|
|
std::cout << "using key: " << right.external_key(i) << endl;
|
|
#endif
|
|
PRF_chunk(right.external_key(i), gate.input(ext_r, 1) , prf_output, PAD_TO_8(n_parties));
|
|
for(party_id_t j=1; j<=n_parties; j++) {
|
|
k = *(Key*)(prf_output+16*(j-1));
|
|
#ifdef DEBUG
|
|
printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_r, sig_r,ext_r,g,j);
|
|
std::cout << k << std::endl;
|
|
#endif
|
|
garbled_entry[j-1] -= k;
|
|
}
|
|
}
|
|
|
|
// for(party_id_t i=1; i<=n_parties; i++) {
|
|
// std::cout << garbled_entry[i-1] << " ";
|
|
// }
|
|
// std::cout << std::endl;
|
|
|
|
external = garbled_entry[my_id - 1].get_signal();
|
|
|
|
#ifdef DEBUG_MASK
|
|
cout << "output signal: " << (int)external << endl;
|
|
#endif
|
|
|
|
#ifdef DEBUG
|
|
std::cout << "k^"<<my_id<<"_{"<<w_o<<","<<(int)external<<"} = " << key(my_id, external) << std::endl;
|
|
#endif
|
|
}
|
|
|
|
void GarbleRegister::op(const Register& left, const Register& right, Function func)
|
|
{
|
|
TrustedProgramParty& party = *TrustedProgramParty::singleton;
|
|
party.load_wire(*this);
|
|
Gate _;
|
|
Register::garble(left, right, func, &_, 0, party.prf_outputs,
|
|
party.buffers[TYPE_GARBLED_CIRCUIT]);
|
|
}
|
|
|
|
void Register::garble(const Register& left, const Register& right,
|
|
Function func, Gate* gate, int g,
|
|
vector<ReceivedMsg>& prf_outputs, SendBuffer& buffer)
|
|
{
|
|
(void)func;
|
|
(void)gate;
|
|
(void)g;
|
|
size_t n_parties = CommonParty::singleton->get_n_parties();
|
|
GarbledGate garbled_gate(n_parties);
|
|
garbled_gate.reset();
|
|
KeyVector& gg_A = garbled_gate[0];
|
|
KeyVector& gg_B = garbled_gate[1];
|
|
KeyVector& gg_C = garbled_gate[2];
|
|
KeyVector& gg_D = garbled_gate[3];
|
|
|
|
for(party_id_t i=1; i<=n_parties; i++) {
|
|
#ifdef DEBUG
|
|
std::cout << "adding prfs of party " << i << std::endl ;
|
|
#endif
|
|
Key left_i_j, right_i_j;
|
|
for (party_id_t j=1; j<=n_parties; j++) {
|
|
PRFTuple party_prfs;
|
|
prf_outputs[i - 1].unserialize(party_prfs);
|
|
//A
|
|
#ifdef DEBUG
|
|
left_i_j = *(Key*)party_prfs.outputs[0][0][0];
|
|
right_i_j = *(Key*)party_prfs.outputs[1][0][0];
|
|
std::cout << "A" << std::endl;
|
|
cout << gg_A[j-1] << std::endl;
|
|
cout << left_i_j << std::endl;
|
|
cout << right_i_j << std::endl;
|
|
#endif
|
|
gg_A[j-1] += party_prfs.for_garbling(0);
|
|
// cout << gg_A[j-1] << std::endl<< std::endl;
|
|
|
|
//B
|
|
#ifdef DEBUG
|
|
left_i_j = *(Key*)party_prfs.outputs[0][0][1];
|
|
right_i_j = *(Key*)party_prfs.outputs[1][1][0];
|
|
std::cout << "B" << std::endl;
|
|
cout << gg_B[j-1] << std::endl;
|
|
cout << left_i_j << std::endl;
|
|
cout << right_i_j << std::endl;
|
|
#endif
|
|
gg_B[j-1] += party_prfs.for_garbling(1);
|
|
// cout << gg_B[j-1] << std::endl<< std::endl;
|
|
|
|
//C
|
|
#ifdef DEBUG
|
|
left_i_j = *(Key*)party_prfs.outputs[0][1][0];
|
|
right_i_j = *(Key*)party_prfs.outputs[1][0][1];
|
|
std::cout << "C" << std::endl;
|
|
cout << gg_C[j-1] << std::endl;
|
|
cout << left_i_j << std::endl;
|
|
cout << right_i_j << std::endl;
|
|
#endif
|
|
gg_C[j-1] += party_prfs.for_garbling(2);
|
|
// cout << gg_C[j-1] << std::endl<< std::endl;
|
|
|
|
//D
|
|
#ifdef DEBUG
|
|
left_i_j = *(Key*)party_prfs.outputs[0][1][1];
|
|
right_i_j = *(Key*)party_prfs.outputs[1][1][1];
|
|
std::cout << "D" << std::endl;
|
|
cout << gg_D[j-1] << std::endl;
|
|
cout << left_i_j << std::endl;
|
|
cout << right_i_j << std::endl;
|
|
#endif
|
|
gg_D[j-1] += party_prfs.for_garbling(3);
|
|
// cout << gg_D[j-1] << std::endl<< std::endl;
|
|
}
|
|
}
|
|
|
|
//Adding the hidden keys
|
|
left.check_mask();
|
|
right.check_mask();
|
|
check_mask();
|
|
char maskl = left.mask;
|
|
char maskr = right.mask;
|
|
char masko = mask;
|
|
#ifdef DEBUG
|
|
printf("\ngate %u, leftwire=%lu, rightwire=%lu, outwire=%lu: func=%d%d%d%d, msk_l=%d, msk_r=%d, msk_o=%d\n"
|
|
, g,gate->_left, gate->_right, gate->_out
|
|
,func[0],func[1],func[2],func[3], maskl, maskr, masko);
|
|
#endif
|
|
|
|
// printf("\n");
|
|
// printf("maskl=%d, maskr=%d, masko=%d\n",maskl,maskr,masko);
|
|
// printf("gate func = %d%d%d%d\n",gate->_func[0],gate->_func[1],gate->_func[2],gate->_func[3]);
|
|
bool xa = func[2*maskl+maskr] != masko;
|
|
bool xb = func[2*maskl+(1-maskr)] != masko;
|
|
bool xc = func[2*(1-maskl)+maskr] != masko;
|
|
bool xd = func[2*(1-maskl)+(1-maskr)] != masko;
|
|
|
|
#ifdef DEBUG
|
|
printf("xa=%d, xb=%d, xc=%d, xd=%d\n", xa,xb,xc,xd);
|
|
#endif
|
|
|
|
// these are the 0-keys
|
|
Register& outwire = *this;
|
|
keys.init(n_parties);
|
|
KeyVector& keyxa = outwire[xa];
|
|
KeyVector& keyxb = outwire[xb];
|
|
KeyVector& keyxc = outwire[xc];
|
|
KeyVector& keyxd = outwire[xd];
|
|
|
|
for(party_id_t i=1; i<=n_parties; i++) {
|
|
#ifdef DEBUG
|
|
std::cout << "adding to A = " << keyxa[i-1] << std::endl;
|
|
std::cout << "adding to B = " << keyxb[i-1] << std::endl;
|
|
std::cout << "adding to C = " << keyxc[i-1] << std::endl;
|
|
std::cout << "adding to D = " << keyxd[i-1] << std::endl;
|
|
#endif
|
|
gg_A[i-1] += keyxa[i-1];
|
|
gg_B[i-1] += keyxb[i-1];
|
|
gg_C[i-1] += keyxc[i-1];
|
|
gg_D[i-1] += keyxd[i-1];
|
|
}
|
|
garbled_gate.serialize_no_allocate(buffer);
|
|
#ifdef DEBUG
|
|
garbled_gate.print();
|
|
#endif
|
|
}
|
|
|
|
void PRFRegister::output()
|
|
{
|
|
ProgramParty& party = *ProgramParty::singleton;
|
|
party.store_wire(*this);
|
|
}
|
|
|
|
void PRFRegister::op(const PRFRegister& left, const PRFRegister& right, Function func)
|
|
{
|
|
(void)func;
|
|
#ifdef DEBUG
|
|
cout << "prf op " << &left << " " << &right << endl;
|
|
cout << "sizes " << left.keys[0].size() << " " << right.keys[0].size() << endl;
|
|
#endif
|
|
const Register* in_wires[2] = { &left, &right };
|
|
ProgramParty& party = *ProgramParty::singleton;
|
|
party.receive_keys(*this);
|
|
GarbledGate gate(party.get_n_parties());
|
|
PRFOutputs prf_output(party.get_n_parties());
|
|
gate.compute_prfs_outputs(in_wires, party._id, prf_output, party.new_gate());
|
|
party.process_prf_output(prf_output, this, &left, &right);
|
|
#ifdef DEBUG_FREE_XOR
|
|
int i = ProgramParty::s()._id - 1;
|
|
Key delta = party.get_delta();
|
|
if (delta != (left.keys[0][i] ^ left.keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
if (delta != (right.keys[0][i] ^ right.keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
if (delta != (keys[0][i] ^ keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
#endif
|
|
}
|
|
|
|
void PRFRegister::input(party_id_t from, char value)
|
|
{
|
|
(void)from;
|
|
(void)value;
|
|
ProgramParty& party = *ProgramParty::singleton;
|
|
party.receive_keys(*this);
|
|
party.store_wire(*this);
|
|
#ifdef DEBUG
|
|
cout << "(PRF) input from " << from << ":" << endl;
|
|
keys.print(get_id());
|
|
#endif
|
|
}
|
|
|
|
void PRFRegister::public_input(bool value)
|
|
{
|
|
ProgramParty& party = ProgramParty::s();
|
|
int i = party.get_id() - 1;
|
|
keys[value][i] = 0;
|
|
keys[1 - value][i] = party.get_delta();
|
|
set_mask(0);
|
|
}
|
|
|
|
void PRFRegister::random()
|
|
{
|
|
ProgramParty& party = ProgramParty::s();
|
|
ProgramParty::s().receive_keys(*this);
|
|
ProgramParty::s().receive_all_keys(*this, 0);
|
|
party.store_wire(*this);
|
|
keys[0].serialize(party.wires);
|
|
#ifdef DEBUG
|
|
cout << "(PRF) random:" << endl;
|
|
keys.print(get_id());
|
|
#endif
|
|
}
|
|
|
|
class InputAccess
|
|
{
|
|
party_id_t from;
|
|
size_t n_bits;
|
|
GC::Secret<EvalRegister>& dest;
|
|
GC::Processor<GC::Secret<EvalRegister> >& processor;
|
|
ProgramParty& party;
|
|
InputArgs args;
|
|
|
|
public:
|
|
InputAccess(const InputArgs& args,
|
|
GC::Processor<GC::Secret<EvalRegister> >& processor) :
|
|
from(args.from + 1), n_bits(args.n_bits), dest(
|
|
processor.S[args.dest]), processor(processor), party(
|
|
ProgramParty::s()), args(args)
|
|
{
|
|
if (from > unsigned(party.get_n_parties()) or n_bits > 100)
|
|
throw runtime_error("invalid input parameters");
|
|
}
|
|
|
|
void prepare_masks(octetStream& os)
|
|
{
|
|
dest.resize_regs(n_bits);
|
|
for (auto& reg : dest.get_regs())
|
|
party.load_wire(reg);
|
|
if (from == party.get_id())
|
|
{
|
|
long long in = processor.get_input(args.params);
|
|
for (size_t i = 0; i < n_bits; i++)
|
|
{
|
|
auto& reg = dest.get_reg(i);
|
|
reg.input_helper((in >> i) & 1, os);
|
|
}
|
|
}
|
|
}
|
|
|
|
void received_masks(vector<octetStream>& oss)
|
|
{
|
|
size_t id = party.get_id() - 1;
|
|
for (auto& reg : dest.get_regs())
|
|
{
|
|
if (party.get_id() != from)
|
|
{
|
|
char ext;
|
|
oss[from - 1].unserialize(ext);
|
|
reg.set_external(ext);
|
|
}
|
|
oss[id].serialize(reg.get_garbled_entry()[id]);
|
|
}
|
|
}
|
|
|
|
void received_labels(vector<octetStream>& oss)
|
|
{
|
|
for (auto& reg : dest.get_regs())
|
|
{
|
|
for (party_id_t id = 1; id < (size_t)party.get_n_parties() + 1; id++)
|
|
{
|
|
Key key;
|
|
if (id != party.get_id())
|
|
{
|
|
oss[id - 1].unserialize(key);
|
|
reg.set_external_key(id, key);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <>
|
|
void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
|
|
const vector<int>& args)
|
|
{
|
|
auto& party = ProgramParty::s();
|
|
vector<octetStream> oss(party.get_n_parties());
|
|
octetStream& my_os = oss[party.get_id() - 1];
|
|
vector<InputAccess> accesses;
|
|
InputArgList a(args);
|
|
for (auto x : a)
|
|
{
|
|
accesses.push_back({x , processor});
|
|
}
|
|
for (auto& access : accesses)
|
|
access.prepare_masks(my_os);
|
|
party.P->Broadcast_Receive(oss, true);
|
|
my_os.reset_write_head();
|
|
for (auto& access : accesses)
|
|
access.received_masks(oss);
|
|
party.P->Broadcast_Receive(oss, true);
|
|
for (auto& access : accesses)
|
|
access.received_labels(oss);
|
|
}
|
|
|
|
void EvalRegister::convcbit(Integer& dest, const GC::Clear& source,
|
|
GC::Processor<GC::Secret<EvalRegister>>& processor)
|
|
{
|
|
auto& party = ProgramParty::s();
|
|
dest = source;
|
|
party.convcbit = source;
|
|
processor.untaint();
|
|
}
|
|
|
|
void EvalRegister::input_helper(char value, octetStream& os)
|
|
{
|
|
set_mask(ProgramParty::s().input_masks.pop_front());
|
|
set_external(get_mask() ^ value);
|
|
os.serialize(get_external());
|
|
}
|
|
|
|
EvalInputter::EvalInputter() :
|
|
party(ProgramParty::s()), oss(*party.P)
|
|
{
|
|
}
|
|
|
|
void EvalRegister::my_input(EvalInputter& inputter, bool input, int n_bits)
|
|
{
|
|
assert(n_bits == 1);
|
|
auto& party = inputter.party;
|
|
party.load_wire(*this);
|
|
input_helper(input, inputter.oss.mine);
|
|
inputter.tuples.push_back({this, party.P->my_num()});
|
|
}
|
|
|
|
void EvalRegister::other_input(EvalInputter& inputter, int from)
|
|
{
|
|
auto& party = inputter.party;
|
|
party.load_wire(*this);
|
|
inputter.tuples.push_back({this, from});
|
|
}
|
|
|
|
void EvalInputter::exchange()
|
|
{
|
|
party.P->Broadcast_Receive(oss, true);
|
|
for (auto& tuple : tuples)
|
|
{
|
|
if (tuple.from != party.P->my_num())
|
|
{
|
|
char ext;
|
|
oss[tuple.from].unserialize(ext);
|
|
tuple.reg->set_external(ext);
|
|
}
|
|
}
|
|
|
|
size_t id = party.get_id() - 1;
|
|
for (auto& os : oss)
|
|
os.reset_write_head();
|
|
|
|
for (auto& tuple : tuples)
|
|
{
|
|
oss[id].serialize(tuple.reg->get_garbled_entry()[id]);
|
|
#ifdef DEBUG_COMM
|
|
cout << "send " << garbled_entry[id] << ", "
|
|
<< oss[id].get_length() << " bytes from " << id << endl;
|
|
#endif
|
|
}
|
|
|
|
party.P->Broadcast_Receive(oss, true);
|
|
}
|
|
|
|
void EvalRegister::finalize_input(EvalInputter& inputter, int, int)
|
|
{
|
|
auto& party = inputter.party;
|
|
size_t id = party.get_id() - 1;
|
|
for (size_t i = 0; i < (size_t)party.get_n_parties(); i++)
|
|
{
|
|
if (i != id)
|
|
inputter.oss[i].unserialize(garbled_entry[i]);
|
|
}
|
|
keys[external] = garbled_entry;
|
|
#ifdef DEBUG
|
|
cout << "(Input) input from " << from << ":" << endl;
|
|
keys.print(get_id());
|
|
cout << garbled_entry << endl;
|
|
#endif
|
|
}
|
|
|
|
void EvalRegister::public_input(bool value)
|
|
{
|
|
ProgramParty& party = ProgramParty::s();
|
|
set_mask(0);
|
|
set_external(value);
|
|
for (int i = 0; i < party.get_n_parties(); i++)
|
|
set_external_key(i + 1, 0);
|
|
}
|
|
|
|
void EvalRegister::random()
|
|
{
|
|
auto& party = ProgramParty::s();
|
|
party.load_wire(*this);
|
|
keys[0].unserialize(party.wires, party.get_n_parties());
|
|
set_external(0);
|
|
}
|
|
|
|
void RandomRegister::randomize()
|
|
{
|
|
TrustedProgramParty& party = *TrustedProgramParty::singleton;
|
|
party.random_timer.start();
|
|
init(party.randomfd, party._N);
|
|
party.random_timer.stop();
|
|
keys[1] = keys[0] ^ party.get_deltas();
|
|
party.add_keys(*this);
|
|
}
|
|
|
|
void RandomRegister::op(const Register& left, const Register& right,
|
|
Function func)
|
|
{
|
|
(void)left;
|
|
(void)right;
|
|
(void)func;
|
|
randomize();
|
|
TrustedProgramParty& party = *TrustedProgramParty::singleton;
|
|
party.new_gate();
|
|
party.store_wire(*this);
|
|
}
|
|
|
|
void RandomRegister::input(party_id_t from, char value)
|
|
{
|
|
(void)value;
|
|
randomize();
|
|
auto& party = TrustedProgramParty::s();
|
|
party.store_wire(*this);
|
|
party.msg_input_masks[from - 1].push_back(get_mask());
|
|
#ifdef DEBUG
|
|
cout << "(Random) input from " << from << ":" << endl;
|
|
keys.print(get_id());
|
|
#endif
|
|
}
|
|
|
|
void RandomRegister::public_input(bool value)
|
|
{
|
|
auto& party = TrustedProgramParty::s();
|
|
keys.init(party.get_n_parties());
|
|
for (int i = 0; i < party.get_n_parties(); i++)
|
|
{
|
|
keys[value][i] = 0;
|
|
keys[1 - value][i] = party.delta(i);
|
|
}
|
|
set_mask(0);
|
|
party.store_wire(*this);
|
|
}
|
|
|
|
void GarbleRegister::input(party_id_t from, char value)
|
|
{
|
|
(void)from;
|
|
(void)value;
|
|
TrustedProgramParty::s().load_wire(*this);
|
|
}
|
|
|
|
void GarbleRegister::public_input(bool value)
|
|
{
|
|
(void)value;
|
|
TrustedProgramParty::s().load_wire(*this);
|
|
}
|
|
|
|
void RandomRegister::random()
|
|
{
|
|
randomize();
|
|
TrustedProgramParty::s().add_all_keys(*this, 0);
|
|
TrustedProgramParty::s().store_wire(*this);
|
|
#ifdef DEBUG
|
|
cout << "random mask: " << get_mask() << endl;
|
|
cout << "(Random) random:" << endl;
|
|
keys.print(get_id());
|
|
#endif
|
|
}
|
|
|
|
void GarbleRegister::random()
|
|
{
|
|
TrustedProgramParty::s().load_wire(*this);
|
|
}
|
|
|
|
void RandomRegister::output()
|
|
{
|
|
TrustedProgramParty::s().msg_output_masks.push_back(get_mask());
|
|
}
|
|
|
|
void EvalRegister::output()
|
|
{
|
|
ProgramParty& party = ProgramParty::s();
|
|
party.load_wire(*this);
|
|
set_mask(party.output_masks.pop_front());
|
|
#ifdef DEBUG_REGS
|
|
cout << "check " << get_id() << endl;
|
|
#endif
|
|
check_signal_key(party.get_id(), garbled_entry);
|
|
party.taint();
|
|
}
|
|
|
|
void RandomRegister::XOR(const Register& left, const Register& right)
|
|
{
|
|
mask = left.get_mask() ^ right.get_mask();
|
|
keys[0] = left.keys[0] ^ right.keys[0];
|
|
keys[1] = keys[0] ^ TrustedProgramParty::singleton->get_deltas();
|
|
}
|
|
|
|
void GarbleRegister::XOR(const Register& left, const Register& right)
|
|
{
|
|
mask = left.get_mask() ^ right.get_mask();
|
|
}
|
|
|
|
void PRFRegister::XOR(const Register& left, const Register& right)
|
|
{
|
|
int i = ProgramParty::s()._id - 1;
|
|
Key delta = ProgramParty::s().get_delta();
|
|
keys[0][i] = left.keys[0][i] ^ right.keys[0][i];
|
|
keys[1][i] = keys[0][i] ^ delta;
|
|
#ifdef DEBUG
|
|
cout << "PRF XOR" << endl;
|
|
cout << "delta " << delta << endl;
|
|
cout << endl;
|
|
cout << "res" << endl;
|
|
keys.print(get_id());
|
|
cout << "left" << endl;
|
|
left.keys.print(left.get_id());
|
|
cout << "right" << endl;
|
|
right.keys.print(right.get_id());
|
|
#endif
|
|
#ifdef DEBUG_FREE_XOR
|
|
if (delta != (left.keys[0][i] ^ left.keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
if (delta != (right.keys[0][i] ^ right.keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
if (delta != (keys[0][i] ^ keys[1][i]))
|
|
throw runtime_error("inconsistent delta");
|
|
#endif
|
|
}
|
|
|
|
void EvalRegister::XOR(const Register& left, const Register& right)
|
|
{
|
|
external = left.get_external() ^ right.get_external();
|
|
garbled_entry = left.get_garbled_entry() ^ right.get_garbled_entry();
|
|
#ifdef DEBUG
|
|
cout << "Eval XOR *" << get_id() << " = *" << left.get_id() << " ^ *" << right.get_id() << endl;
|
|
for (size_t i = 0; i < garbled_entry.size(); i++)
|
|
cout << garbled_entry[i] << " = " << left.get_garbled_entry()[i]
|
|
<< " ^ " << right.get_garbled_entry()[i] << endl;
|
|
#endif
|
|
}
|
|
|
|
template <>
|
|
void RandomRegister::store(NoMemory& mem,
|
|
const vector< GC::WriteAccess< GC::Secret<RandomRegister> > >& accesses)
|
|
{
|
|
(void)mem;
|
|
for (auto access : accesses)
|
|
{
|
|
for (auto& reg : access.source.get_regs())
|
|
TrustedProgramParty::s().store_spdz_wire(SPDZ_STORE, reg);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void RandomRegister::load(vector<GC::ReadAccess< GC::Secret<RandomRegister> > >& accesses,
|
|
const NoMemory& source)
|
|
{
|
|
(void)source;
|
|
for (auto access : accesses)
|
|
for (auto& reg : access.dest.get_regs())
|
|
{
|
|
((RandomRegister*)®)->randomize();
|
|
TrustedProgramParty::s().store_spdz_wire(SPDZ_LOAD, reg);
|
|
TrustedProgramParty::s().store_wire(reg);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void GarbleRegister::load(vector<GC::ReadAccess< GC::Secret<GarbleRegister> > >& accesses,
|
|
const NoMemory& source)
|
|
{
|
|
(void)source;
|
|
for (auto access : accesses)
|
|
for (auto& reg : access.dest.get_regs())
|
|
TrustedProgramParty::s().load_wire(reg);
|
|
}
|
|
|
|
KeyVector KeyVector::operator^(const KeyVector& other) const
|
|
{
|
|
if (size() != other.size())
|
|
throw runtime_error("size mismatch");
|
|
KeyVector res;
|
|
res.resize(size());
|
|
for (size_t i = 0; i < size(); i++)
|
|
res[i] = (*this)[i] ^ other[i];
|
|
return res;
|
|
}
|
|
|
|
ostream& operator<<(ostream& os, const KeyVector& kv)
|
|
{
|
|
for (size_t i = 0; i < kv.size(); i++)
|
|
os << kv[i] << " ";
|
|
return os;
|
|
}
|
|
|
|
template <int I>
|
|
KeyTuple<I> KeyTuple<I>::operator^(const KeyTuple<I>& other) const
|
|
{
|
|
KeyTuple<I> res;
|
|
for (int i = 0; i < I; i++)
|
|
res[i] = (*this)[i] ^ other[i];
|
|
return res;
|
|
}
|
|
|
|
template <int I>
|
|
void KeyTuple<I>::copy_to(Key* dest) {
|
|
for (int b = 0; b < I; b++)
|
|
avx_memcpy(dest + b * keys[0].size(), keys[b].data(), part_size());
|
|
}
|
|
|
|
template <int I>
|
|
void KeyTuple<I>::copy_from(Key* source, int n_parties, int except)
|
|
{
|
|
#ifdef DEBUG
|
|
cout << "skip copying for " << except << endl;
|
|
#endif
|
|
for (int b = 0; b < I; b++)
|
|
{
|
|
keys[b].resize(n_parties);
|
|
for (int i = 0; i < n_parties; i++)
|
|
{
|
|
if (i != except)
|
|
keys[b][i] = *source;
|
|
source++;
|
|
#ifdef DEBUG
|
|
cout << "copy from " << b << " " << i << " " << keys[b][i] << endl;
|
|
#endif
|
|
}
|
|
}
|
|
}
|
|
|
|
template <int I>
|
|
long KeyTuple<I>::counter = 0;
|
|
|
|
template<int I>
|
|
void KeyTuple<I>::randomize()
|
|
{
|
|
for (int i = 0; i < I; i++)
|
|
{
|
|
CommonParty::s().prng.get_octets((octet*)keys[i].data(), part_size());
|
|
#ifdef DEBUG
|
|
for (unsigned j = 0; j < keys[i].size(); j++)
|
|
{
|
|
keys[i][j] = { 0, (counter << 16) + (i << 8) + j };
|
|
counter++;
|
|
}
|
|
#endif
|
|
}
|
|
}
|
|
|
|
template<int I>
|
|
void KeyTuple<I>::print(int wire_id) const
|
|
{
|
|
for(int b=0; b<I; b++) {
|
|
for (party_id_t i=1; i<=keys[b].size(); i++) {
|
|
printf("k^%d_{%u,%d}: ",i,wire_id,b);
|
|
std::cout << keys[b][i-1] << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<int I>
|
|
void KeyTuple<I>::print(int wire_id, party_id_t pid)
|
|
{
|
|
for(int b=0; b<I; b++) {
|
|
printf("k^%d_{%u,%d}: ",pid,wire_id,b);
|
|
std::cout << keys[b][pid-1] << std::endl;
|
|
}
|
|
}
|
|
|
|
template class KeyTuple<2>;
|
|
template class KeyTuple<4> ;
|