mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
331 lines
7.8 KiB
C++
331 lines
7.8 KiB
C++
/*
|
|
* Party.h
|
|
*
|
|
*/
|
|
|
|
#ifndef PROTOCOL_PARTY_H_
|
|
#define PROTOCOL_PARTY_H_
|
|
|
|
#include <mutex>
|
|
#include <boost/atomic.hpp>
|
|
#include <boost/thread.hpp>
|
|
|
|
#include "Register.h"
|
|
#include "GarbledGate.h"
|
|
#include "network/Node.h"
|
|
#include "CommonParty.h"
|
|
#include "SpdzWire.h"
|
|
#include "AndJob.h"
|
|
|
|
#include "GC/Machine.h"
|
|
#include "GC/Program.h"
|
|
#include "GC/Processor.h"
|
|
#include "GC/Secret.h"
|
|
#include "GC/RuntimeBranching.h"
|
|
#include "Tools/Worker.h"
|
|
|
|
class BooleanCircuit;
|
|
|
|
#define SERVER_ID (0)
|
|
#define INPUT_KEYS_MSG_TYPE_SIZE (16) // so memory will by alligned
|
|
|
|
#ifndef N_EVAL_THREADS
|
|
// Default Intel desktop processor has 8 half cores.
|
|
// This is beneficial if only one AES available per full core.
|
|
#define N_EVAL_THREADS (thread::hardware_concurrency())
|
|
#endif
|
|
|
|
|
|
typedef struct {
|
|
unsigned long min=0;
|
|
unsigned long long acc=0;
|
|
} exec_props_t;
|
|
|
|
class PartyProperties
|
|
{
|
|
protected:
|
|
party_id_t _id;
|
|
|
|
Timer online_timer;
|
|
|
|
Key delta;
|
|
|
|
public:
|
|
PartyProperties() : _id(-1) {}
|
|
|
|
party_id_t get_id() { return _id; }
|
|
Key get_delta() { return delta; }
|
|
|
|
};
|
|
|
|
class BaseParty : virtual public CommonFakeParty, virtual public PartyProperties
|
|
{
|
|
public:
|
|
BaseParty();
|
|
virtual ~BaseParty();
|
|
|
|
/* From NodeUpdatable class */
|
|
void NodeReady();
|
|
void NewMessage(int from, ReceivedMsg& msg);
|
|
void NodeAborted(struct sockaddr_in* from) { (void)from; }
|
|
|
|
void Start();
|
|
|
|
protected:
|
|
// int _num_evaluation_threads;
|
|
struct timeval _start_online_net, _end_online_net;
|
|
|
|
virtual void _compute_prfs_outputs(Key* keys) = 0;
|
|
void _send_prfs();
|
|
|
|
virtual void _process_external_received(char* externals,
|
|
party_id_t from) = 0;
|
|
virtual void _process_all_external_received(char* externals) = 0;
|
|
virtual void _process_input_keys(Key* keys, party_id_t from) = 0;
|
|
virtual void _process_all_input_keys(char* keys) = 0;
|
|
|
|
virtual void store_garbled_circuit(ReceivedMsg& msg) = 0;
|
|
virtual void _check_evaluate() = 0;
|
|
|
|
virtual void mask_output(ReceivedMsg& msg) = 0;
|
|
virtual void mask_input(ReceivedMsg& msg) = 0;
|
|
|
|
void done();
|
|
|
|
virtual void start_online_round() = 0;
|
|
|
|
virtual void receive_spdz_wires(ReceivedMsg& msg) = 0;
|
|
};
|
|
|
|
class Party : public BaseParty, public CommonCircuitParty {
|
|
friend class BooleanCircuit;
|
|
|
|
public:
|
|
Party(const char* netmap_file, const char* circuit_file, party_id_t id, const std::string input, int numthreads=5, int numtries=2);
|
|
virtual ~Party();
|
|
|
|
/* TEST methods */
|
|
|
|
private:
|
|
wire_id_t _IO;
|
|
|
|
std::string _all_input;
|
|
|
|
int _NUMTHREADS;
|
|
int _NUMTRIES;
|
|
|
|
vector<GarbledGate> _garbled_tbl;
|
|
|
|
vector<char> _input;
|
|
|
|
vector<char> input_masks;
|
|
vector<char>::iterator input_mask;
|
|
|
|
SendBuffer _external_values_msg;
|
|
int _num_externals_msg_received;
|
|
std::mutex _process_externals_mx;
|
|
|
|
SendBuffer _input_wire_keys_msg;
|
|
int _num_inputkeys_msg_received;
|
|
std::mutex _process_keys_mx;
|
|
|
|
std::mutex _sync_mx;
|
|
|
|
#ifdef __PURE_SHE__
|
|
Key* _sqr_keys;
|
|
inline Key* _sqr_key(party_id_t i, wire_id_t w,int b) {return _sqr_keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
|
|
#endif
|
|
|
|
inline Key& _key(party_id_t i, wire_id_t w,int b) {return registers[w][b][i-1] ; }
|
|
inline KeyVector& _garbled_entry(gate_id_t g, int entry) {return _garbled_tbl[g-1][entry];}
|
|
vector<GarbledGate>::iterator get_garbled_tbl_end() { return _garbled_tbl.begin() + garbled_tbl_size; }
|
|
void resize_garbled_tbl() { _garbled_tbl.resize(_G, _N); garbled_tbl_size = _G; }
|
|
|
|
void _initialize_input();
|
|
void _generate_prf_inputs();
|
|
void _compute_prfs_outputs(Key* keys);
|
|
void _print_keys();
|
|
|
|
void _generate_external_values_msg();
|
|
void _process_external_received(char* externals,
|
|
party_id_t from);
|
|
void _process_all_external_received(char* externals);
|
|
void _print_input_keys_checksum();
|
|
void _process_input_keys(Key* keys, party_id_t from);
|
|
void _process_all_input_keys(char* keys);
|
|
|
|
void _print_input_keys_msg();
|
|
void _print_keys_of_party(Key *keys, int id);
|
|
void _printf_garbled_table();
|
|
|
|
void store_garbled_circuit(ReceivedMsg& msg);
|
|
void load_garbled_circuit() {}
|
|
|
|
void _check_evaluate();
|
|
|
|
void receive_keys(Key* keys);
|
|
|
|
void receive_spdz_wires(ReceivedMsg& msg) { (void)msg; }
|
|
|
|
void start_online_round();
|
|
|
|
void mask_output(ReceivedMsg& msg);
|
|
void mask_input(ReceivedMsg& msg);
|
|
|
|
int get_n_inputs();
|
|
};
|
|
|
|
class ProgramParty : virtual public CommonParty, virtual public PartyProperties, public GC::RuntimeBranching
|
|
{
|
|
protected:
|
|
friend class PRFRegister;
|
|
friend class EvalRegister;
|
|
friend class Register;
|
|
|
|
vector<char> prf_output;
|
|
|
|
deque<octetStream> spdz_wires[SPDZ_OP_N];
|
|
size_t spdz_storage;
|
|
size_t garbled_storage;
|
|
vector<size_t> spdz_counters;
|
|
|
|
Worker<AndJob>* eval_threads;
|
|
vector<AndJob> and_jobs;
|
|
|
|
ReceivedMsgStore output_masks_store;
|
|
ReceivedMsgStore input_masks_store;
|
|
|
|
GC::Machine< GC::Secret<EvalRegister> > machine;
|
|
GC::Processor<GC::Secret<EvalRegister> > processor;
|
|
GC::Program<GC::Secret<EvalRegister> > program;
|
|
|
|
GC::Machine< GC::Secret<PRFRegister> > prf_machine;
|
|
GC::Processor<GC::Secret<PRFRegister> > prf_processor;
|
|
|
|
void store_garbled_circuit(ReceivedMsg& msg);
|
|
void load_garbled_circuit();
|
|
|
|
virtual void _check_evaluate() = 0;
|
|
virtual void done() = 0;
|
|
|
|
virtual void receive_keys(Register& reg) = 0;
|
|
virtual void receive_all_keys(Register& reg, bool external) = 0;
|
|
virtual void process_prf_output(PRFOutputs& prf_output,
|
|
PRFRegister* out, const PRFRegister* left, const PRFRegister* right) = 0;
|
|
|
|
void start_online_round();
|
|
|
|
void mask_output(ReceivedMsg& msg) { output_masks_store.push(msg); }
|
|
void mask_input(ReceivedMsg& msg) { input_masks_store.push(msg); }
|
|
|
|
public:
|
|
static ProgramParty* singleton;
|
|
|
|
LocalBuffer garbled_circuit;
|
|
ReceivedMsgStore garbled_circuits;
|
|
|
|
LocalBuffer output_masks;
|
|
LocalBuffer input_masks;
|
|
|
|
Player* P;
|
|
Names N;
|
|
|
|
int threshold;
|
|
|
|
Integer convcbit;
|
|
|
|
static ProgramParty& s();
|
|
|
|
ProgramParty();
|
|
virtual ~ProgramParty();
|
|
|
|
void reset();
|
|
|
|
void store_wire(const Register& reg);
|
|
void load_wire(Register& reg);
|
|
};
|
|
|
|
template<class T>
|
|
class ProgramPartySpec : public ProgramParty
|
|
{
|
|
static ProgramPartySpec* singleton;
|
|
|
|
protected:
|
|
GC::Memory<T> dynamic_memory;
|
|
|
|
void _check_evaluate();
|
|
|
|
public:
|
|
typename T::MAC_Check* MC;
|
|
|
|
static ProgramPartySpec& s();
|
|
|
|
ProgramPartySpec();
|
|
~ProgramPartySpec();
|
|
|
|
void load(string progname);
|
|
|
|
void get_spdz_wire(SpdzOp op, DualWire<T>& spdz_wire);
|
|
};
|
|
|
|
#ifdef SPDZ_AUTH
|
|
typedef ProgramPartySpec<Share<gf2n_long>> FakeProgramPartySuper;
|
|
#else
|
|
typedef ProgramPartySpec<GC::Memory<AuthValue>> FakeProgramPartySuper;
|
|
#endif
|
|
|
|
class FakeProgramParty : virtual public BaseParty, virtual public FakeProgramPartySuper
|
|
{
|
|
Key* keys_for_prf;
|
|
|
|
void _compute_prfs_outputs(Key* keys);
|
|
|
|
void _process_external_received(char* externals,
|
|
party_id_t from) { (void)externals; (void)from; }
|
|
void _process_all_external_received(char* externals) { (void)externals; }
|
|
void _process_input_keys(Key* keys, party_id_t from)
|
|
{ (void)keys; (void)from; }
|
|
void _process_all_input_keys(char* keys) { (void)keys; }
|
|
|
|
void store_garbled_circuit(ReceivedMsg& msg) { ProgramParty::store_garbled_circuit(msg); }
|
|
|
|
void _check_evaluate();
|
|
|
|
void receive_keys(Register& reg);
|
|
void receive_all_keys(Register& reg, bool external);
|
|
void process_prf_output(PRFOutputs& prf_output, PRFRegister* out,
|
|
const PRFRegister* left, const PRFRegister* right);
|
|
|
|
void receive_spdz_wires(ReceivedMsg& msg);
|
|
|
|
void start_online_round() { FakeProgramPartySuper::start_online_round(); }
|
|
|
|
void mask_output(ReceivedMsg& msg) { ProgramParty::mask_output(msg); }
|
|
void mask_input(ReceivedMsg& msg) { ProgramParty::mask_input(msg); }
|
|
|
|
void done() { BaseParty::done(); }
|
|
|
|
public:
|
|
FakeProgramParty(int argc, const char** argv);
|
|
~FakeProgramParty();
|
|
};
|
|
|
|
inline ProgramParty& ProgramParty::s()
|
|
{
|
|
if (singleton)
|
|
return *singleton;
|
|
else
|
|
throw runtime_error("no singleton");
|
|
}
|
|
|
|
template<class T>
|
|
inline ProgramPartySpec<T>& ProgramPartySpec<T>::s()
|
|
{
|
|
if (singleton)
|
|
return *singleton;
|
|
else
|
|
throw runtime_error("no singleton");
|
|
}
|
|
|
|
#endif /* PROTOCOL_PARTY_H_ */
|