ORAM in SPDZ-BMR.

This commit is contained in:
Marcel Keller
2018-03-07 12:25:45 +00:00
parent e43be00836
commit 2f50444b93
340 changed files with 8450 additions and 1839 deletions

1
.gitignore vendored
View File

@@ -44,6 +44,7 @@ callgrind.out.*
# Compiled source #
###################
Programs/Source/*
Programs/Bytecode/*
Programs/Schedules/*
Programs/Public-Input/*

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "Auth/MAC_Check.h"
@@ -366,16 +366,76 @@ void Direct_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S
this->CheckIfNeeded(P);
}
template<class T>
Passing_MAC_Check<T>::Passing_MAC_Check(const T& ai, Names& Nms, int num) :
Separate_MAC_Check<T>(ai, Nms, num)
{
}
template<class T, int t>
void passing_add_openings(vector<T>& values, octetStream& os)
{
octetStream new_os;
for (unsigned int i=0; i<values.size(); i++)
{
T tmp;
tmp.unpack(os);
(tmp + values[i]).pack(new_os);
}
os = new_os;
}
template<class T>
void Passing_MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
values.resize(S.size());
this->os.reset_write_head();
for (unsigned int i=0; i<S.size(); i++)
{
S[i].get_share().pack(this->os);
values[i] = S[i].get_share();
}
this->AddToMacs(S);
for (int i = 0; i < P.num_players() - 1; i++)
{
P.pass_around(this->os);
if (T::t() == 2)
passing_add_openings<T,2>(values, this->os);
else
passing_add_openings<T,0>(values, this->os);
}
for (unsigned int i = 0; i < values.size(); i++)
{
T tmp;
tmp.unpack(this->os);
this->vals.push_back(tmp);
}
}
template<class T>
void Passing_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
(void)S;
this->GetValues(values);
this->popen_cnt += values.size();
this->CheckIfNeeded(P);
}
template class MAC_Check<gfp>;
template class Direct_MAC_Check<gfp>;
template class Parallel_MAC_Check<gfp>;
template class Passing_MAC_Check<gfp>;
template class MAC_Check<gf2n>;
template class Direct_MAC_Check<gf2n>;
template class Parallel_MAC_Check<gf2n>;
template class Passing_MAC_Check<gf2n>;
#ifdef USE_GF2N_LONG
template class MAC_Check<gf2n_short>;
template class Direct_MAC_Check<gf2n_short>;
template class Parallel_MAC_Check<gf2n_short>;
template class Passing_MAC_Check<gf2n_short>;
#endif

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _MAC_Check
#define _MAC_Check
@@ -161,6 +161,16 @@ public:
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
};
template <class T>
class Passing_MAC_Check : public Separate_MAC_Check<T>
{
public:
Passing_MAC_Check(const T& ai, Names& Nms, int thread_num);
void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
};
enum mc_timer { SEND, RECV_ADD, BCAST, RECV_SUM, SEED, COMMIT, WAIT_SUMMER, RECV, SUM, SELECT, MAX_TIMER };

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "Auth/Subroutines.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _Subroutines
#define _Subroutines

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Summer.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Summer.h

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "Math/gf2n.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _fake_stuff

38
BMR/AndJob.cpp Normal file
View File

@@ -0,0 +1,38 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* AndJob.cpp
*
*/
#include "AndJob.h"
#include "Party.h"
#include "Register_inline.h"
int AndJob::run()
{
#ifdef DEBUG_AND_JOB
printf("thread %d: run and job from %d to %d with %d gates\n",
pthread_self(), start, end, gates.size());
#endif
__m128i prf_output[PAD_TO_8(MAX_N_PARTIES)];
auto gate = gates.begin();
vector< GC::Secret<EvalRegister> >& S = *this->S;
const vector<int>& args = *this->args;
int i_gate = 0;
for (size_t i = start; i < end; i += 4)
{
GC::Secret<EvalRegister>& dest = S[args[i + 1]];
for (int j = 0; j < args[i]; j++)
{
i_gate++;
gate->init_inputs(gate_id + i_gate,
ProgramParty::s().get_n_parties());
dest.get_reg(j).eval(S[args[i + 2]].get_reg(j),
S[args[i + 3]].get_reg(0), *gate,
ProgramParty::s().get_id(), (char*) prf_output, 0, 0, 0);
gate++;
}
}
return i_gate;
}

44
BMR/AndJob.h Normal file
View File

@@ -0,0 +1,44 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* AndJob.h
*
*/
#ifndef BMR_ANDJOB_H_
#define BMR_ANDJOB_H_
#include "GC/Secret.h"
#include "Register.h"
#include <vector>
using namespace std;
class AndJob
{
vector< GC::Secret<EvalRegister> >* S;
const vector<int>* args;
public:
vector<GarbledGate> gates;
size_t start, end;
gate_id_t gate_id;
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
{
this->S = &S;
this->args = &args;
this->start = start;
this->end = start;
this->gate_id = gate_id;
if (gates.size() < n_gates)
gates.resize(n_gates, {n_parties});
}
int run();
};
#endif /* BMR_ANDJOB_H_ */

View File

@@ -1,14 +1,16 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "BooleanCircuit.h"
#include "prf.h"
static void throw_party_exists(pid_t pid, unsigned int pos) {
std::cout << "ERROR: in circuit description" << std::endl
<< "\tPosition: " << pos << std::endl
<< "\tPlayer id " << pid << " already exists" << std::endl;
throw std::invalid_argument( "player id error" );
}
//static void throw_party_exists(pid_t pid, unsigned int pos) {
// std::cout << "ERROR: in circuit description" << std::endl
// << "\tPosition: " << pos << std::endl
// << "\tPlayer id " << pid << " already exists" << std::endl;
// throw std::invalid_argument( "player id error" );
//}
static void throw_bad_circuit_file() {
std::cout << "ERROR: could not read circuit file" << std::endl;
@@ -27,7 +29,7 @@ void BooleanCircuit::_parse_circuit(const char* desc_file)
unsigned int total_input_wires;
circuit_file >> total_input_wires;
for (int idx_party = 1; idx_party <= _num_parties; idx_party++) {
for (size_t idx_party = 1; idx_party <= _num_parties; idx_party++) {
unsigned int num_input_wires = total_input_wires/_num_parties;
if (idx_party == _num_parties) {
num_input_wires = total_input_wires-(total_input_wires/_num_parties)*(_num_parties-1);
@@ -53,7 +55,7 @@ void BooleanCircuit::_parse_circuit(const char* desc_file)
/* Parse gates */
for (gate_id_t gate_id = 1; gate_id <= _num_gates; gate_id++)
{
int fan_in, fan_out, left, right, out;
size_t fan_in, fan_out, left, right, out;
std::string func;
circuit_file >> fan_in >> fan_out >> left >> right >> out >> func;
@@ -163,11 +165,11 @@ int BooleanCircuit::__make_layers(gate_id_t g)
return 0;
}
int layer_left, layer_right;
if(_gates[g]._left == NULL)
if(_gates[g]._left == 0)
layer_left = -1;
else
layer_left = __make_layers(_wires[_gates[g]._left]._out_from);
if(_gates[g]._right == NULL)
if(_gates[g]._right == 0)
layer_right = -1;
else
layer_right = __make_layers(_wires[_gates[g]._right]._out_from);
@@ -183,7 +185,7 @@ int BooleanCircuit::__make_layers(gate_id_t g)
void BooleanCircuit::_add_to_layer(int layer, gate_id_t g)
{
// printf("adding %u to layer %d\n", g, layer);
if (_layers.size()<layer+1) {
if (_layers.size()<(size_t)layer+1) {
// printf("layer doesn't exist, creating it\n");
_layers.resize(layer+1);
}
@@ -192,11 +194,11 @@ void BooleanCircuit::_add_to_layer(int layer, gate_id_t g)
void BooleanCircuit::_print_layers()
{
printf("num layers = %d\n",_layers.size());
for(int i=0; i<_layers.size(); i++) {
printf ("\nlayer %d size=%d\n",i,_layers[i].size());
for(int j=0; j<_layers[i].size(); j++) {
printf("%u ", _layers[i][j]);
printf("num layers = %lu\n",_layers.size());
for(size_t i=0; i<_layers.size(); i++) {
printf ("\nlayer %lu size=%lu\n",i,_layers[i].size());
for(size_t j=0; j<_layers[i].size(); j++) {
printf("%lu ", _layers[i][j]);
}
printf ("\n");
}
@@ -207,11 +209,11 @@ void BooleanCircuit::_validate_layers()
_max_layer_sz = 0;
for (gate_id_t g=1; g<_num_gates; g++){
if(_gates[g]._layer == NO_LAYER)
printf("gate %d has no layer!\n",g);
printf("gate %lu has no layer!\n",g);
assert(_gates[g]._layer != NO_LAYER);
}
gate_id_t sum_gates_in_layers = 0;
for(int i=0; i<_layers.size(); i++) {
for(size_t i=0; i<_layers.size(); i++) {
sum_gates_in_layers += _layers[i].size();
if(_layers[i].size() > _max_layer_sz) {
_max_layer_sz = _layers[i].size();
@@ -245,7 +247,7 @@ void BooleanCircuit::Inputs(const char* inputs_file_path)
}
BooleanCircuit::BooleanCircuit(const char* desc_file)
:_num_evaluated_out_wires(0),_num_input_wires(0), _keys(NULL), _masks(NULL), _prf_inputs(NULL)
:_num_evaluated_out_wires(0)
{
_parse_circuit(desc_file);
_make_layers();
@@ -259,18 +261,17 @@ void BooleanCircuit::EvaluateByLayerLinearly(party_id_t my_id) {
mpz_t temp_mpz;
init_temp_mpz_t(temp_mpz);
#endif
for(int i=0; i<_layers.size(); i++) {
for (int j=0; j<_layers[i].size(); j++) {
for(size_t i=0; i<_layers.size(); i++) {
for (size_t j=0; j<_layers[i].size(); j++) {
gate_id_t gid = _layers[i][j];
#ifdef __PURE_SHE__
signal_t s = _eval_gate(gid, my_id, prf_output, temp_mpz);
_eval_gate(gid, my_id, prf_output, temp_mpz);
#else
signal_t s = _eval_gate(gid, my_id, prf_output);
_eval_gate(gid, my_id, prf_output);
#endif
// printf("%d ", s);
_externals[_gates[gid]._out] = s;
}
}
delete[] prf_output;
}
void BooleanCircuit::EvaluateByLayer(int num_threads, party_id_t my_id)
@@ -291,7 +292,7 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
mpz_t temp_mpz;
init_temp_mpz_t(temp_mpz);
#endif
for(int l=0; l<_layers.size(); l++) {
for(size_t l=0; l<_layers.size(); l++) {
int layer_sz = _layers[l].size();
int start_idx = (layer_sz/num_threads)*i;
int end_idx = (layer_sz/num_threads)*(i+1)-1;
@@ -301,11 +302,10 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
for(int g=start_idx; g<=end_idx; g++) {
gate_id_t gid = _layers[l][g];
#ifdef __PURE_SHE__
signal_t s = _eval_gate(gid, my_id, prf_output, temp_mpz);
_eval_gate(gid, my_id, prf_output, temp_mpz);
#else
signal_t s = _eval_gate(gid, my_id, prf_output);
_eval_gate(gid, my_id, prf_output);
#endif
_externals[_gates[gid]._out] = s;
}
// printf("done eval layer %d\n", l);
@@ -375,75 +375,17 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
//}
#ifdef __PURE_SHE__
signal_t BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz)
void BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz)
#else
signal_t BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output)
void BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output)
#endif
{
// std::cout << std::endl << "evaluate gate " << g << std::endl;
wire_id_t w_l = _gates[g]._left;
wire_id_t w_r = _gates[g]._right;
wire_id_t w_o = _gates[g]._out;
int sig_l = _externals[w_l];
int sig_r = _externals[w_r];
int entry = 2 * sig_l + sig_r;
Key* garbled_entry = _garbled_entry(g, entry);
int ext_l = entry%2 ? 1 : 0 ;
int ext_r = entry<2 ? 0 : 1 ;
Key k;
for(party_id_t i=1; i<=_num_parties; i++) {
// std::cout << "using key: " << *_key(i,w_l,sig_l) << ": ";
PRF_chunk(_key(i,w_l,sig_l), _input(ext_l,g,1), prf_output, PAD_TO_8(_num_parties));
for(party_id_t j=1; j<=_num_parties; j++) {
k = *(Key*)(prf_output+16*(j-1));
#ifdef __PRIME_FIELD__
k.adjust();
#endif
// printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_l, sig_l,ext_l,g,j);
// std::cout << k << std::endl;
garbled_entry[j-1] -= k;
}
PRF_chunk(_key(i,w_r,sig_r), _input(ext_r,g,1) , prf_output, PAD_TO_8(_num_parties));
for(party_id_t j=1; j<=_num_parties; j++) {
// std::cout << "using key: " << *_key(i,w_r,sig_r) << ": ";
k = *(Key*)(prf_output+16*(j-1));
#ifdef __PRIME_FIELD__
k.adjust();
#endif
// printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_r, sig_r,ext_r,g,j);
// std::cout << k << std::endl;
garbled_entry[j-1] -= k;
}
}
#if __PURE_SHE__
for(party_id_t j=1; j<=_num_parties; j++) {
garbled_entry[j-1].sqr_in_place(tmp_mpz);
}
#endif
// for(party_id_t i=1; i<=_num_parties; i++) {
// std::cout << garbled_entry[i-1] << " ";
// }
// std::cout << std::endl;
if(garbled_entry[my_id-1] == *_key(my_id, w_o, 0)) {
// std::cout << "k^"<<my_id<<"_{"<<w_o<<",0} = " << *_key(my_id, w_o, 0) << std::endl;
memcpy(_key(1,w_o,0), garbled_entry, sizeof(Key)*_num_parties);
return 0;
} else if (garbled_entry[my_id-1] == *_key(my_id, w_o, 1)) {
// std::cout << "k^"<<my_id<<"_{"<<w_o<<",1} = " << *_key(my_id, w_o, 1) << std::endl;
memcpy(_key(1,w_o,1), garbled_entry, sizeof(Key)*_num_parties);
return 1;
} else {
printf("\nERROR!!!\n");
// throw std::invalid_argument("result key doesn't fit any of my keys");
// return NO_SIGNAL;
}
party->registers[w_o].eval(party->registers[w_l], party->registers[w_r],
party->_garbled_tbl[g-1], my_id, prf_output, w_o, w_l, w_r);
}
@@ -522,12 +464,11 @@ std::string BooleanCircuit::Output()
{
std::cout << "output:" <<std::endl;
std::stringstream ss;
// printf("masks\n");
// phex(_masks, _wires.size());
// printf("externals\n");
// phex(_externals, _wires.size());
for( int i=0; i<_num_output_wires; i++ ) {
int output = _externals[_output_start+i] ^ _masks[_output_start+i];
printf("masks/externals\n");
for( size_t i=0; i<_num_output_wires; i++ ) {
cout << "mask " << i << ": " << (int)party->registers[_output_start+i].mask << endl;
cout << "external " << i << ": " << (int)party->registers[_output_start+i].get_external() << endl;
int output = party->registers[_output_start+i].get_output();
ss << output;
}
std::cout << ss.str();

View File

@@ -1,3 +1,5 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef __BOOLEAN_CIRCUIT__
#define __BOOLEAN_CIRCUIT__
@@ -22,10 +24,10 @@
#include "Gate.h"
#include "Wire.h"
#include "Key.h"
#include "Register.h"
#include "GarbledGate.h"
typedef unsigned int party_id_t;
//#include "Party.h"
#include "Party.h"
#define INIT_PARTY(W,N) {.wires=W, .n_wires=N }
@@ -41,13 +43,11 @@ typedef struct party_t{
#define GARBLED_GATE_SIZE(N) (4*N)
#define MSG_KEYS_HEADER_SZ (16)
//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
class BooleanCircuit
{
friend class Party;
friend class TrustedParty;
friend class CommonCircuitParty;
public:
BooleanCircuit(const char* desc_file);
// void RawInputs(std::string raw_inputs);
@@ -63,15 +63,7 @@ public:
inline wire_id_t OutWiresStart() { return _output_start; }
inline gate_id_t NumGates() {return _num_gates; }
inline wire_id_t NumParties() { return _num_parties; }
inline party_t Party(party_id_t id) {return _parties[id];}
inline void Keys(Key* keys) {_keys = keys;}
inline Key* Keys() { return _keys; }
inline void Masks(char* masks) {_masks = masks;}
inline char* Masks() { return _masks; }
inline char* PrfInputs() {return _prf_inputs;}
inline void PrfInputs(char* prf_inputs) { _prf_inputs = prf_inputs; }
inline char* Prfs(){return _prfs;}
inline void Prfs(char* prfs){_prfs=prfs;}
inline party_t get_party(party_id_t id) {return _parties[id];}
private:
@@ -88,7 +80,7 @@ private:
std::vector<std::vector<gate_id_t>> _layers;
int _max_layer_sz;
size_t _max_layer_sz;
void _make_layers();
int __make_layers(gate_id_t g);
void _add_to_layer(int layer, gate_id_t g);
@@ -111,13 +103,13 @@ private:
void _parse_circuit(const char* desc_file);
void _eval_thread(party_id_t my_id);
#ifdef __PURE_SHE__
signal_t _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz);
void _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz);
#else
signal_t _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output);
void _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output);
#endif
inline bool is_wire_ready(wire_id_t w) {
return _externals[w] != NO_SIGNAL;
return party->registers[w].get_external() != NO_SIGNAL;
}
inline bool is_gate_ready(gate_id_t g) {
bool ready = is_wire_ready(_gates[g]._left)
@@ -125,73 +117,7 @@ private:
return ready;
}
/* Additional data stored per per party per wire: */
Key* _keys; /* Total of n*W*2 keys
* For every w={0,...,W}
* For every b={0,1}
* For every i={1...n}
* k^i_{w,b}
* This is helpful that the keys for specific w and b are adjacent
* for pipelining matters.
*/
inline Key* _key(party_id_t i, wire_id_t w,int b) {return _keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
#ifdef __PURE_SHE__
Key* _sqr_keys;
inline Key* _sqr_key(party_id_t i, wire_id_t w,int b) {return _sqr_keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
#endif
char* _masks; /* There are W masks, one per wire. beginning with 0 */
char* _externals; /* Same as _masks */
Key* _garbled_tbl; /* will be allocated 4*n*G keys;
* (n keys for each of A,B,C,D entries).
* For each g in G:
* A1, A2, ... , An
* B1, B2, ... , Bn
* C1, C2, ... , Cn
* D1, D2, ... , Dn
*/
Key* _garbled_tbl_copy;
inline Key* _garbled_entry(gate_id_t g, int entry) {return _garbled_tbl+(g-1)*4*_num_parties+entry*_num_parties;}
char* _prfs; /*
* Total of n*(G*2*2*2*n+4) = n*(8*G*n+RESERVE_FOR_MSG_TYPE) = n*(PRFS_PER_PARTY+RESERVE_FOR_MSG_TYPE)
*
* For every party i={1...n}
* <Message_type> = saves us copying memory to another location - only for Party.
* For every gate g={1...G}
* For input wires x={left,right}
* For b={0,1} (0-key/1-key)
* For e={0,1} (extension)
* for every party j={1...n}
* F_{k^i_{x,b}}(e,j,g,)
*/
char* _prf_inputs; /*
* These are all possible inputs to the prf,
* This is not efficient in terms of storage but increase
* performance in terms of speed since no need to generate
* (new allocation plus filling) those inputs every time
* we compute the prf.
* Structure:
* - Total of G*n*2 inputs. (i.e. for every gate g, for every
* party j and for every extension e (from {0,1}).
* - We want to be able to set key once and use it to encrypt
* several inputs, so we want those inputs to be adjacent in
* memory in order to save time of building the block of
* inputs. So, the structure of the inputs is as follows:
* - First half of inputs (G*n inputs) are:
* - For every gate g=1,...,G we store the inputs:
* (0||g||1),(0||g||2),...,(0||g||n)
* - Second half of the inputs are:
* - For every gate g=1,...,G we store the inputs:
* (1||g||1),(1||g||2),...,(1||g||n)
*/
inline char* _input(int e, gate_id_t g, party_id_t j)
{ return _prf_inputs + (e*_num_gates*_num_parties + (g-1)*_num_parties + j-1) * 16 ;}
Party* party;
};

219
BMR/CommonParty.cpp Normal file
View File

@@ -0,0 +1,219 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* CommonParty.cpp
*
*/
#include "CommonParty.h"
#include "BooleanCircuit.h"
#include "Tools/benchmarking.h"
CommonParty* CommonParty::singleton = 0;
CommonParty::CommonParty() :
_node(0), gate_counter(0), gate_counter2(0), garbled_tbl_size(0),
cpu_timer(CLOCK_PROCESS_CPUTIME_ID), buffers(TYPE_MAX)
{
insecure("MPC emulation");
if (singleton != 0)
throw runtime_error("there can only be one");
singleton = this;
prng.ReSeed();
#ifdef DEBUG_PRNG
octet seed[SEED_SIZE];
memset(seed, 0, sizeof(seed));
prng.SetSeed(seed);
#endif
cpu_timer.start();
timer.start();
gf2n::init_field(128);
mac_key.randomize(prng);
}
CommonParty::~CommonParty()
{
if (_node)
delete _node;
cout << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl;
cout << "CPU time: " << cpu_timer.elapsed() << endl;
cout << "Total time: " << timer.elapsed() << endl;
cout << "First phase time: " << timers[0].elapsed() << endl;
cout << "Second phase time: " << timers[1].elapsed() << endl;
cout << "Number of gates: " << gate_counter << endl;
}
void CommonParty::init(const char* netmap_file, int id, int n_parties)
{
#ifdef N_PARTIES
if (n_parties != N_PARTIES)
throw runtime_error("wrong number of parties");
#else
#ifdef MAX_N_PARTIES
if (n_parties > MAX_N_PARTIES)
throw runtime_error("too many parties");
#endif
_N = n_parties;
#endif // N_PARTIES
printf("netmap_file: %s\n", netmap_file);
if (0 == strcmp(netmap_file, LOOPBACK_STR)) {
_node = new Node( NULL, id, this, _N + 1);
} else {
_node = new Node(netmap_file, id, this);
}
}
int CommonParty::init(const char* netmap_file, int id)
{
int n_parties;
if (string(netmap_file) != string(LOOPBACK_STR))
{
ifstream(netmap_file) >> n_parties;
n_parties--;
}
else
n_parties = 2;
init(netmap_file, id, n_parties);
return n_parties;
}
void CommonParty::reset()
{
garbled_tbl_size = 0;
}
gate_id_t CommonParty::new_gate()
{
gate_counter++;
garbled_tbl_size++;
return gate_counter;
}
void CommonParty::next_gate(GarbledGate& gate)
{
gate_counter2++;
gate.init_inputs(gate_counter2, _N);
}
void CommonParty::input(Register& reg, party_id_t from)
{
(void)reg;
(void)from;
throw not_implemented();
}
SendBuffer& CommonParty::get_buffer(MSG_TYPE type)
{
SendBuffer& buffer = buffers[type];
buffer.clear();
fill_message_type(buffer, type);
#ifdef DEBUG_BUFFER
cout << type << " buffer:";
phex(buffers.data(), 4);
#endif
return buffer;
}
void CommonCircuitParty::print_masks(const vector<int>& indices)
{
vector<char> bits;
for (auto i = indices.begin(); i != indices.end(); i++)
bits.push_back(registers[*i].get_mask_no_check());
print_bit_array(bits);
}
void CommonCircuitParty::print_outputs(const vector<int>& indices)
{
vector<char> bits;
for (auto i = indices.begin(); i != indices.end(); i++)
bits.push_back(registers[*i].get_output_no_check());
print_bit_array(bits);
}
template <class T, class U>
GC::BreakType CommonParty::first_phase(GC::Program<U>& program,
GC::Processor<T>& processor, GC::Machine<T>& machine)
{
(void)machine;
timers[0].start();
reset();
wires.clear();
GC::BreakType next = (reinterpret_cast<GC::Program<T>*>(&program))->execute(processor);
#ifdef DEBUG_ROUNDS
cout << "finished first phase at pc " << processor.PC
<< " reason " << next << endl;
#endif
timers[0].stop();
cout << "First round time: " << timers[0].elapsed() << " / "
<< timer.elapsed() << endl;
#ifdef DEBUG_WIRES
cout << "Storing wires with " << 1e-9 * wires.size() << " GB on disk" << endl;
#endif
wire_storage.push(wires);
return next;
}
template<class T>
GC::BreakType CommonParty::second_phase(GC::Program<T>& program,
GC::Processor<T>& processor, GC::Machine<T>& machine)
{
(void)machine;
wire_storage.pop(wires);
wires.reset_head();
timers[1].start();
GC::BreakType next = GC::TIME_BREAK;
next = program.execute(processor);
#ifdef DEBUG_ROUNDS
cout << "finished second phase at " << processor.PC
<< " reason " << next << endl;
#endif
timers[1].stop();
// cout << "Second round time: " << timers[1].elapsed() << ", ";
// cout << "total time: " << timer.elapsed() << endl;
if (false)
return GC::CAP_BREAK;
else
return next;
}
void CommonCircuitParty::prepare_input_regs(party_id_t from)
{
party_t sender = _circuit->_parties[from];
wire_id_t s = sender.wires; //the beginning of my wires
wire_id_t n = sender.n_wires; // number of my wires
input_regs_queue.clear();
input_regs_queue.push_back(_N + 1);
(*input_regs)[from].clear();
for (wire_id_t i = 0; i < n; i++) {
wire_id_t w = s + i;
(*input_regs)[from].push_back(w);
}
}
void CommonCircuitParty::prepare_output_regs()
{
output_regs.clear();
for (size_t i = 0; i < _OW; i++)
output_regs.push_back(_circuit->OutWiresStart()+i);
}
template GC::BreakType CommonParty::first_phase(
GC::Program<GC::Secret<GarbleRegister> >& program,
GC::Processor<GC::Secret<RandomRegister> >& processor,
GC::Machine<GC::Secret<RandomRegister> >& machine);
template GC::BreakType CommonParty::first_phase(
GC::Program<GC::Secret<EvalRegister> >& program,
GC::Processor<GC::Secret<PRFRegister> >& processor,
GC::Machine<GC::Secret<PRFRegister> >& machine);
template GC::BreakType CommonParty::second_phase(
GC::Program<GC::Secret<GarbleRegister> >& program,
GC::Processor<GC::Secret<GarbleRegister> >& processor,
GC::Machine<GC::Secret<GarbleRegister> >& machine);
template GC::BreakType CommonParty::second_phase(
GC::Program<GC::Secret<EvalRegister> >& program,
GC::Processor<GC::Secret<EvalRegister> >& processor,
GC::Machine<GC::Secret<EvalRegister> >& machine);

163
BMR/CommonParty.h Normal file
View File

@@ -0,0 +1,163 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* CommonParty.h
*
*/
#ifndef BMR_COMMONPARTY_H_
#define BMR_COMMONPARTY_H_
#include <vector>
using namespace std;
#include "GarbledGate.h"
#include "Register.h"
#include "proto_utils.h"
#include "network/Node.h"
#include "Tools/random.h"
#include "Auth/MAC_Check.h"
#include "Tools/time-func.h"
#include "GC/Program.h"
#include "Tools/FlexBuffer.h"
#if (defined(DEBUG) || defined(DEBUG_COMM)) && !defined(DEBUG_STEPS)
#define DEBUG_STEPS
#endif
enum SpdzOp
{
SPDZ_LOAD,
SPDZ_STORE,
SPDZ_MAC,
SPDZ_OP_N,
};
template <class T>
class PersistentFront
{
deque<T>& container;
int i;
public:
PersistentFront(deque<T>& container) : container(container), i(0) {}
T& operator*() { return container.front(); }
void operator++(int) { i++; }
void reset() {}
int get_i() { return i; }
};
class CommonParty : public NodeUpdatable
{
protected:
friend class Register;
#ifdef N_PARTIES
const party_id_t _N = N_PARTIES;
#else
party_id_t _N;
#endif
Node* _node;
int gate_counter, gate_counter2;
int garbled_tbl_size;
Timer cpu_timer;
Timer timers[2];
Timer timer;
gf2n mac_key;
mutex global_lock;
LocalBuffer wires;
ReceivedMsgStore wire_storage;
template<class T, class U>
GC::BreakType first_phase(GC::Program<U>& program, GC::Processor<T>& processor,
GC::Machine<T>& machine);
template<class T>
GC::BreakType second_phase(GC::Program<T>& program, GC::Processor<T>& processor,
GC::Machine<T>& machine);
public:
static CommonParty* singleton;
static CommonParty& s();
vector<SendBuffer> buffers;
PRNG prng;
CommonParty();
virtual ~CommonParty();
#ifdef N_PARTIES
static int get_n_parties() { return N_PARTIES; }
#else
static int get_n_parties() { return s()._N; }
#endif
void init(const char* netmap_file, int id, int n_parties);
int init(const char* netmap_file, int id);
virtual void reset();
gate_id_t new_gate();
void next_gate(GarbledGate& gate);
gate_id_t next_gate(int skip) { return gate_counter2 += skip; }
size_t get_garbled_tbl_size() { return garbled_tbl_size; }
void input(Register& reg, party_id_t from);
SendBuffer& get_buffer(MSG_TYPE type);
gf2n get_mac_key() { return mac_key; }
};
class BooleanCircuit;
class CommonCircuitParty : virtual public CommonParty
{
protected:
BooleanCircuit* _circuit;
gate_id_t _G;
wire_id_t _W;
wire_id_t _OW;
CheckVector<Register> registers;
deque< CheckVector< CheckVector<int> > > input_regs_queue;
PersistentFront< CheckVector< CheckVector<int> > > input_regs;
vector<int> output_regs;
CommonCircuitParty() : input_regs(input_regs_queue) {}
void prepare_input_regs(party_id_t id);
void prepare_output_regs();
void resize_registers() { registers.resize(_W, {(int)_N}); }
Register& get_reg(int reg);
void print_masks(const vector<int>& indices);
void print_outputs(const vector<int>& indices);
void print_round_regs();
};
inline Register& CommonCircuitParty::get_reg(int reg)
{
#ifdef DEBUG_REGS
cout << "get reg " << reg << endl << registers.at(reg).keys[0] << endl
<< registers.at(reg).keys[1] << endl;
#endif
return registers.at(reg);
}
inline CommonParty& CommonParty::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
#endif /* BMR_COMMONPARTY_H_ */

97
BMR/GarbledGate.cpp Normal file
View File

@@ -0,0 +1,97 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* GarbledGate.cpp
*
*/
#include "GarbledGate.h"
#include "prf.h"
#include "CommonParty.h"
GarbledGate::~GarbledGate() {
}
void GarbledGate::init_inputs(gate_id_t g, int n_parties)
{
n_parties = CommonParty::get_n_parties();
id = g;
for (unsigned int e = 0; e <= 1; e++) {
prf_inputs[e].resize(n_parties);
for (unsigned int j = 0; j < (size_t)n_parties; j++) {
prf_inputs[e][j] = 0;
/* fill out this buffer s.t. first 4 bytes are the extension (0/1),
* next 4 bytes are gate_id and next 4 bytes are party id.
* For the first half we dont need to fill the extension because
* it is zero anyway.
*/
unsigned int* prf_input_index =
(unsigned int*) &prf_inputs[e][j]; //easier to refer as integers
// printf("e,g,j=%u,%u,%u\n",e,g,j);
*prf_input_index = e;
*(prf_input_index + 1) = g;
*(prf_input_index + 2) = j + 1;
}
}
}
void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id,
SendBuffer& buffer, gate_id_t g)
{
int n_parties = CommonParty::get_n_parties();
init_inputs(g, n_parties);
PRFOutputs prf_output(n_parties);
for(int w=0; w<=1; w++) {
for (int b=0; b<=1; b++) {
const Key& key = in_wires[w]->key(my_id, b);
AES_KEY aes_key;
AES_128_Key_Expansion((unsigned char*)&key.r, &aes_key);
#ifdef DEBUG
cout << "using key " << key << endl;
#endif
for (int e=0; e<=1; e++) {
for (int j=1; j<= n_parties; j++) {
prf_output[my_id-1][j-1].outputs[w][b][e][0] =
aes_128_encrypt(*(__m128i*)input(e, j), (octet*)aes_key.rd_key);
#ifdef __PRIME_FIELD__
((Key*)prf_outputs_index)->adjust();
#endif
}
}
}
}
for (int i = 0; i < n_parties; i++)
buffer.serialize(prf_output[my_id - 1][i]);
#ifdef DEBUG
wire_id_t wire_ids[] = { (wire_id_t)in_wires[0]->get_id(), (wire_id_t)in_wires[1]->get_id() };
prf_output.print_prfs(g, wire_ids, my_id, n_parties);
#endif
}
void PRFOutputs::print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties)
{
for(int w=0; w<=1; w++) {
for (int b=0; b<=1; b++) {
for (int e=0; e<=1; e++) {
for(party_id_t j=1; j<=(size_t)n_parties; j++) {
printf("F_k^%d_{%lu,%u}(%d,%lu,%u) = ", my_id, in_wires[w], b, e, g, j);
Key k = *((Key*)(*this)[my_id-1][j-1].outputs[w][b][e]);
std::cout << k << std::endl;
}
}
}
}
}
void GarbledGate::print()
{
cout << "garbled gate " << id << endl;
for (int i = 0; i < 4; i++)
{
for (size_t j = 0; j < keys[i].size(); j++)
cout << keys[i][j] << " ";
cout << endl;
}
}

90
BMR/GarbledGate.h Normal file
View File

@@ -0,0 +1,90 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* GarbledGate.h
*
*/
#ifndef BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_
#define BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_
#include "Register.h"
#include "common.h"
struct PRFTuple {
Key outputs[2][2][2][1];
};
/*
* Total of n*(G*2*2*2*n+4) = n*(8*G*n+RESERVE_FOR_MSG_TYPE) = n*(PRFS_PER_PARTY+RESERVE_FOR_MSG_TYPE)
*
* For every party i={1...n}
* <Message_type> = saves us copying memory to another location - only for Party.
* For every gate g={1...G}
* For input wires x={left,right}
* For b={0,1} (0-key/1-key)
* For e={0,1} (extension)
* for every party j={1...n}
* F_{k^i_{x,b}}(e,j,g,)
*/
struct PRFOutputs {
#ifdef MAX_N_PARTIES
PRFTuple tuples[MAX_N_PARTIES][MAX_N_PARTIES];
PRFOutputs(int n_parties) { (void)n_parties; }
PRFTuple* operator[](int i) { return tuples[i]; }
#else
vector<PRFTuple> tuples;
int n_parties;
PRFOutputs(int n_parties) : n_parties(n_parties), tuples(n_parties * n_parties) {}
PRFTuple* operator[](int i) { return &tuples[i*n_parties]; }
#endif
void print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties);
};
class GarbledGate : public KeyTuple<4> {
/* will be allocated 4*n keys;
* (n keys for each of A,B,C,D entries):
* A1, A2, ... , An
* B1, B2, ... , Bn
* C1, C2, ... , Cn
* D1, D2, ... , Dn
*/
public:
KeyVector prf_inputs[2]; /*
* These are all possible inputs to the prf,
* This is not efficient in terms of storage but increase
* performance in terms of speed since no need to generate
* (new allocation plus filling) those inputs every time
* we compute the prf.
* Structure:
* - Total of G*n*2 inputs. (i.e. for every gate g, for every
* party j and for every extension e (from {0,1}).
* - We want to be able to set key once and use it to encrypt
* several inputs, so we want those inputs to be adjacent in
* memory in order to save time of building the block of
* inputs. So, the structure of the inputs is as follows:
* - First half of inputs (G*n inputs) are:
* - For every gate g=1,...,G we store the inputs:
* (0||g||1),(0||g||2),...,(0||../circuit_bmr/inc/GarbledGate.h:43:19: error: gate_id_t has not been declared
* g||n)
* - Second half of the inputs are:
* - For every gate g=1,...,G we store the inputs:
* (1||g||1),(1||g||2),...,(1||g||n)
*/
gate_id_t id;
GarbledGate(int n_parties) : KeyTuple<4>(n_parties), id(-1) {}
virtual ~GarbledGate();
void init_inputs(gate_id_t g, int n_parties);
char* input(int e, party_id_t j) { return (char*)&prf_inputs[e][j-1]; }
void compute_prfs_outputs(const Register** in_wires, int my_id, SendBuffer& buffer, gate_id_t g);
void print();
};
#endif /* BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_ */

View File

@@ -1,32 +1,33 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef __GATE_H__
#define __GATE_H__
#include <string>
#include <iostream>
#include <boost/thread/mutex.hpp>
#include <stdio.h>
#include "Key.h"
#define NO_LAYER (-1)
typedef struct Gate {
class Gate {
public:
wire_id_t _left;
wire_id_t _right;
wire_id_t _out;
uint8_t _func[4];
Function _func;
int _layer;
Gate() : _left(-1), _right(-1), _out(-1), _layer(-1) {}
inline void init(wire_id_t left, wire_id_t right, wire_id_t out,std::string func) {
_left = left;
_right = right;
_out = out;
_func[0] = (func[0]=='0')?0:1;//TODO change to func[0]-0x30
_func[1] = (func[1]=='0')?0:1;
_func[2] = (func[2]=='0')?0:1;
_func[3] = (func[3]=='0')?0:1;
_func = func;
_layer = NO_LAYER;
}
inline uint8_t func(uint8_t left, uint8_t right) {
@@ -34,9 +35,9 @@ typedef struct Gate {
}
void print(int id) {
printf ("gate %d: l:%d, r:%d, o:%d, func:%d%d%d%d\n", id, _left, _right, _out, _func[0], _func[1], _func[2], _func[3] );
printf ("gate %d: l:%lu, r:%lu, o:%lu, func:%d%d%d%d\n", id, _left, _right, _out, _func[0], _func[1], _func[2], _func[3] );
}
} Gate;
};
#endif

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Key.cpp
*
* Created on: Oct 27, 2015
* Author: marcel
*/
@@ -17,7 +17,7 @@ ostream& operator<<(ostream& o, const Key& key)
ostream& operator<<(ostream& o, const __m128i& x) {
o.fill('0');
o << hex;
o << hex << noshowbase;
for (int i = 0; i < 2; i++)
{
o.width(16);
@@ -27,26 +27,11 @@ ostream& operator<<(ostream& o, const __m128i& x) {
return o;
}
Key& Key::operator=(const Key& other) {
r= other.r;
// memcpy(&r, &other.r, sizeof(r));
return *this;
}
bool Key::operator==(const Key& other) {
__m128i neq = _mm_xor_si128(r, other.r);
return _mm_test_all_zeros(neq,neq);
}
Key& Key::operator-=(const Key& other) {
r ^= other.r;
return *this;
}
Key& Key::operator+=(const Key& other) {
r ^= other.r;
return *this;
}
//Key& Key::operator=(const Key& other) {
// r= other.r;
//// memcpy(&r, &other.r, sizeof(r));
// return *this;
//}
#else //__PRIME_FIELD__ is defined

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Key.h
*
* Created on: Oct 27, 2015
* Author: marcel
*/
#ifndef COMMON_INC_KEY_H_
@@ -13,7 +13,7 @@
#include <smmintrin.h>
#include <string.h>
#include "proto_utils.h"
#include "Tools/FlexBuffer.h"
using namespace std;
@@ -23,21 +23,63 @@ class Key {
public:
__m128i r;
Key() : r(_mm_set1_epi64x(0)) {}
Key(long long a) : r(_mm_set1_epi64x(a)) {}
Key(const Key& other) {r= other.r;}
Key() {}
Key(long long a) : r(_mm_cvtsi64_si128(a)) {}
Key(long long a, long long b) : r(_mm_set_epi64x(a, b)) {}
Key(__m128i r) : r(r) {}
// Key(const Key& other) {r= other.r;}
Key& operator=(const Key& other);
// Key& operator=(const Key& other);
bool operator==(const Key& other);
bool operator!=(const Key& other) { return !(*this == other); }
Key& operator-=(const Key& other);
Key& operator+=(const Key& other);
Key operator^(const Key& other) const { return r ^ other.r; }
Key operator^=(const Key& other) { r ^= other.r; return *this; }
void serialize(SendBuffer& output) const { output.serialize(r); }
void serialize_no_allocate(SendBuffer& output) const { output.serialize_no_allocate(r); }
bool get_signal() { return _mm_cvtsi128_si64(r) & 1; }
template <class T>
T get() const;
};
ostream& operator<<(ostream& o, const Key& key);
ostream& operator<<(ostream& o, const __m128i& x);
inline bool Key::operator==(const Key& other) {
__m128i neq = _mm_xor_si128(r, other.r);
return _mm_test_all_zeros(neq,neq);
}
inline Key& Key::operator-=(const Key& other) {
r ^= other.r;
return *this;
}
inline Key& Key::operator+=(const Key& other) {
r ^= other.r;
return *this;
}
template <>
inline unsigned long Key::get() const
{
return _mm_cvtsi128_si64(r);
}
template <>
inline __m128i Key::get() const
{
return r;
}
#else //__PRIME_FIELD__ is defined
const __uint128_t MODULUS= 0xffffffffffffffffffffffffffffff61;

File diff suppressed because it is too large Load Diff

View File

@@ -1,94 +1,253 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Party.h
*
* Created on: Feb 15, 2016
* Author: bush
*/
#ifndef PROTOCOL_PARTY_H_
#define PROTOCOL_PARTY_H_
#include "BooleanCircuit.h"
#include <mutex>
#include <boost/atomic.hpp>
#include "Node.h"
#include "Register.h"
#include "GarbledGate.h"
#include "network/Node.h"
#include "CommonParty.h"
#include "SpdzWire.h"
#include "AndJob.h"
#include "GC/Machine.h"
#include "GC/Program.h"
#include "GC/Processor.h"
#include "GC/Secret.h"
#include "Tools/Worker.h"
class BooleanCircuit;
#define SERVER_ID (0)
#define INPUT_KEYS_MSG_TYPE_SIZE (16) // so memory will by alligned
#ifndef N_EVAL_THREADS
// Default Intel desktop processor has 8 half cores.
// This is beneficial if only one AES available per full core.
#define N_EVAL_THREADS (8)
#endif
typedef struct {
unsigned long min=0;
unsigned long long acc=0;
} exec_props_t;
class Party : public NodeUpdatable {
class BaseParty : virtual public CommonParty {
public:
BaseParty(party_id_t id);
virtual ~BaseParty();
/* From NodeUpdatable class */
void NodeReady();
void NewMessage(int from, ReceivedMsg& msg);
void NodeAborted(struct sockaddr_in* from) { (void)from; }
void Start();
party_id_t get_id() { return _id; }
Key get_delta() { return delta; }
protected:
party_id_t _id;
// int _num_evaluation_threads;
struct timeval _start_online_net, _end_online_net;
vector<char> input_masks;
vector<char>::iterator input_mask;
Timer online_timer;
Key delta;
virtual void _compute_prfs_outputs(Key* keys) = 0;
void _send_prfs();
virtual void _process_external_received(char* externals,
party_id_t from) = 0;
virtual void _process_all_external_received(char* externals) = 0;
virtual void _process_input_keys(Key* keys, party_id_t from) = 0;
virtual void _process_all_input_keys(char* keys) = 0;
virtual void store_garbled_circuit(ReceivedMsg& msg) = 0;
virtual void _check_evaluate() = 0;
virtual void mask_output(ReceivedMsg& msg) = 0;
void done();
virtual void start_online_round() = 0;
virtual void receive_spdz_wires(ReceivedMsg& msg) = 0;
};
class Party : public BaseParty, public CommonCircuitParty {
friend class BooleanCircuit;
public:
Party(const char* netmap_file, const char* circuit_file, party_id_t id, const std::string input, int numthreads=5, int numtries=2);
virtual ~Party();
/* From NodeUpdatable class */
void NodeReady();
void NewMessage(int from, char* message, unsigned int len);
void NodeAborted(struct sockaddr_in* from) {}
void Start();
/* TEST methods */
private:
party_id_t _id;
Node* _node;
BooleanCircuit* _circuit;
gate_id_t _G;
party_id_t _N;
wire_id_t _W;
wire_id_t _OW;
wire_id_t _IO;
wire_id_t _IO;
std::string _all_input;
char* _input;
char* _external_values_msg;
unsigned int _external_values_msg_sz;
int _NUMTHREADS;
int _NUMTRIES;
vector<GarbledGate> _garbled_tbl;
vector<char> _input;
SendBuffer _external_values_msg;
int _num_externals_msg_received;
std::mutex _process_externals_mx;
char* _input_wire_keys_msg;
unsigned int _input_wire_keys_msg_sz;
SendBuffer _input_wire_keys_msg;
int _num_inputkeys_msg_received;
std::mutex _process_keys_mx;
std::mutex _sync_mx;
// int _num_evaluation_threads;
struct timeval* _start_online_net, *_end_online_net;
int _NUMTHREADS;
int _NUMTRIES;
#ifdef __PURE_SHE__
Key* _sqr_keys;
inline Key* _sqr_key(party_id_t i, wire_id_t w,int b) {return _sqr_keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
#endif
inline Key& _key(party_id_t i, wire_id_t w,int b) {return registers[w][b][i-1] ; }
inline KeyVector& _garbled_entry(gate_id_t g, int entry) {return _garbled_tbl[g-1][entry];}
vector<GarbledGate>::iterator get_garbled_tbl_end() { return _garbled_tbl.begin() + garbled_tbl_size; }
void resize_garbled_tbl() { _garbled_tbl.resize(_G, _N); garbled_tbl_size = _G; }
void _initialize_input();
void _generate_prf_inputs();
void _allocate_prf_outputs();
void _compute_prfs_outputs();
void _print_prfs();
void _send_prfs();
void _compute_prfs_outputs(Key* keys);
void _print_keys();
void _printf_garbled_table();
void _allocate_external_values();
void _generate_external_values_msg(char * masks);
void _process_external_received(char* externals, party_id_t from);
void _generate_external_values_msg();
void _process_external_received(char* externals,
party_id_t from);
void _process_all_external_received(char* externals);
inline void _allocate_input_wire_keys();
void _print_input_keys_msg();
void _print_keys_of_party(Key *keys, int id);
void _print_input_keys_checksum();
void _process_input_keys(Key* keys, party_id_t from);
void _process_all_input_keys(char* keys);
void _print_input_keys_msg();
void _print_keys_of_party(Key *keys, int id);
void _printf_garbled_table();
void store_garbled_circuit(ReceivedMsg& msg);
void load_garbled_circuit() {}
void _check_evaluate();
void receive_keys(Key* keys);
void receive_spdz_wires(ReceivedMsg& msg) { (void)msg; }
void start_online_round();
void mask_output(ReceivedMsg& msg);
int get_n_inputs();
};
class ProgramParty : public BaseParty
{
friend class PRFRegister;
friend class EvalRegister;
friend class Register;
char* prf_output;
Key* keys_for_prf;
deque<octetStream> spdz_wires[SPDZ_OP_N];
size_t spdz_storage;
size_t garbled_storage;
vector<size_t> spdz_counters;
Worker<AndJob> eval_threads[N_EVAL_THREADS];
AndJob and_jobs[N_EVAL_THREADS];
ReceivedMsgStore output_masks_store;
GC::Memory< GC::Secret<EvalRegister>::DynamicType > dynamic_memory;
GC::Machine< GC::Secret<EvalRegister> > machine;
GC::Processor<GC::Secret<EvalRegister> > processor;
GC::Program<GC::Secret<EvalRegister> > program;
GC::Machine< GC::Secret<PRFRegister> > prf_machine;
GC::Processor<GC::Secret<PRFRegister> > prf_processor;
void _compute_prfs_outputs(Key* keys);
void _process_external_received(char* externals,
party_id_t from) { (void)externals; (void)from; }
void _process_all_external_received(char* externals) { (void)externals; }
void _process_input_keys(Key* keys, party_id_t from)
{ (void)keys; (void)from; }
void _process_all_input_keys(char* keys) { (void)keys; }
void store_garbled_circuit(ReceivedMsg& msg);
void load_garbled_circuit();
void _check_evaluate();
void receive_keys(Register& reg);
void receive_all_keys(Register& reg, bool external);
void receive_spdz_wires(ReceivedMsg& msg);
void start_online_round();
void mask_output(ReceivedMsg& msg) { output_masks_store.push(msg); }
public:
static ProgramParty* singleton;
ReceivedMsg garbled_circuit;
ReceivedMsgStore garbled_circuits;
ReceivedMsg output_masks;
MAC_Check<gf2n>* MC;
Player* P;
Names N;
int threshold;
static ProgramParty& s();
ProgramParty(int argc, char** argv);
~ProgramParty();
void reset();
void input_value(party_id_t from, char value);
void get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire);
void store_wire(const Register& reg);
void load_wire(Register& reg);
};
inline ProgramParty& ProgramParty::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
#endif /* PROTOCOL_PARTY_H_ */

1094
BMR/Register.cpp Normal file

File diff suppressed because it is too large Load Diff

394
BMR/Register.h Normal file
View File

@@ -0,0 +1,394 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Register.h
*
*/
#ifndef PROTOCOL_SRC_REGISTER_H_
#define PROTOCOL_SRC_REGISTER_H_
#include <vector>
#include <utility>
#include <stdint.h>
using namespace std;
#include "Key.h"
#include "Wire.h"
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "Math/gf2n.h"
#include "Tools/FlexBuffer.h"
#ifndef FREE_XOR
#warning not using free XOR has not been tested in a while
#endif
typedef unsigned int party_id_t;
//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
#ifdef N_PARTIES
#define MAX_N_PARTIES N_PARTIES
#endif
#ifdef MAX_N_PARTIES
class BaseKeyVector
{
Key keys[MAX_N_PARTIES];
public:
Key& operator[](int i) { return keys[i]; }
const Key& operator[](int i) const { return keys[i]; }
Key* data() { return keys; }
const Key* data() const { return keys; }
#ifdef N_PARTIES
BaseKeyVector(int n_parties = 0) { (void)n_parties; avx_memzero(keys, sizeof(keys)); }
size_t size() const { return N_PARTIES; }
void resize(int size) { (void)size; }
#else
BaseKeyVector(int n_parties = 0) : n_parties(n_parties) { memset(keys, 0, sizeof(keys)); }
size_t size() const { return n_parties; }
void resize(int size) { n_parties = size; }
private:
int n_parties;
#endif
};
#else
typedef vector<Key> BaseKeyVector;
#endif
class KeyVector : public BaseKeyVector
{
public:
KeyVector(int size = 0) : BaseKeyVector(size) {}
size_t byte_size() const { return size() * sizeof(Key); }
void operator=(const KeyVector& source);
KeyVector operator^(const KeyVector& other) const;
template <class T>
void serialize_no_allocate(T& output) const { output.serialize_no_allocate(data(), byte_size()); }
template <class T>
void serialize(T& output) const { output.serialize(data(), byte_size()); }
void unserialize(ReceivedMsg& source, int n_parties);
friend ostream& operator<<(ostream& os, const KeyVector& kv);
};
class GarbledGate;
class CommonParty;
template <int I>
class KeyTuple {
friend class Register;
static long counter;
protected:
KeyVector keys[I];
int part_size() { return keys[0].size() * sizeof(Key); }
public:
KeyTuple() {}
KeyTuple(int n_parties) { init(n_parties); }
void init(int n_parties);
int byte_size() { return I * keys[0].byte_size(); }
KeyVector& operator[](int i) { return keys[i]; }
const KeyVector& operator[](int i) const { return keys[i]; }
KeyTuple<I> operator^(const KeyTuple<I>& other) const;
void copy_to(Key* dest);
void unserialize(ReceivedMsg& source, int n_parties);
void copy_from(Key* source, int n_parties, int except);
template <class T>
void serialize_no_allocate(T& output) const;
template <class T>
void serialize(T& output) const;
template <class T>
void serialize(T& output, party_id_t pid) const;
void unserialize(vector<char>& output);
template <class T>
void unserialize(T& output);
void randomize();
void reset();
void print(int wire_id) const;
void print(int wire_id, party_id_t pid);
};
namespace GC
{
class AuthValue;
class Mask;
class SpdzShare;
template <class T>
class Secret;
}
class Register {
protected:
static int counter;
KeyVector garbled_entry;
char external;
public:
char mask;
KeyTuple<2> keys; /* Additional data stored per per party per wire: */
/* Total of n*W*2 keys
* For every w={0,...,W}
* For every b={0,1}
* For every i={1...n}
* k^i_{w,b}
* This is helpful that the keys for specific w and b are adjacent
* for pipelining matters.
*/
Register(int n_parties);
void init(int n_parties);
void init(int rfd, int n_parties);
KeyVector& operator[](int i) { return keys[i]; }
const Key& key(party_id_t i, int b) const { return keys[b][i-1]; }
Key& key(party_id_t i, int b) { return keys[b][i-1]; }
void set_eval_keys();
void set_eval_keys(Key* keys, int n_parties, int except);
const Key& external_key(party_id_t i) const { return garbled_entry[i-1]; }
void set_external_key(party_id_t i, const Key& key);
void reset_non_external_key(party_id_t i);
void set_external(char ext);
char get_external() const { check_external(); return external; }
char get_external_no_check() const { return external; }
void set_mask(char mask);
int get_mask() const { check_mask(); return mask; }
char get_mask_no_check() { return mask; }
char get_output() { check_external(); check_mask(); return mask ^ external; }
char get_output_no_check() { return mask ^ external; }
const KeyVector& get_garbled_entry() const { return garbled_entry; }
void print_input(int id);
void print() const { keys.print(get_id()); }
void check_external() const;
void check_mask() const;
void check_signal_key(int my_id, KeyVector& garbled_entry);
void eval(const Register& left, const Register& right, GarbledGate& gate,
party_id_t my_id, char* prf_output, int, int, int);
void garble(const Register& left, const Register& right, Function func,
Gate* gate, int g, vector<ReceivedMsg>& prf_outputs, SendBuffer& buffer);
size_t get_id() const { return (size_t)this; }
template <class T>
void set_trace();
};
// this is to fake a "cout" that does nothing
class BlackHole
{
public:
template <typename T>
BlackHole& operator<<(T) { return *this; }
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
};
inline BlackHole& endl(BlackHole& b) { return b; }
inline BlackHole& flush(BlackHole& b) { return b; }
class ProgramRegister : public Register
{
public:
typedef BlackHole out_type;
static const BlackHole out;
static Register new_reg();
static Register tmp_reg() { return new_reg(); }
static Register and_reg() { return new_reg(); }
static void check(const int128& value, word share, int128 mac)
{ (void)value; (void)share; (void)mac; }
static void get_dyn_mask(GC::Mask& mask, int length, int mac_length)
{ (void)mask; (void)length; (void)mac_length; }
template <class T>
static void store_clear_in_dynamic(T& mem, const vector<GC::ClearWriteAccess>& accesses)
{ (void)mem; (void)accesses; }
static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share,
word masked, int128 masked_mac)
{ (void)dest; (void)mask_share; (void)mac_mask_share; (void)masked; (void)masked_mac; }
template<class T>
static void store(GC::Memory<GC::SpdzShare>& dest,
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
template<class T>
static void load(vector<GC::ReadAccess<T> >& accesses,
const GC::Memory<GC::SpdzShare>& source) { (void)accesses; (void)source; }
template <class T>
static void andrs(T& processor, const vector<int>& args) { processor.andrs(args); }
void input(party_id_t from, char value = -1) { (void)from; (void)value; }
void public_input(bool value) { (void)value; }
void random() {}
char get_output() { return 0; }
};
class FirstRoundRegister : public ProgramRegister
{
public:
};
class SecondRoundRegister : public ProgramRegister
{
public:
};
class PRFRegister : public FirstRoundRegister
{
public:
static string name() { return "PRF"; }
template<class T>
static void load(vector<GC::ReadAccess<T> >& accesses,
const GC::Memory<GC::SpdzShare>& source);
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char input = -1);
void public_input(bool value);
void random();
void output();
};
class EvalRegister : public SecondRoundRegister
{
public:
static string name() { return "Evaluation"; }
typedef ostream& out_type;
static ostream& out;
static void check(const int128& value, word share, int128 mac);
static void get_dyn_mask(GC::Mask& mask, int length, int mac_length);
static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share,
word masked, int128 masked_mac);
template<class T>
static void store(GC::Memory<GC::SpdzShare>& dest,
const vector<GC::WriteAccess<T> >& accesses);
template<class T>
static void load(vector<GC::ReadAccess<T> >& accesses,
const GC::Memory<GC::SpdzShare>& source);
template <class T>
static void andrs(T& processor, const vector<int>& args);
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
void XOR(const Register& left, const Register& right);
void public_input(bool value);
void random();
void output();
unsigned long long get_output() { return Register::get_output(); }
template <class T>
static void store_clear_in_dynamic(GC::Memory<T>& mem,
const vector<GC::ClearWriteAccess>& accesses);
void input(party_id_t from, char value = -1);
};
class GarbleRegister : public SecondRoundRegister
{
public:
static string name() { return "Garbling"; }
template<class T>
static void load(vector<GC::ReadAccess<T> >& accesses,
const GC::Memory<GC::SpdzShare>& source);
void op(const Register& left, const Register& right, Function func);
void XOR(const Register& left, const Register& right);
void public_input(bool value);
void random();
void output() {}
};
class RandomRegister : public FirstRoundRegister
{
public:
static string name() { return "Randomization"; }
template<class T>
static void store(GC::Memory<GC::SpdzShare>& dest,
const vector<GC::WriteAccess<T> >& accesses);
template<class T>
static void load(vector<GC::ReadAccess<T> >& accesses,
const GC::Memory<GC::SpdzShare>& source);
void randomize();
void op(const Register& left, const Register& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char value = -1);
void public_input(bool value);
void random();
void output();
};
inline Register::Register(int n_parties) :
garbled_entry(n_parties), external(NO_SIGNAL),
mask(NO_SIGNAL), keys(n_parties)
{
}
inline void KeyVector::unserialize(ReceivedMsg& source, int n_parties)
{
resize(n_parties);
source.unserialize(data(), size() * sizeof(Key));
}
template <int I>
inline void KeyTuple<I>::init(int n_parties) {
for (int i = 0; i < I; i++)
keys[i].resize(n_parties);
}
template<int I>
inline void KeyTuple<I>::reset()
{
for (int i = 0; i < I; i++)
for (size_t j = 0; j < keys[i].size(); j++)
keys[i][j] = 0;
}
template <int I>
inline void KeyTuple<I>::unserialize(ReceivedMsg& source, int n_parties) {
for (int b = 0; b < I; b++)
keys[b].unserialize(source, n_parties);
}
template<int I> template <class T>
void KeyTuple<I>::serialize_no_allocate(T& output) const {
for (int i = 0; i < I; i++)
keys[i].serialize_no_allocate(output);
}
template<int I> template <class T>
void KeyTuple<I>::serialize(T& output) const {
for (int i = 0; i < I; i++)
for (unsigned int j = 0; j < keys[i].size(); j++)
keys[i][j].serialize(output);
}
template<int I> template <class T>
void KeyTuple<I>::serialize(T& output, party_id_t pid) const {
for (int i = 0; i < I; i++)
keys[i][pid - 1].serialize(output);
}
template<int I> template <class T>
void KeyTuple<I>::unserialize(T& output) {
for (int i = 0; i < I; i++)
for (unsigned int j = 0; j < keys[i].size(); j++)
output.unserialize(keys[i][j]);
}
#endif /* PROTOCOL_SRC_REGISTER_H_ */

20
BMR/Register_inline.h Normal file
View File

@@ -0,0 +1,20 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Register_inline.h
*
*/
#ifndef BMR_REGISTER_INLINE_H_
#define BMR_REGISTER_INLINE_H_
#include "CommonParty.h"
#include "Party.h"
inline Register ProgramRegister::new_reg()
{
return Register(CommonParty::s().get_n_parties());
}
#endif /* BMR_REGISTER_INLINE_H_ */

26
BMR/SpdzWire.cpp Normal file
View File

@@ -0,0 +1,26 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* SpdzWire.cpp
*
*/
#include "SpdzWire.h"
SpdzWire::SpdzWire()
{
}
void SpdzWire::pack(octetStream& os) const
{
mask.pack(os);
os.serialize(my_keys);
}
void SpdzWire::unpack(octetStream& os, size_t wanted_size)
{
(void)wanted_size;
mask.unpack(os);
os.unserialize(my_keys);
}

25
BMR/SpdzWire.h Normal file
View File

@@ -0,0 +1,25 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* SpdzWire.h
*
*/
#ifndef BMR_SPDZWIRE_H_
#define BMR_SPDZWIRE_H_
#include "Math/Share.h"
#include "Key.h"
class SpdzWire
{
public:
Share<gf2n> mask;
Key my_keys[2];
SpdzWire();
void pack(octetStream& os) const;
void unpack(octetStream& os, size_t wanted_size);
};
#endif /* BMR_SPDZWIRE_H_ */

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* TrustedParty.cpp
*
* Created on: Feb 15, 2016
* Author: bush
*/
#include "TrustedParty.h"
@@ -16,9 +16,25 @@
#include "proto_utils.h"
#include "msg_types.h"
#include "SpdzWire.h"
#include "Auth/fake-stuff.h"
TrustedProgramParty* TrustedProgramParty::singleton = 0;
BaseTrustedParty::BaseTrustedParty()
{
#ifdef __PURE_SHE__
init_modulos();
init_temp_mpz_t(_temp_mpz);
std::cout << "_temp_mpz: " << _temp_mpz << std::endl;
#endif
_num_prf_received = 0;
_received_gc_received = 0;
n_received = 0;
randomfd = open("/dev/urandom", O_RDONLY);
}
TrustedParty::TrustedParty(const char* netmap_file, // required to init Node
const char* circuit_file // required to init BooleanCircuit
@@ -26,126 +42,231 @@ TrustedParty::TrustedParty(const char* netmap_file, // required to init Node
{
_circuit = new BooleanCircuit( circuit_file );
_G = _circuit->NumGates();
#ifndef N_PARTIES
_N = _circuit->NumParties();
#endif
_W = _circuit->NumWires();
_OW = _circuit->NumOutWires();
#ifdef __PURE_SHE__
init_modulos();
init_temp_mpz_t(_temp_mpz);
std::cout << "_temp_mpz: " << _temp_mpz << std::endl;
#endif
_allocate_prf_outputs();
_allocate_garbled_table();
_num_prf_received = 0;
_received_gc_received = 0;
if (0 == strcmp(netmap_file, LOOPBACK_STR)) {
_node = new Node( NULL, 0, this , _N+1);
} else {
_node = new Node( netmap_file, 0, this );
}
reset();
garbled_tbl_size = _G;
init(netmap_file, 0, _N);
}
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :
machine(dynamic_memory), processor(machine),
random_machine(dynamic_memory), random_processor(random_machine)
{
if (argc < 2)
{
cerr << "Usage: " << argv[0] << " <program> [netmap]" << endl;
exit(1);
}
ifstream file((string("Programs/Bytecode/") + argv[1] + "-0.bc").c_str());
program.parse(file);
processor.reset(program);
machine.reset(program);
random_processor.reset(program.cast< GC::Secret<RandomRegister> >());
random_machine.reset(program.cast< GC::Secret<RandomRegister> >());
if (singleton)
throw runtime_error("there can only be one");
singleton = this;
if (argc == 3)
init(argv[2], 0);
else
init("LOOPBACK", 0);
#ifdef FREE_XOR
deltas.resize(_N);
for (size_t i = 0; i < _N; i++)
{
deltas[i] = prng.get_doubleword();
#ifdef DEBUG
deltas[i] = Key(i + 1, 0);
#endif
#ifdef KEY_SIGNAL
if (deltas[i].get_signal() == 0)
deltas[i] ^= Key(1);
#endif
cout << "Delta " << i << ": " << deltas[i] << endl;
}
#endif
}
TrustedProgramParty::~TrustedProgramParty()
{
cout << "Random timer: " << random_timer.elapsed() << endl;
}
TrustedParty::~TrustedParty() {
// TODO Auto-generated destructor stub
}
void TrustedParty::NodeReady()
void BaseTrustedParty::NodeReady()
{
#ifdef DEBUG_STEPS
printf("\n\nNode ready \n\n");
sleep(1);
#endif
//sleep(1);
prepare_randomness();
send_randomness();
prf_outputs.resize(get_n_parties());
}
_generate_masks();
void BaseTrustedParty::prepare_randomness()
{
msg_keys.resize(_N);
for (size_t i = 0; i < msg_keys.size(); i++)
{
msg_keys[i].clear();
fill_message_type(msg_keys[i], TYPE_KEYS);
msg_keys[i].resize(MSG_KEYS_HEADER_SZ);
}
unsigned int number_of_keys = 2* _W * _N;
unsigned int size_of_keys = number_of_keys*sizeof(Key);
unsigned int msg_keys_size = size_of_keys + MSG_KEYS_HEADER_SZ;
Key* all_keys = new Key[number_of_keys];
memset(all_keys, 0, size_of_keys);
#ifdef __PURE_SHE__
_circuit->_sqr_keys = new Key[number_of_keys];
memset(_circuit->_sqr_keys, 0, size_of_keys);
#endif
for(party_id_t pid=1; pid<=_N; pid++)
{
/* generating and sending keys */
char* msg_keys = new char[msg_keys_size];
memset(msg_keys, 0, msg_keys_size);
fill_message_type(msg_keys, TYPE_KEYS);
Key* party_keys = (Key*)(msg_keys + MSG_KEYS_HEADER_SZ);
#ifdef __PURE_SHE__
_fill_keys_for_party(_circuit->_sqr_keys, party_keys, pid);
#else
_fill_keys_for_party(party_keys, pid);
done_filling = _fill_keys();
#endif
// printf("keys for party %u\n", pid);
// phex(party_keys, size_of_keys);
_merge_keys(all_keys, party_keys);
}
void BaseTrustedParty::send_randomness()
{
for(party_id_t pid=1; pid<=_N; pid++)
{
// printf("all keys\n");
// phex(all_keys, size_of_keys);
_node->Send(pid, msg_keys, msg_keys_size);
/* sending masks for input wires */
party_t party = _circuit->Party(pid);
char* msg_input_masks = new char[sizeof(MSG_TYPE) + party.n_wires];
fill_message_type(msg_input_masks, TYPE_MASK_INPUTS);
memcpy(msg_input_masks+sizeof(MSG_TYPE), _circuit->Masks()+party.wires, party.n_wires);
_node->Send(pid, msg_input_masks, sizeof(MSG_TYPE) + party.n_wires);
// TODO: test only
// printf("input masks for party %d\n", pid);
// phex(msg_input_masks + sizeof(MSG_TYPE) , party.n_wires);
_node->Send(pid, msg_keys[pid - 1]);
// printf("msg keys\n");
// phex(msg_keys, msg_keys_size);
send_input_masks(pid);
}
_circuit->Keys(all_keys);
send_output_masks();
}
void TrustedParty::send_input_masks(party_id_t pid)
{
prepare_input_regs(pid);
/* sending masks for input wires */
msg_input_masks.resize(get_n_parties());
SendBuffer& buffer = msg_input_masks[pid-1];
buffer.clear();
fill_message_type(buffer, TYPE_MASK_INPUTS);
for (auto input_regs = input_regs_queue.begin();
input_regs != input_regs_queue.end(); input_regs++)
{
int n_wires = (*input_regs)[pid].size();
#ifdef DEBUG_ROUNDS
cout << dec << n_wires << " inputs from " << pid << endl;
#endif
for (int i = 0; i < n_wires; i++)
buffer.push_back(registers[(*input_regs)[pid][i]].get_mask());
#ifdef DEBUG2
printf("input masks for party %d\n", pid);
phex(buffer);
#endif
#ifdef DEBUG_VALUES
printf("input masks for party %d:\t", pid);
print_masks((*input_regs)[pid]);
cout << "on registers:" << endl;
print_indices((*input_regs)[pid]);
#endif
}
#ifdef DEBUG_ROUNDS
cout << "sending " << dec << buffer.size() - 4 << " input masks" << endl;
#endif
_node->Send(pid, buffer);
}
void TrustedParty::send_output_masks()
{
prepare_output_regs();
/* output wires' masks are the same for all players */
/* sending masks for output wires */
char* msg_output_masks = new char[sizeof(MSG_TYPE) + _OW];
fill_message_type(msg_output_masks, TYPE_MASK_OUTPUT);
memcpy(msg_output_masks + sizeof(MSG_TYPE), _circuit->Masks()+_circuit->OutWiresStart(), _OW);
_node->Broadcast(msg_output_masks, sizeof(MSG_TYPE) + _OW);
// TODO: test only
// printf("output masks\n");
// phex(msg_output_masks+ sizeof(MSG_TYPE) , _OW);
int _OW = output_regs.size();
SendBuffer& msg_output_masks = get_buffer(TYPE_MASK_OUTPUT);
for (int i = 0; i < _OW; i++)
msg_output_masks.push_back(get_reg(output_regs[i]).get_mask());
_node->Broadcast(msg_output_masks);
#ifdef DEBUG2
printf("output masks\n");
phex(msg_output_masks);
#endif
#ifdef DEBUG_VALUES
printf("output masks:\t\t\t");
print_masks(output_regs);
cout << "on registers:" << endl;
print_indices(output_regs);
#endif
}
void TrustedParty::NewMessage(int from, char* message, unsigned int len)
void TrustedProgramParty::send_output_masks()
{
#ifdef DEBUG_OUTPUT_MASKS
cout << "sending " << msg_output_masks.size() - 4 << " output masks" << endl;
#endif
_node->Broadcast(msg_output_masks);
}
void BaseTrustedParty::NewMessage(int from, ReceivedMsg& msg)
{
char* message = msg.data();
int len = msg.size();
MSG_TYPE message_type;
memcpy(&message_type, message, sizeof(MSG_TYPE));
msg.unserialize(message_type);
unique_lock<mutex> locker(global_lock);
switch(message_type) {
case TYPE_PRF_OUTPUTS:
{
#ifdef DEBUG
cout << "TYPE_PRF_OUTPUTS" << endl;
#endif
_print_mx.lock();
// printf("got message of len %u from %d\n", len, from);
#ifdef DEBUG2
printf("got message of len %u from %d\n", len, from);
phex(message, len);
cout << "garbled table size " << get_garbled_tbl_size() << endl;
#endif
#ifdef DEBUG_STEPS
printf("\n Got prfs from %d\n",from);
#endif
char* party_prfs = _circuit->Prfs() + (PRFS_PER_PARTY(_G, _N)*sizeof(Key)) *(from-1) ;
memcpy(party_prfs, message + sizeof(MSG_TYPE), PRFS_PER_PARTY(_G, _N)*sizeof(Key));
// phex(party_prfs, PRFS_PER_PARTY(G, N)*sizeof(Key));
prf_outputs[from-1] = msg;
_print_mx.unlock();
_num_prf_received ++;
if(_num_prf_received == _N) {
if(++_num_prf_received == _N) {
_num_prf_received = 0;
_compute_send_garbled_circuit();
}
break;
}
case TYPE_RECEIVED_GC:
{
_received_gc_received++;
if(_received_gc_received == _N) {
_launch_online();
if(++_received_gc_received == _N) {
_received_gc_received = 0;
if (done_filling)
_launch_online();
else
NodeReady();
}
break;
}
case TYPE_NEXT:
if (++n_received == _N)
{
n_received = 0;
send_randomness();
}
break;
case TYPE_DONE:
if (++n_received == _N)
_node->Stop();
break;
default:
{
_print_mx.lock();
@@ -155,6 +276,7 @@ void TrustedParty::NewMessage(int from, char* message, unsigned int len)
phex(message, len);
_print_mx.unlock();
}
break;
}
}
@@ -163,155 +285,57 @@ void TrustedParty::_launch_online()
{
printf("press to launch online\n");
getchar();
char* launch_msg = new char[sizeof(MSG_TYPE)];
fill_message_type(launch_msg, TYPE_LAUNCH_ONLINE);
_node->Broadcast(launch_msg, sizeof(MSG_TYPE));
_node->Broadcast(get_buffer(TYPE_LAUNCH_ONLINE));
printf("launched\n");
}
void TrustedParty::_allocate_garbled_table()
void TrustedProgramParty::_launch_online()
{
unsigned int garbled_table_sz = _G*4*_N*sizeof(Key)+RESERVE_FOR_MSG_TYPE;
_circuit->_garbled_tbl = (Key*)malloc(garbled_table_sz);
memset(_circuit->_garbled_tbl,0,garbled_table_sz);
_node->Broadcast(get_buffer(TYPE_LAUNCH_ONLINE));
}
void TrustedParty::_compute_send_garbled_circuit()
void TrustedParty::garble()
{
Key* tbl_start = _circuit->_garbled_tbl+(RESERVE_FOR_MSG_TYPE/sizeof(Key));
unsigned int prfs_per_party_sz = PRFS_PER_PARTY(_G,_N)*sizeof(Key);
std::vector<Gate>* gates = &_circuit->_gates;
char* masks = _circuit->_masks;
for(gate_id_t g=1; g<=_G; g++) {
// std::cout << "garbling gate " << g << std::endl ;
Key* gg_A = tbl_start + GARBLED_GATE_SIZE(_N)*(g-1);
Key* gg_B = gg_A + _N;
Key* gg_C = gg_A + 2*_N;
Key* gg_D = gg_A + 3*_N;
unsigned int prfs_left_offset = (g-1)*8*_N*sizeof(Key);
unsigned int prfs_right_offset = prfs_left_offset + 4*_N*sizeof(Key);
for(party_id_t i=1; i<=_N; i++) {
// std::cout << "adding prfs of party " << i << std::endl ;
char* party_prfs = _circuit->Prfs() + (i-1)*prfs_per_party_sz;
Key left_i_j, right_i_j;
for (party_id_t j=1; j<=_N; j++) {
//A
// std::cout << "A" << std::endl;
left_i_j = *(Key*)(party_prfs+prfs_left_offset + (j-1)*sizeof(Key));
right_i_j = *(Key*)(party_prfs+prfs_right_offset + (j-1)*sizeof(Key));
// cout << *(gg_A+j-1) << std::endl;
// cout << left_i_j << std::endl;
// cout << right_i_j << std::endl;
*(gg_A+j-1) += left_i_j; //left wire of party i in part j
*(gg_A+j-1) += right_i_j; //right wire of party i in part j
// cout << *(gg_A+j-1) << std::endl<< std::endl;
//B
// std::cout << "B" << std::endl;
left_i_j = *(Key*)(party_prfs+prfs_left_offset + _N*sizeof(Key) + (j-1)*sizeof(Key));
right_i_j = *(Key*)(party_prfs+prfs_right_offset + 2*_N*sizeof(Key) + (j-1)*sizeof(Key));
// cout << *(gg_B+j-1) << std::endl;
// cout << left_i_j << std::endl;
// cout << right_i_j << std::endl;
*(gg_B+j-1) += left_i_j; //left wire of party i in part j
*(gg_B+j-1) += right_i_j; //right wire of party i in part j
// cout << *(gg_B+j-1) << std::endl<< std::endl;
//C
// std::cout << "C" << std::endl;
left_i_j = *(Key*)(party_prfs+prfs_left_offset + 2*_N*sizeof(Key) + (j-1)*sizeof(Key));
right_i_j = *(Key*)(party_prfs+prfs_right_offset + _N*sizeof(Key) + (j-1)*sizeof(Key));
// cout << *(gg_C+j-1) << std::endl;
// cout << left_i_j << std::endl;
// cout << right_i_j << std::endl;
*(gg_C+j-1) += left_i_j; //left wire of party i in part j
*(gg_C+j-1) += right_i_j; //right wire of party i in part j
// cout << *(gg_C+j-1) << std::endl<< std::endl;
//D
// std::cout << "D" << std::endl;
left_i_j = *(Key*)(party_prfs+prfs_left_offset + 3*_N*sizeof(Key) + (j-1)*sizeof(Key));
right_i_j = *(Key*)(party_prfs+prfs_right_offset + 3*_N*sizeof(Key) + (j-1)*sizeof(Key));
// cout << *(gg_D+j-1) << std::endl;
// cout << left_i_j << std::endl;
// cout << right_i_j << std::endl;
*(gg_D+j-1) += left_i_j; //left wire of party i in part j
*(gg_D+j-1) += right_i_j; //right wire of party i in part j
// cout << *(gg_D+j-1) << std::endl<< std::endl;
}
}
//Adding the hidden keys
Gate* gate = &gates->at(g);
char maskl = masks[gate->_left];
char maskr = masks[gate->_right];
char masko = masks[gate->_out];
// printf("\ngate %u, leftwire=%u, rightwire=%u, outwire=%u: func=%d%d%d%d, msk_l=%d, msk_r=%d, msk_o=%d\n"
// , g,gate->_left, gate->_right, gate->_out
// ,gate->_func[0],gate->_func[1],gate->_func[2],gate->_func[3], maskl, maskr, masko);
// printf("\n");
// printf("maskl=%d, maskr=%d, masko=%d\n",maskl,maskr,masko);
// printf("gate func = %d%d%d%d\n",gate->_func[0],gate->_func[1],gate->_func[2],gate->_func[3]);
bool xa = gate->_func[2*maskl+maskr] != masko;
bool xb = gate->_func[2*maskl+(1-maskr)] != masko;
bool xc = gate->_func[2*(1-maskl)+maskr] != masko;
bool xd = gate->_func[2*(1-maskl)+(1-maskr)] != masko;
// printf("xa=%d, xb=%d, xc=%d, xd=%d\n", xa,xb,xc,xd);
// these are the 0-keys
#ifdef __PURE_SHE__
Key* outwire_start = _circuit->_sqr_keys + gate->_out*2*_N;
#else
Key* outwire_start = _circuit->_keys + gate->_out*2*_N;
#ifdef DEBUG
std::cout << "garbling gate " << g << std::endl ;
#endif
Key* keyxa = outwire_start + (xa?_N:0);
Key* keyxb = outwire_start + (xb?_N:0);
Key* keyxc = outwire_start + (xc?_N:0);
Key* keyxd = outwire_start + (xd?_N:0);
for(party_id_t i=1; i<=_N; i++) {
// std::cout << "adding to A = " << keyxa[i-1] << std::endl;
// std::cout << "adding to B = " << keyxb[i-1] << std::endl;
// std::cout << "adding to C = " << keyxc[i-1] << std::endl;
// std::cout << "adding to D = " << keyxd[i-1] << std::endl;
*(gg_A+i-1) += keyxa[i-1];
*(gg_B+i-1) += keyxb[i-1];
*(gg_C+i-1) += keyxc[i-1];
*(gg_D+i-1) += keyxd[i-1];
}
Gate& gate = _circuit->_gates[g];
registers[gate._out].garble(registers[gate._left], registers[gate._right],
gate._func, &gate, g, prf_outputs, buffers[TYPE_GARBLED_CIRCUIT]);
}
//sending to parties:
fill_message_type(((char*)tbl_start)-4, TYPE_GARBLED_CIRCUIT );
_node->Broadcast( ((char*)tbl_start)-4 , _G*4*_N*sizeof(Key)+sizeof(MSG_TYPE));
}
void TrustedParty::Start()
void BaseTrustedParty::_compute_send_garbled_circuit()
{
SendBuffer& buffer = get_buffer(TYPE_GARBLED_CIRCUIT );
buffer.allocate(get_garbled_tbl_size() * 4 * get_n_parties() * sizeof(Key));
garble();
//sending to parties:
#ifdef DEBUG
cout << "sending garbled circuit" << endl;
#endif
#ifdef DEBUG2
phex(buffer);
#endif
_node->Broadcast(buffer);
//prepare_randomness();
}
void BaseTrustedParty::Start()
{
_node->Start();
}
/* keys - a 2*W*n keys buffer to be filled only in the places belong to party pid */
void TrustedParty::_fill_keys_for_party(Key* keys, party_id_t pid)
bool TrustedParty::_fill_keys()
{
int nullfd = open("/dev/urandom", O_RDONLY);
resize_registers();
for (wire_id_t w=0; w<_W; w++) {
read(nullfd, (char*)(keys+pid-1), sizeof(Key));
read(nullfd, (char*)(keys+_N+pid-1), sizeof(Key));
#ifdef __PRIME_FIELD__
keys[pid-1].adjust();
keys[pid+_N-1].adjust();
#endif
keys = keys + 2*_N;
registers[w].init(randomfd, _N);
add_keys(registers[w]);
}
close(nullfd);
return true;
}
#ifdef __PURE_SHE__
@@ -335,56 +359,98 @@ void TrustedParty::_fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid
}
#endif
/* Merge the two Key buffers into the dest buffer */
void TrustedParty::_merge_keys(Key* dest, Key* src)
{
for (wire_id_t w=0; w<_W; w++) {
for (int i=0; i<2; i++) {
for (party_id_t p=0; p<_N; p++) {
int offset = w*2*_N+i*_N+p;
Key kk;
kk = *(src + offset);
dest[offset] += kk;
}
}
}
}
void TrustedParty::_generate_masks ()
{
char* masks = new char[_W];
fill_random(masks, _W);
//need to convert from total random to 0/1
for (unsigned int i=0 ; i<_W; i++)
{
//because it is a SIGNED char ~half the chars whould be negetives
masks[i] = masks[i]>0 ? 1 : 0;
}
_circuit->Masks(masks);
// printf("masks\n");
// phex(_circuit->Masks(), _W);
}
void TrustedParty::_allocate_prf_outputs()
{
unsigned int prf_output_size = (PRFS_PER_PARTY(_G, _N)*sizeof(Key)) *_N ;
void* prf_outputs = malloc(prf_output_size);
memset(prf_outputs,0,prf_output_size);
_circuit->Prfs((char*)prf_outputs);
}
void TrustedParty::_print_keys()
{
Key* key_idx = _circuit->_key(1,0,0);
for (wire_id_t w=0; w<_W; w++) {
for(int b=0; b<=1; b++) {
for (party_id_t i=1; i<=_N; i++) {
printf("k^%d_{%u,%d}: ",i,w,b);
std::cout << *key_idx << std::endl;
key_idx++;
}
registers[w].keys.print(w);
}
}
void TrustedProgramParty::NodeReady()
{
#ifdef FREE_XOR
for (int i = 0; i < get_n_parties(); i++)
{
SendBuffer& buffer = get_buffer(TYPE_DELTA);
buffer.serialize(deltas[i]);
_node->Send(i + 1, buffer);
}
#endif
this->BaseTrustedParty::NodeReady();
}
bool TrustedProgramParty::_fill_keys()
{
for (int i = 0; i < SPDZ_OP_N; i++)
{
spdz_wires[i].clear();
spdz_wires[i].resize(get_n_parties());
}
msg_output_masks = get_buffer(TYPE_MASK_OUTPUT);
return GC::DONE_BREAK == first_phase(program, random_processor, random_machine);
}
void TrustedProgramParty::garble()
{
second_phase(program, processor, machine);
vector< Share<gf2n> > tmp;
make_share<gf2n>(tmp, 1, get_n_parties(), mac_key, prng);
for (int i = 0; i < get_n_parties(); i++)
tmp[i].get_mac().pack(spdz_wires[SPDZ_MAC][i]);
for (int i = 0; i < get_n_parties(); i++)
{
for (int j = 0; j < SPDZ_OP_N; j++)
{
SendBuffer buffer;
fill_message_type(buffer, TYPE_SPDZ_WIRES);
buffer.serialize(j);
buffer.serialize(spdz_wires[j][i].get_data(), spdz_wires[j][i].get_length());
#ifdef DEBUG_SPDZ_WIRE
cout << "send " << spdz_wires[j][i].get_length() << "/" << buffer.size()
<< " bytes for type " << j << " to " << i << endl;
#endif
_node->Send(i + 1, buffer);
}
}
}
void TrustedProgramParty::store_spdz_wire(SpdzOp op, const Register& reg)
{
make_share(mask_shares, gf2n(reg.get_mask()), get_n_parties(), gf2n(get_mac_key()), prng);
for (int i = 0; i < get_n_parties(); i++)
{
SpdzWire wire;
wire.mask = mask_shares[i];
for (int j = 0; j < 2; j++)
{
wire.my_keys[j] = reg.keys[j][i];
}
wire.pack(spdz_wires[op][i]);
}
#ifdef DEBUG_SPDZ_WIRE
cout << "stored SPDZ wire of type " << op << ":" << endl;
reg.keys.print(reg.get_id());
#endif
}
void TrustedProgramParty::store_wire(const Register& reg)
{
wires.serialize(reg.mask);
reg.keys.serialize(wires);
#ifdef DEBUG
cout << "storing wire" << endl;
reg.print();
#endif
}
void TrustedProgramParty::load_wire(Register& reg)
{
wires.unserialize(reg.mask);
reg.keys.unserialize(wires);
#ifdef DEBUG
cout << "loading wire" << endl;
reg.print();
#endif
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* TrustedParty.h
*
* Created on: Feb 15, 2016
* Author: bush
*/
#ifndef PROTOCOL_TRUSTEDPARTY_H_
@@ -13,48 +13,150 @@
#include "network/Node.h"
#include <atomic>
#include "Register.h"
#include "CommonParty.h"
class TrustedParty : public NodeUpdatable {
class BaseTrustedParty : virtual public CommonParty {
public:
vector<ReceivedMsg> prf_outputs;
BaseTrustedParty();
virtual ~BaseTrustedParty() {}
/* From NodeUpdatable class */
virtual void NodeReady();
void NewMessage(int from, ReceivedMsg& msg);
void NodeAborted(struct sockaddr_in* from) { (void)from; }
void Start();
protected:
boost::mutex _print_mx;
std::atomic_uint _num_prf_received;
std::atomic_uint _received_gc_received;
std::atomic_uint n_received;
vector<SendBuffer> msg_keys, msg_input_masks;
int randomfd;
bool done_filling;
#ifdef __PURE_SHE__
mpz_t _temp_mpz;
void _fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid);
#endif
virtual bool _fill_keys() = 0;
void _compute_send_garbled_circuit();
virtual void _launch_online() = 0;
void prepare_randomness();
void send_randomness();
virtual void send_input_masks(party_id_t pid) = 0;
virtual void send_output_masks() = 0;
virtual void garble() = 0;
void add_keys(const Register& reg);
};
class TrustedParty : public BaseTrustedParty, public CommonCircuitParty {
public:
TrustedParty(const char* netmap_file, const char* circuit_file );
virtual ~TrustedParty();
/* From NodeUpdatable class */
void NodeReady();
void NewMessage(int from, char* message, unsigned int len);
void NodeAborted(struct sockaddr_in* from) {}
void Start();
/* TEST methods */
private:
Node* _node;
BooleanCircuit* _circuit;
gate_id_t _G;
party_id_t _N;
wire_id_t _W;
wire_id_t _OW;
boost::mutex _print_mx;
std::atomic_int _num_prf_received;
std::atomic_int _received_gc_received;
#ifdef __PURE_SHE__
mpz_t _temp_mpz;
void _fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid);
#endif
void _fill_keys_for_party(Key* keys, party_id_t pid);
void _merge_keys(Key* dest, Key* src);
void _generate_masks();
void _allocate_prf_outputs();
void _allocate_garbled_table();
void _compute_send_garbled_circuit();
bool _fill_keys();
void _launch_online();
void _print_keys();
void send_input_masks(party_id_t pid);
void send_output_masks();
void garble();
};
class TrustedProgramParty : public BaseTrustedParty {
public:
SendBuffer msg_output_masks;
TrustedProgramParty(int argc, char** argv);
~TrustedProgramParty();
void NodeReady();
void store_spdz_wire(SpdzOp op, const Register& reg);
void store_wire(const Register& reg);
void load_wire(Register& reg);
#ifdef FREE_XOR
const Key& delta(int i) { return deltas[i]; }
const KeyVector& get_deltas() { return deltas; }
#endif
private:
friend class GarbleRegister;
friend class RandomRegister;
static TrustedProgramParty* singleton;
static TrustedProgramParty& s();
GC::Memory< GC::Secret<GarbleRegister>::DynamicType > dynamic_memory;
GC::Machine< GC::Secret<GarbleRegister> > machine;
GC::Processor< GC::Secret<GarbleRegister> > processor;
GC::Program< GC::Secret<GarbleRegister> > program;
GC::Machine< GC::Secret<RandomRegister> > random_machine;
GC::Processor< GC::Secret<RandomRegister> > random_processor;
#ifdef FREE_XOR
KeyVector deltas;
#endif
vector<octetStream> spdz_wires[SPDZ_OP_N];
vector< Share<gf2n> > mask_shares;
Timer random_timer;
bool _fill_keys();
void _launch_online();
void send_input_masks(party_id_t pid) { (void)pid; }
void send_output_masks();
void garble();
void add_all_keys(const Register& reg, bool external);
};
inline void BaseTrustedParty::add_keys(const Register& reg)
{
for(int p = 0; p < get_n_parties(); p++)
reg.keys.serialize(msg_keys[p], p + 1);
}
inline void TrustedProgramParty::add_all_keys(const Register& reg, bool external)
{
for(int p = 0; p < get_n_parties(); p++)
for (int i = 0; i < get_n_parties(); i++)
reg.keys[external][i].serialize(msg_keys[p]);
}
inline TrustedProgramParty& TrustedProgramParty::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
#endif /* PROTOCOL_TRUSTEDPARTY_H_ */

View File

@@ -1,12 +1,13 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef __WIRE_H__
#define __WIRE_H__
#include <vector>
#include <boost/thread/mutex.hpp>
#include <boost/thread.hpp>
#include "Key.h"
#include "common.h"
class Gate;
@@ -16,10 +17,10 @@ typedef char signal_t;
typedef struct Wire {
// signal_t _sig; // the actual value that is passing through the wire (after inputs have been set)
std::vector<gate_id_t> _enters_to; //TODO make it a regular c array
gate_id_t _out_from;
bool _is_output;
gate_id_t _out_from;
Wire(bool out):/*_sig(NO_SIGNAL),*/ _is_output(out),_out_from(NULL){}
Wire(bool out):/*_sig(NO_SIGNAL),*/ _is_output(out),_out_from(0){}
// inline signal_t Sig() { return _sig; }
// inline void Sig(signal_t s) { _sig = s; }
// void print(int w_id) {

View File

@@ -1,3 +1,5 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "aes.h"
#ifdef _WIN32

View File

@@ -1,15 +1,53 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* common.h
*
* Created on: Jan 21, 2016
* Author: bush
*/
#ifndef CIRCUIT_INC_COMMON_H_
#define CIRCUIT_INC_COMMON_H_
#include <string>
typedef unsigned long wire_id_t;
typedef unsigned long gate_id_t;
class Function {
bool rep[4];
int shift(int i) { return 4 * (3 - i); }
public:
Function() { memset(rep, 0, sizeof(rep)); }
Function(std::string& func)
{
for (int i = 0; i < 4; i++)
if (func[i] != '0')
rep[i] = 1;
else
rep[i] = 0;
}
Function(int int_rep)
{
for (int i = 0; i < 4; i++)
rep[i] = (int_rep << shift(i)) & 1;
}
uint8_t operator[](int i) { return rep[i]; }
};
template <class T>
class CheckVector : public vector<T>
{
public:
CheckVector() : vector<T>() {}
CheckVector(size_t size) : vector<T>(size) {}
CheckVector(size_t size, const T& def) : vector<T>(size, def) {}
#ifdef CHECK_SIZE
T& operator[](size_t i) { return this->at(i); }
const T& operator[](size_t i) const { return this->at(i); }
#else
T& at(size_t i) { return (*this)[i]; }
const T& at(size_t i) const { return (*this)[i]; }
#endif
};
#endif /* CIRCUIT_INC_COMMON_H_ */

16
BMR/msg_types.cpp Normal file
View File

@@ -0,0 +1,16 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* msg_types.cpp
*
*/
#include "msg_types.h"
#define X(NAME) #NAME,
const char* message_type_names[] = {
MESSAGE_TYPES
};
#undef X

View File

@@ -1,30 +1,42 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* msg_types.h
*
* Created on: Feb 16, 2016
* Author: bush
*/
#ifndef PROTOCOL_INC_MSG_TYPES_H_
#define PROTOCOL_INC_MSG_TYPES_H_
typedef enum {
TYPE_KEYS = 0,
TYPE_MASK_INPUTS,
TYPE_MASK_OUTPUT,
TYPE_PRF_OUTPUTS,
TYPE_GARBLED_CIRCUIT,
TYPE_EXTERNAL_VALUES,
TYPE_ALL_EXTERNAL_VALUES,
TYPE_KEY_PER_IN_WIRE,
TYPE_ALL_KEYS_PER_IN_WIRE,
TYPE_LAUNCH_ONLINE,
TYPE_RECEIVED_GC,
#define MESSAGE_TYPES \
X(TYPE_KEYS) \
X(TYPE_MASK_INPUTS) \
X(TYPE_MASK_OUTPUT) \
X(TYPE_PRF_OUTPUTS) \
X(TYPE_GARBLED_CIRCUIT) \
X(TYPE_EXTERNAL_VALUES) \
X(TYPE_ALL_EXTERNAL_VALUES) \
X(TYPE_KEY_PER_IN_WIRE) \
X(TYPE_ALL_KEYS_PER_IN_WIRE) \
X(TYPE_LAUNCH_ONLINE) \
X(TYPE_RECEIVED_GC) \
X(TYPE_SPDZ_WIRES) \
X(TYPE_DELTA) \
X(TYPE_CHECKSUM) \
X(TYPE_NEXT) \
X(TYPE_DONE) \
X(TYPE_MAX) \
TYPE_CHECKSUM,
TYPE_MAX
#define X(NAME) NAME,
typedef enum {
MESSAGE_TYPES
} MSG_TYPE;
#undef X
extern const char* message_type_names[];
#endif /* PROTOCOL_INC_MSG_TYPES_H_ */

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Client.cpp
*
* Created on: Jan 27, 2016
* Author: bush
*/
#include "Client.h"
@@ -28,20 +28,15 @@ static void throw_bad_ip(const char* ip) {
}
Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable, unsigned int max_message_size)
:_numservers(numservers),
_max_msg_sz(max_message_size),
_updatable(updatable),
_new_message(false)
:_max_msg_sz(max_message_size),
_numservers(numservers),
_updatable(updatable)
{
_sockets = new int[_numservers](); // 0 initialized
_servers = new struct sockaddr_in[_numservers];
_msg_queues = new Queue<Msg>[_numservers]();
_msg_queues = new WaitQueue< shared_ptr<SendBuffer> >[_numservers];
_lockqueue = new std::mutex[_numservers];
_queuecheck = new std::condition_variable[_numservers];
_new_message = new bool[_numservers]();
memset(_servers, 0, sizeof(_servers));
memset(_servers, 0, sizeof(*_servers));
for (int i=0; i<_numservers; i++) {
_sockets[i] = socket(AF_INET, SOCK_STREAM, 0);
@@ -56,20 +51,32 @@ Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable
}
Client::~Client() {
Stop();
for (int i=0; i<_numservers; i++)
close(_sockets[i]);
delete[] _sockets;
delete[] _servers;
delete[] _msg_queues;
delete[] _lockqueue;
delete[] _queuecheck;
delete[] _new_message;
#ifdef DEBUG_COMM
printf("Client:: Client deleted\n");
#endif
}
void Client::Connect() {
for (int i=0; i<_numservers; i++)
new boost::thread(&Client::_send_thread, this, i);
new boost::thread(&Client::_connect, this);
threads.add_thread(new boost::thread(&Client::_send_thread, this, i));
threads.add_thread(new boost::thread(&Client::_connect, this));
}
void Client::Stop() {
for (int i=0; i<_numservers; i++)
_msg_queues[i].stop();
threads.join_all();
for (int i=0; i<_numservers; i++)
shutdown(_sockets[i], SHUT_RDWR);
#ifdef DEBUG_COMM
printf("Stopped client\n");
#endif
}
void Client::_connect() {
@@ -89,95 +96,100 @@ void Client::_connect_to_server(int i) {
int port = ntohs(_servers[i].sin_port);
ip = inet_ntoa(_servers[i].sin_addr);
int error = 0;
int interval = 10000;
int total_wait = 0;
while (true ) {
error = connect(_sockets[i], (struct sockaddr *)&_servers[i], sizeof(struct sockaddr));
if (interval < CONNECT_INTERVAL)
interval *= 2;
if(!error)
break;
if (errno == 111) {
fprintf(stderr,".");
} else {
fprintf(stderr,"Client:: Error (%d): connect to %s:%d: \"%s\"\n",errno, ip,port,strerror(errno));
fprintf(stderr,"Client:: socket %d sleeping for %u usecs\n",i, CONNECT_INTERVAL);
fprintf(stderr,"Client:: socket %d sleeping for %u usecs\n",i, interval);
}
usleep(CONNECT_INTERVAL);
usleep(interval);
total_wait += interval;
if (total_wait > 60e6)
throw runtime_error("waiting for too long");
}
printf("\nClient:: connected to %s:%d\n", ip,port);
setsockopt(_sockets[i], SOL_SOCKET, SO_SNDBUF, &BUFFER_SIZE, sizeof(BUFFER_SIZE));
// Using the following disables the automatic buffer size (ipv4.tcp_wmem)
// in favour of the core.wmem_max, which is worse.
//setsockopt(_sockets[i], SOL_SOCKET, SO_SNDBUF, &NETWORK_BUFFER_SIZE, sizeof(NETWORK_BUFFER_SIZE));
}
void Client::Send(int id, const char* message, unsigned int len) {
Msg new_msg = {message, len};
void Client::Send(int id, SendBuffer& buffer) {
{
std::unique_lock<std::mutex> locker(_lockqueue[id]);
// printf ("Client:: queued %u bytes to %d\n", len, id);
_msg_queues[id].Enqueue(new_msg);
_new_message[id] = true;
_queuecheck[id].notify_one();
#ifdef DEBUG_COMM
printf ("Client:: queued %u bytes to %d\n", buffer.size(), id);
phex(buffer.data(), 4);
#endif
SendBuffer* tmp = new SendBuffer;
*tmp = buffer;
shared_ptr<SendBuffer> new_msg(tmp);
_msg_queues[id].push(new_msg);
}
}
void Client::Broadcast(const char* message, unsigned int len) {
void Client::Broadcast(SendBuffer& buffer) {
#ifdef DEBUG_COMM
printf ("Client:: queued %u bytes to broadcast\n", buffer.size());
phex(buffer.data(), 4);
#endif
SendBuffer* tmp = new SendBuffer;
*tmp = buffer;
shared_ptr<SendBuffer> new_msg(tmp);
for(int i=0;i<_numservers; i++) {
std::unique_lock<std::mutex> locker(_lockqueue[i]);
Msg new_msg = {message, len};
_msg_queues[i].Enqueue(new_msg);
_new_message[i] = true;
_queuecheck[i].notify_one();
_msg_queues[i].push(new_msg);
}
}
void Client::Broadcast2(const char* message, unsigned int len) {
void Client::Broadcast2(SendBuffer& buffer) {
#ifdef DEBUG_COMM
printf ("Client:: queued %u bytes to broadcast to all but the server\n", buffer.size());
phex(buffer.data(), 4);
#endif
SendBuffer* tmp = new SendBuffer;
*tmp = buffer;
shared_ptr<SendBuffer> new_msg(tmp);
// first server is always the trusted party so we start with i=1
for(int i=1;i<_numservers; i++) {
std::unique_lock<std::mutex> locker(_lockqueue[i]);
Msg new_msg = {message, len};
_msg_queues[i].Enqueue(new_msg);
_new_message[i] = true;
_queuecheck[i].notify_one();
_msg_queues[i].push(new_msg);
}
}
void Client::_send_thread(int i) {
while(true)
{
{
std::unique_lock<std::mutex> locker(_lockqueue[i]);
//printf("Client:: waiting for a notification to send to %d\n", i);
_queuecheck[i].wait(locker);
if (!_new_message[i]) {
// printf("Client:: Spurious notification!\n");
continue;
}
//printf("Client:: notified!!\n");
}
while (true)
{
Msg msg = {0};
{
std::unique_lock<std::mutex> locker(_lockqueue[i]);
if(_msg_queues[i].Empty()) {
//printf("Client:: no more messages in queue\n");
break; // out of the inner while
}
msg = _msg_queues[i].Dequeue();
}
_send_blocking(msg, i);
}
_new_message[i] = false;
}
shared_ptr<SendBuffer> msg;
while(_msg_queues[i].pop_dont_stop(msg))
_send_blocking(*msg, i);
#ifdef DEBUG_COMM
printf("Shutting down sender thread %d\n", i);
#endif
}
void Client::_send_blocking(Msg msg, int id) {
// printf ("Client:: sending %u bytes to %d\n", msg.len, id);
void Client::_send_blocking(SendBuffer& msg, int id) {
#ifdef DEBUG_COMM
printf ("Client:: sending %llu bytes at 0x%x to %d\n", msg.size(), msg.data(), id);
fflush(0);
#ifdef DEBUG2
phex(msg.data(), msg.size());
#else
phex(msg.data(), 4);
#endif
#endif
int cur_sent = 0;
cur_sent = send(_sockets[id], &msg.len, LENGTH_FIELD, 0);
if(LENGTH_FIELD == cur_sent) {
size_t len = msg.size();
cur_sent = send(_sockets[id], &len, sizeof(len), 0);
if(sizeof(len) == cur_sent) {
unsigned int total_sent = 0;
unsigned int remaining = 0;
while(total_sent != msg.len) {
remaining = (msg.len-total_sent)>_max_msg_sz ? _max_msg_sz : (msg.len-total_sent);
cur_sent = send(_sockets[id], msg.msg+total_sent, remaining, 0);
while(total_sent != msg.size()) {
remaining = (msg.size()-total_sent)>_max_msg_sz ? _max_msg_sz : (msg.size()-total_sent);
cur_sent = send(_sockets[id], msg.data()+total_sent, remaining, 0);
//printf("Client:: msg.len=%u, remaining=%u, total_sent=%u, cur_sent = %d\n",msg.len, remaining, total_sent,cur_sent);
if(cur_sent == -1) {
fprintf(stderr,"Client:: Error: send msg failed: %s\n",strerror(errno));
@@ -188,4 +200,10 @@ void Client::_send_blocking(Msg msg, int id) {
} else if (-1 == cur_sent){
fprintf(stderr,"Client:: Error: send header failed: %s\n",strerror(errno));
}
#ifdef DEBUG_COMM
printf ("Client:: sent %u bytes at 0x%x to %d\n", msg.size(), msg.data(), id);
fflush(0);
phex(msg.data(), 4);
fflush(0);
#endif
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Client.h
*
* Created on: Jan 27, 2016
* Author: bush
*/
#ifndef NETWORK_INC_CLIENT_H_
@@ -17,6 +17,10 @@
#include <thread>
#include <mutex>
#include <condition_variable>
#include <boost/thread.hpp>
#include "Tools/WaitQueue.h"
#include "Tools/FlexBuffer.h"
#define CONNECT_INTERVAL (1000000)
@@ -34,22 +38,20 @@ public:
virtual ~Client();
void Connect();
void Send(int id, const char* message, unsigned int len);
void Broadcast(const char* message, unsigned int len);
void Broadcast2(const char* message, unsigned int len);
void Send(int id, SendBuffer& new_msg);
void Broadcast(SendBuffer& new_msg);
void Broadcast2(SendBuffer& new_msg);
void Stop();
private:
Queue<Msg>* _msg_queues;
WaitQueue< shared_ptr<SendBuffer> >* _msg_queues;
// std::queue<Msg>* _msg_queues;
// boost::mutex* _msg_mux;
// boost::mutex* _thd_mux;
unsigned int _max_msg_sz;
std::mutex* _lockqueue;
std::condition_variable* _queuecheck;
bool* _new_message;
void _send_thread(int i);
void _send_blocking(Msg msg, int id);
void _send_blocking(SendBuffer& msg, int id);
void _connect();
void _connect_to_server(int i);
@@ -58,6 +60,8 @@ private:
struct sockaddr_in* _servers;
int* _sockets;
ClientUpdatable* _updatable;
boost::thread_group threads;
};
#endif /* NETWORK_INC_CLIENT_H_ */

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Node.cpp
*
* Created on: Jan 27, 2016
* Author: bush
*/
#include <string.h>
@@ -26,12 +26,12 @@ static void throw_bad_id(int id) {
Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num_parties)
:_id(my_id),
_updatable(updatable),
_connected_to_servers(false),
_num_parties_identified(0)
_num_parties_identified(0),
_updatable(updatable)
{
_parse_map(netmap_file, num_parties);
unsigned int max_message_size = BUFFER_SIZE/2;
unsigned int max_message_size = NETWORK_BUFFER_SIZE/2;
if(_id < 0 || _id > _numparties)
throw_bad_id(_id);
_ready_nodes = new bool[_numparties](); //initialized to false
@@ -41,21 +41,30 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num
}
Node::~Node() {
print_waiting();
delete(_client);
delete(_server);
delete (_endpoints);
delete (_ready_nodes);
delete (_clients_connected);
delete[] (_endpoints);
delete[] (_ready_nodes);
delete[] (_clients_connected);
}
void Node::Start() {
_client->Connect();
new boost::thread(&Node::_start, this);
boost::thread(&Node::_start, this).join();
_server->starter->join();
for (unsigned int i = 0; i < _server->listeners.size(); i++)
_server->listeners[i]->join();
}
void Node::Stop() {
_client->Stop();
}
void Node::_start() {
usleep(START_INTERVAL);
int interval = 10000;
int total_wait = 0;
while(true) {
bool all_ready = true;
if (_connected_to_servers) {
@@ -72,15 +81,27 @@ void Node::_start() {
break;
fprintf(stderr,"+");
// fprintf(stderr,"Node:: waiting for all nodes to get ready ; sleeping for %u usecs\n", START_INTERVAL);
usleep(START_INTERVAL);
if (interval < CONNECT_INTERVAL)
interval *= 2;
usleep(interval);
total_wait += interval;
if (total_wait > 60e6)
throw runtime_error("waiting for too long");
}
printf("All identified\n");
_client->Broadcast(ALL_IDENTIFIED,strlen(ALL_IDENTIFIED));
SendBuffer buffer;
buffer.serialize(ALL_IDENTIFIED, strlen(ALL_IDENTIFIED));
_client->Broadcast(buffer);
}
void Node::NewMsg(char* msg, unsigned int len, struct sockaddr_in* from) {
//printf("Node:: got message of length %d ",len);
// printf("from %s:%d\n", inet_ntoa(from->sin_addr), ntohs(from->sin_port));
void Node::NewMsg(ReceivedMsg& message, struct sockaddr_in* from) {
char* msg = message.data();
size_t len = message.size();
#ifdef DEBUG_COMM
printf("Node:: got message of length %u at 0x%x\n",len,msg);
printf("from %s:%d\n", inet_ntoa(from->sin_addr), ntohs(from->sin_port));
phex(msg, 4);
#endif
if(len == strlen(ID_HDR)+sizeof(_id) && 0==strncmp(msg, ID_HDR, strlen(ID_HDR))) {
int *id = (int*)(msg+strlen(ID_HDR));
printf("Node:: identified as party: %d\n", *id);
@@ -94,12 +115,16 @@ void Node::NewMsg(char* msg, unsigned int len, struct sockaddr_in* from) {
printf("Node:: received ALL_IDENTIFIED from %d\n",_clientsmap[from]);
_num_parties_identified++;
if(_num_parties_identified == _numparties-1) {
printf("Node:: received ALL_IDENTIFIED from ALL\n",_clientsmap[from]);
printf("Node:: received ALL_IDENTIFIED from ALL\n");
_updatable->NodeReady();
}
return;
}
_updatable->NewMessage(_clientsmap[from], msg, len );
_updatable->NewMessage(_clientsmap[from], message);
#ifdef DEBUG_COMM
printf("finished with %d bytes at 0x%x\n", len, msg);
phex(msg, 4);
#endif
}
void Node::ClientsConnected() {
@@ -109,6 +134,7 @@ void Node::ClientsConnected() {
void Node::NodeAborted(struct sockaddr_in* from)
{
printf("Node:: party %d has aborted\n",_clientsmap[from]);
Stop();
}
void Node::ConnectedToServers() {
@@ -117,25 +143,43 @@ void Node::ConnectedToServers() {
_identify();
}
void Node::Send(int to, const char* msg, unsigned int len) {
void Node::Send(int to, SendBuffer& msg) {
int new_recipient = to>_id?to-1:to;
//printf("Node:: new_recipient=%d\n",new_recipient);
_client->Send(new_recipient, msg, len);
#ifdef DEBUG_COMM
printf("Node:: new_recipient=%d\n",new_recipient);
printf("Send %d bytes at 0x%x\n", msg.size(), msg.data());
phex(msg.data(), 4);
#endif
_client->Send(new_recipient, msg);
}
void Node::Broadcast(const char* msg, unsigned int len) {
_client->Broadcast(msg, len);
void Node::Broadcast(SendBuffer& msg) {
#ifdef DEBUG_COMM
printf("Broadcast %d bytes at 0x%x\n", msg.size(), msg.data());
phex(msg.data(), 4);
#endif
_client->Broadcast(msg);
}
void Node::Broadcast2(const char* msg, unsigned int len) {
_client->Broadcast2(msg, len);
void Node::Broadcast2(SendBuffer& msg) {
#ifdef DEBUG_COMM
printf("Broadcast2 %d bytes at 0x%x\n", msg.size(), msg.data());
phex(msg.data(), 4);
#endif
_client->Broadcast2(msg);
}
void Node::_identify() {
char* msg = new char[strlen(ID_HDR)+sizeof(_id)];
char* msg = id_msg;
strncpy(msg, ID_HDR, strlen(ID_HDR));
strncpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
//printf("Node:: identifying myself:\n");
_client->Broadcast(msg,strlen(ID_HDR)+4);
SendBuffer buffer;
buffer.serialize(msg, strlen(ID_HDR)+4);
#ifdef DEBUG_COMM
cout << "message for identification:";
phex(buffer.data(), 4);
#endif
_client->Broadcast(buffer);
}
void Node::_parse_map(const char* netmap_file, int num_parties) {
@@ -165,11 +209,27 @@ void Node::_parse_map(const char* netmap_file, int num_parties) {
for(int i=0; i<_numparties; i++) {
if(_id == i) {
netmap >> _ip >> _port;
//printf("Node:: my address: %s:%d\n", _ip.c_str(),_port);
#ifdef DEBUG_NETMAP
printf("Node:: my address: %s:%d\n", _ip.c_str(),_port);
#endif
continue;
}
netmap >> _endpoints[j].ip >> _endpoints[j].port;
#ifdef DEBUG_NETMAP
printf("Node:: other address (%d): %s:%d\n", j,
_endpoints[j].ip.c_str(), _endpoints[j].port);
#endif
j++;
}
}
}
void Node::print_waiting()
{
for (unsigned i = 0; i < _server->timers.size(); i++)
{
cout << "Waited " << _server->timers[i].elapsed()
<< " seconds for client "
<< _clientsmap[_server->get_client_addr(i)] << endl;
}
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Node.h
*
* Created on: Jan 27, 2016
* Author: bush
*/
#ifndef NETWORK_NODE_H_
@@ -11,11 +11,14 @@
#include <string>
#include <map>
#include <atomic>
#include <vector>
#include "common.h"
#include "Client.h"
#include "Server.h"
#include "Tools/FlexBuffer.h"
class Server;
class Client;
class ServerUpdatable;
@@ -24,7 +27,7 @@ class ClientUpdatable;
class NodeUpdatable {
public:
virtual void NodeReady()=0;
virtual void NewMessage(int from, char* message, unsigned int len) =0;
virtual void NewMessage(int from, ReceivedMsg& msg) =0;
virtual void NodeAborted(struct sockaddr_in* from) =0;
};
@@ -34,26 +37,29 @@ typedef void (*msg_id_cb_t)(int from, char* msg, unsigned int len);
const char ALL_IDENTIFIED[] = "ALID";
#define LOOPBACK NULL
#define LOCALHOST_IP "127.0.0.1"
#define PORT_BASE (4000)
#define PORT_BASE (14000)
class Node : public ServerUpdatable, public ClientUpdatable {
public:
Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num_parties=0);
virtual ~Node();
void Send(int to, const char* msg, unsigned int len);
void Broadcast(const char* msg, unsigned int len);
void Broadcast2(const char* msg, unsigned int len);
void Send(int to, SendBuffer& msg);
void Broadcast(SendBuffer& msg);
void Broadcast2(SendBuffer& msg);
inline int NumParties(){return _numparties;}
void Start();
void Stop();
// void Close();
//derived from ServerUpdateable
void NewMsg(char* msg, unsigned int len, struct sockaddr_in* from);
void NewMsg(ReceivedMsg& msg, struct sockaddr_in* from);
void ClientsConnected();
void NodeAborted(struct sockaddr_in* from);
//derived from ClientUpdatable
void ConnectedToServers();
void print_waiting();
private:
void _parse_map(const char* netmap_file, int num_parties);
void _identify();
@@ -74,6 +80,8 @@ private:
std::map<struct sockaddr_in*,int> _clientsmap;
bool* _clients_connected;
NodeUpdatable* _updatable;
char id_msg[strlen(ID_HDR)+sizeof(_id)];
};
#endif /* NETWORK_NODE_H_ */

View File

@@ -1,3 +1,5 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include <stdio.h>
#include <errno.h>
@@ -9,19 +11,21 @@
#include <boost/thread.hpp>
#include <iostream>
#include <fstream>
#include <vector>
#include "Server.h"
/* Opens server socket for listening - not yet accepting */
Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size)
:_port(port),
:starter(0),
_expected_clients(expected_clients),
_port(port),
_updatable(updatable),
_max_msg_sz(max_message_size)
{
_clients = new int[expected_clients]();
_clients_addr = new struct sockaddr_in[expected_clients]();
timers.resize(expected_clients);
_servfd = socket(AF_INET, SOCK_STREAM, 0);
if (-1 == _servfd)
@@ -41,16 +45,21 @@ Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsig
printf("Server:: Error listen: \n%s\n",strerror(errno));
new boost::thread(&Server::_start_server, this);
starter = new boost::thread(&Server::_start_server, this);
}
Server::~Server() {
#ifdef DEBUG_COMM
printf("Server:: Server being deleted\n");
#endif
close(_servfd);
for (int i=0; i<_expected_clients; i++)
close(_clients[i]);
delete (_clients);
delete (_clients_addr);
delete[] (_clients);
delete[] (_clients_addr);
delete starter;
for (unsigned i = 0; i < listeners.size(); i++)
delete listeners[i];
}
void Server::_start_server() {
@@ -64,8 +73,10 @@ void Server::_start_server() {
printf("Server:: accept: error in connecting socket\n%s\n",strerror(errno));
} else {
printf("Server:: Incoming connection from %s:%d\n",inet_ntoa(_clients_addr[i].sin_addr), ntohs(_clients_addr[i].sin_port));
setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &BUFFER_SIZE, sizeof(BUFFER_SIZE));
boost::thread* listener = new boost::thread(&Server::_listen_to_client, this, i);
// Using the following disables the automatic buffer size (ipv4.tcp_rmem)
// in favour of the core.rmem_max, which is worse.
//setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &NETWORK_BUFFER_SIZE, sizeof(NETWORK_BUFFER_SIZE));
listeners.push_back(new boost::thread(&Server::_listen_to_client, this, i));
}
}
@@ -73,23 +84,32 @@ void Server::_start_server() {
}
void Server::_listen_to_client(int id){
int msg_len = 0;
size_t msg_len = 0;
int n_recv = 0;
unsigned int total_received;
unsigned int remaining;
char *msg;
size_t total_received;
size_t remaining;
ReceivedMsg msg;
#ifdef DEBUG_FLEXBUF
cout << "msg from " << id << " stored at " << &msg << endl;
#endif
while (true) {
n_recv = recv(_clients[id], &msg_len, LENGTH_FIELD, MSG_WAITALL);
if (!_handle_recv_len(id, n_recv,LENGTH_FIELD))
timers[id].start();
n_recv = recv(_clients[id], &msg_len, sizeof(msg_len), MSG_WAITALL);
timers[id].stop();
#ifdef DEBUG_COMM
cout << "message of size " << msg_len << " coming in from " << id << endl;
#endif
msg.resize(msg_len);
if (!_handle_recv_len(id, n_recv, sizeof(msg_len)))
return;
// printf("Server:: waiting for a message of len = %d\n", msg_len);
msg = new char[msg_len];
assert(msg != NULL);
total_received = 0;
remaining = 0;
while (total_received != msg_len) {
remaining = (msg_len-total_received)>_max_msg_sz ? _max_msg_sz : (msg_len-total_received);
n_recv = recv(_clients[id], msg+total_received, remaining, NULL /* MSG_WAITALL*/);
timers[id].start();
n_recv = recv(_clients[id], msg.data()+total_received, remaining, 0 /* MSG_WAITALL*/);
timers[id].stop();
// printf("n_recv = %d\n", n_recv);
if (!_handle_recv_len(id, n_recv,remaining)) {
printf("returning\n");
@@ -99,19 +119,21 @@ void Server::_listen_to_client(int id){
// printf("total_received = %d\n", total_received);
}
// printf("Server:: received %d: \n", msg_len);
_updatable->NewMsg(msg, msg_len, &_clients_addr[id]);
_updatable->NewMsg(msg, &_clients_addr[id]);
}
printf("stop listenning to %d\n", id);
}
bool Server::_handle_recv_len(int id, unsigned int actual_len, unsigned int expected_len) {
bool Server::_handle_recv_len(int id, size_t actual_len, size_t expected_len) {
// printf("Server:: received msg from %d len = %u\n",id, actual_len);
if (actual_len == 0) {
#ifdef DEBUG_COMM
printf("Server:: [%d]: Error: n_recv==0 Connection closed\n", id);
#endif
_updatable->NodeAborted(&_clients_addr[id]);
return false;
// exit(1);
} else if (actual_len == -1) {
} else if (actual_len == -1U) {
printf("Server:: [%d]: Error: n_recv==-1. \"%s\"\n",id, strerror(errno));
_updatable->NodeAborted(&_clients_addr[id]);
return false;
@@ -120,4 +142,5 @@ bool Server::_handle_recv_len(int id, unsigned int actual_len, unsigned int expe
// printf("Server:: [%d]: Error: n_recv < %d; n_recv=%d; \"%d\"\n",id, expected_len, actual_len,strerror(errno));
return true;
}
return true;
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Server.h
*
* Created on: Jan 24, 2016
* Author: bush
*/
#ifndef NETWORK_INC_SERVER_H_
@@ -10,13 +10,15 @@
#include <sys/socket.h>
#include <netinet/in.h>
#include <vector>
#include "common.h"
#include "Tools/FlexBuffer.h"
class ServerUpdatable {
public:
virtual void ClientsConnected()=0;
virtual void NewMsg(char* msg, unsigned int len, struct sockaddr_in* from)=0;
virtual void NewMsg(ReceivedMsg& msg, struct sockaddr_in* from)=0;
virtual void NodeAborted(struct sockaddr_in* from) =0;
};
@@ -25,6 +27,13 @@ public:
Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size);
~Server();
sockaddr_in* get_client_addr(int id) { return &_clients_addr[id]; }
boost::thread* starter;
std::vector<boost::thread*> listeners;
vector<Timer> timers;
private:
int _expected_clients;
int *_clients;
@@ -38,7 +47,7 @@ private:
void _start_server();
void _listen_to_client(int id);
bool _handle_recv_len(int id, unsigned int actual_len,unsigned int expected_len);
bool _handle_recv_len(int id, size_t actual_len, size_t expected_len);
};

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* common.h
*
* Created on: Jan 27, 2016
* Author: bush
*/
#ifndef NETWORK_INC_COMMON_H_
@@ -13,23 +13,16 @@
//#include "utils.h"
#define LENGTH_FIELD (4)
/*
* To change the buffer sizes in the kernel
# echo 'net.core.wmem_max=12582912' >> /etc/sysctl.conf
# echo 'net.core.rmem_max=12582912' >> /etc/sysctl.conf
*/
const int BUFFER_SIZE = 20000000;
const int NETWORK_BUFFER_SIZE = 20000000;
typedef struct {
std::string ip;
int port;
} endpoint_t;
typedef struct {
const char* msg;
unsigned int len;
} Msg;
#endif /* NETWORK_INC_COMMON_H_ */

View File

@@ -1,43 +0,0 @@
CPPFLAGS=-g -O3 -c -w --std=c++11 \
-I/usr/include \
-Iinc \
$(LOCAL_CPPFLAGS)
LDFLAGS=-g -L/usr/lib/x86_64-linux-gnu/ \
-lboost_system -lpthread -lboost_thread $(LOCAL_LDFLAGS)
OBJECTS=build/test.o \
build/utils.o \
build/Server.o \
build/Client.o \
build/Node.o
all:bin/net
bin/net:$(OBJECTS)
g++ -Wall $^ -o bin/net $(LDFLAGS)
build/Server.o:src/Server.cpp
$(CC) $(CPPFLAGS) $^ -o $@
build/Client.o:src/Client.cpp
$(CC) $(CPPFLAGS) $^ -o $@
build/Node.o:src/Node.cpp
$(CC) $(CPPFLAGS) $^ -o $@
build/test.o:test/test.cpp
$(CC) $(CPPFLAGS) $^ -o $@
build/utils.o:src/utils.cpp
$(CC) $(CPPFLAGS) $^ -o $@
clean:
rm build/* bin/net

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* mq.h
*
* Created on: Feb 14, 2016
* Author: bush
*/
#ifndef NETWORK_INC_MQ_H_

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* utils.cpp
*
* Created on: Jan 31, 2016
* Author: bush
*/
@@ -11,6 +11,7 @@
#include <unistd.h>
#include <fcntl.h>
#include <string.h>
#include <stdlib.h>
#include "utils.h"
@@ -22,7 +23,7 @@ void fill_random(void* buffer, unsigned int length)
}
char cs(char* msg, unsigned int len, char result) {
for(int i = 0; i < len; i++)
for(size_t i = 0; i < len; i++)
result += msg[i];
return result;
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* utils.h
*
* Created on: Jan 31, 2016
* Author: bush
*/
#ifndef NETWORK_TEST_UTILS_H_

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* prf.cpp
*
* Created on: Feb 18, 2016
* Author: bush
*/
@@ -10,22 +10,14 @@
#include "aes.h"
#include "proto_utils.h"
void PRF_single(const Key* key, char* input, char* output)
void PRF_single(const Key& key, char* input, char* output)
{
// printf("prf_single\n");
// std::cout << *key;
// phex(input, 16);
AES_KEY aes_key;
AES_128_Key_Expansion((const unsigned char*)(&(key->r)), &aes_key);
AES_128_Key_Expansion((const unsigned char*)(&(key.r)), &aes_key);
aes_key.rounds=10;
AES_encryptC((block*)input, (block*)output, &aes_key);
// phex(output, 16);
}
void PRF_chunk(const Key* key, char* input, char* output, int number)
{
AES_KEY aes_key;
AES_128_Key_Expansion((const unsigned char*)(&(key->r)), &aes_key);
aes_key.rounds=10;
AES_ecb_encrypt_chunk_in_out((block*)input, (block*)output, number, &aes_key);
}

View File

@@ -1,16 +1,37 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* prf.h
*
* Created on: Feb 18, 2016
* Author: bush
*/
#ifndef PROTOCOL_INC_PRF_H_
#define PROTOCOL_INC_PRF_H_
#include "Key.h"
#include "aes.h"
void PRF_single(const Key* key, char* input, char* output);
void PRF_chunk(const Key* key, char* input, char* output, int number);
#include "Tools/aes.h"
void PRF_single(const Key& key, char* input, char* output);
inline void PRF_chunk(const Key& key, char* input, char* output, int number)
{
__m128i* in = (__m128i*)input;
__m128i* out = (__m128i*)output;
AES_KEY aes_key;
AES_128_Key_Expansion((unsigned char*)&key.r, &aes_key);
switch (number)
{
case 2:
ecb_aes_128_encrypt<2>(out, in, (octet*)aes_key.rd_key);
break;
case 3:
ecb_aes_128_encrypt<3>(out, in, (octet*)aes_key.rd_key);
break;
default:
throw not_implemented();
}
}
#endif /* PROTOCOL_INC_PRF_H_ */

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* utils.cpp
*
* Created on: Jan 31, 2016
* Author: bush
*/
@@ -21,12 +21,9 @@ void fill_message_type(void* buffer, MSG_TYPE type)
memcpy(buffer, &type, sizeof(MSG_TYPE));
}
void fill_message_type(vector<char>& buffer, MSG_TYPE type)
void fill_message_type(SendBuffer& buffer, MSG_TYPE type)
{
// compatibility
buffer.resize(buffer.size() + 4);
char* start = (char*)&type;
copy(start, start + sizeof(MSG_TYPE), buffer.end() - 4);
buffer.serialize(type);
}

View File

@@ -1,8 +1,8 @@
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* utils.h
*
* Created on: Jan 31, 2016
* Author: bush
*/
#ifndef PROTO_UTILS_H_
@@ -11,23 +11,30 @@
#include "msg_types.h"
#include <time.h>
#include <sys/time.h>
#include <vector>
#include <iostream>
using namespace std;
#include "Tools/avx_memcpy.h"
#include "Tools/FlexBuffer.h"
#define LOOPBACK_STR "LOOPBACK"
void fill_random(void* buffer, unsigned int length);
void fill_message_type(void* buffer, MSG_TYPE type);
class SendBuffer;
char cs(char* msg, unsigned int len, char result=0);
void fill_message_type(void* buffer, MSG_TYPE type);
void fill_message_type(SendBuffer& buffer, MSG_TYPE type);
void phex (const void *addr, int len);
//inline void xor_big(const char* input1, const char* input2, char* output);
inline timeval* GET_TIME() {
struct timeval* now = new struct timeval();
int rc = gettimeofday(now, 0);
inline timeval GET_TIME() {
struct timeval now;
int rc = gettimeofday(&now, 0);
if (rc != 0) {
perror("gettimeofday");
}
@@ -46,6 +53,29 @@ inline unsigned long PRINT_DIFF(struct timeval* before, struct timeval* after) {
return diff;
}
inline void phex(const FlexBuffer& buffer) { phex(buffer.data(), buffer.size()); }
inline void print_bit_array(const char* bits, int len)
{
for (int i = 0; i < len; i++)
{
if (i % 8 == 0)
cout << " ";
cout << (int)bits[i];
}
cout << endl;
}
inline void print_bit_array(const vector<char>& bits)
{
print_bit_array(bits.data(), bits.size());
}
inline void print_indices(const vector<int>& indices)
{
for (unsigned i = 0; i < indices.size(); i++)
cout << indices[i] << " ";
cout << endl;
}
#endif /* NETWORK_TEST_UTILS_H_ */

View File

@@ -1,62 +0,0 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here.
## 0.0.3 (Mar 2, 2018)
- Added offline phases based on homomorphic encryption, used in the [SPDZ-2 paper](https://eprint.iacr.org/2012/642) and the [Overdrive paper](https://eprint.iacr.org/2017/1230).
- On macOS, the minimum requirement is now Sierra.
- Compilation with LLVM/clang is now possible (tested with 3.8).
## 0.0.2 (Sep 13, 2017)
### Support sockets based external client input and output to a SPDZ MPC program.
See the [ExternalIO directory](./ExternalIO/README.md) for more details and examples.
Note that [libsodium](https://download.libsodium.org/doc/) is now a dependency on the SPDZ build.
Added compiler instructions:
* LISTEN
* ACCEPTCLIENTCONNECTION
* CONNECTIPV4
* WRITESOCKETSHARE
* WRITESOCKETINT
Removed instructions:
* OPENSOCKET
* CLOSESOCKET
Modified instructions:
* READSOCKETC
* READSOCKETS
* READSOCKETINT
* WRITESOCKETC
* WRITESOCKETS
Support secure external client input and output with new instructions:
* READCLIENTPUBLICKEY
* INITSECURESOCKET
* RESPSECURESOCKET
### Read/Write secret shares to disk to support persistence in a SPDZ MPC program.
Added compiler instructions:
* READFILESHARE
* WRITEFILESHARE
### Other instructions
Added compiler instructions:
* DIGESTC - Clear truncated hash computation
* PRINTINT - Print register value
## 0.0.1 (Sep 2, 2016)
### Initial Release
* See `README.md` and `tutorial.md`.

10
CONFIG
View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
ROOT = .
@@ -17,6 +17,10 @@ USE_GF2N_LONG = 0
# AVX2 support (Haswell or later) changes the bit matrix transpose
ARCH = -mtune=native -mavx
# defaults for BMR, change number of parties here
CFLAGS = -DN_PARTIES=2 -DFREE_XOR -DKEY_SIGNAL -DSPDZ_AUTH -DNO_INPUT -DMAX_INLINE
USE_GF2N_LONG = 1
#use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine
@@ -39,8 +43,10 @@ ifeq ($(OS), Linux)
LDLIBS += -lrt
endif
BOOST = -lboost_system -lboost_thread $(MY_BOOST)
CXX = g++
CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror
CPPFLAGS = $(CFLAGS)
LD = g++

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Check-Offline.cpp

2
Compiler/GC/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt

166
Compiler/GC/instructions.py Normal file
View File

@@ -0,0 +1,166 @@
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import Compiler.instructions_base as base
import Compiler.instructions as spdz
import Compiler.tools as tools
import collections
import itertools
class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
class ClearBitsAF(base.RegisterArgFormat):
reg_type = 'cb'
base.ArgFormats['sb'] = SecretBitsAF
base.ArgFormats['sbw'] = SecretBitsAF
base.ArgFormats['cb'] = ClearBitsAF
base.ArgFormats['cbw'] = ClearBitsAF
opcodes = dict(
XORS = 0x200,
XORM = 0x201,
ANDRS = 0x202,
BITDECS = 0x203,
BITCOMS = 0x204,
CONVSINT = 0x205,
LDMSDI = 0x206,
STMSDI = 0x207,
LDMSD = 0x208,
STMSD = 0x209,
LDBITS = 0x20a,
XORCI = 0x210,
BITDECC = 0x211,
CONVCINT = 0x213,
REVEAL = 0x214,
STMSDCI = 0x215,
)
class xors(base.Instruction):
code = opcodes['XORS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
class xorm(base.Instruction):
code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb']
class xorc(base.Instruction):
code = base.opcodes['XORC']
arg_format = ['cbw','cb','cb']
class xorci(base.Instruction):
code = opcodes['XORCI']
arg_format = ['cbw','cb','int']
class andrs(base.Instruction):
code = opcodes['ANDRS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
class addc(base.Instruction):
code = base.opcodes['ADDC']
arg_format = ['cbw','cb','cb']
class addci(base.Instruction):
code = base.opcodes['ADDCI']
arg_format = ['cbw','cb','int']
class mulci(base.Instruction):
code = base.opcodes['MULCI']
arg_format = ['cbw','cb','int']
class bitdecs(base.VarArgsInstruction):
code = opcodes['BITDECS']
arg_format = tools.chain(['sb'], itertools.repeat('sbw'))
class bitcoms(base.VarArgsInstruction):
code = opcodes['BITCOMS']
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
class bitdecc(base.VarArgsInstruction):
code = opcodes['BITDECC']
arg_format = tools.chain(['cb'], itertools.repeat('cbw'))
class shrci(base.Instruction):
code = base.opcodes['SHRCI']
arg_format = ['cbw','cb','int']
class ldbits(base.Instruction):
code = opcodes['LDBITS']
arg_format = ['sbw','i','i']
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
code = base.opcodes['LDMS']
arg_format = ['sbw','int']
class stms(base.DirectMemoryWriteInstruction):
code = base.opcodes['STMS']
arg_format = ['sb','int']
# def __init__(self, *args, **kwargs):
# super(type(self), self).__init__(*args, **kwargs)
# import inspect
# self.caller = [frame[1:] for frame in inspect.stack()[1:]]
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
code = base.opcodes['LDMC']
arg_format = ['cbw','int']
class stmc(base.DirectMemoryWriteInstruction):
code = base.opcodes['STMC']
arg_format = ['cb','int']
class ldmsi(base.ReadMemoryInstruction):
code = base.opcodes['LDMSI']
arg_format = ['sbw','ci']
class stmsi(base.WriteMemoryInstruction):
code = base.opcodes['STMSI']
arg_format = ['sb','ci']
class ldmsdi(base.ReadMemoryInstruction):
code = opcodes['LDMSDI']
arg_format = tools.cycle(['sbw','cb','int'])
class stmsdi(base.WriteMemoryInstruction):
code = opcodes['STMSDI']
arg_format = tools.cycle(['sb','cb'])
class ldmsd(base.ReadMemoryInstruction):
code = opcodes['LDMSD']
arg_format = tools.cycle(['sbw','int','int'])
class stmsd(base.WriteMemoryInstruction):
code = opcodes['STMSD']
arg_format = tools.cycle(['sb','int'])
class stmsdci(base.WriteMemoryInstruction):
code = opcodes['STMSDCI']
arg_format = tools.cycle(['cb','cb'])
class convsint(base.Instruction):
code = opcodes['CONVSINT']
arg_format = ['int','sbw','ci']
class convcint(base.Instruction):
code = opcodes['CONVCINT']
arg_format = ['cbw','ci']
class movs(base.Instruction):
code = base.opcodes['MOVS']
arg_format = ['sbw','sb']
class bit(base.Instruction):
code = base.opcodes['BIT']
arg_format = ['sbw']
class reveal(base.Instruction):
code = opcodes['REVEAL']
arg_format = ['int','cbw','sb']
class print_reg(base.IOInstruction):
code = base.opcodes['PRINTREG']
arg_format = ['cb','i']
def __init__(self, reg, comment=''):
super(print_reg, self).__init__(reg, self.str_to_int(comment))
class print_reg_plain(base.IOInstruction):
code = base.opcodes['PRINTREGPLAIN']
arg_format = ['cb']

12
Compiler/GC/program.py Normal file
View File

@@ -0,0 +1,12 @@
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler import types, instructions
class Program(object):
def __init__(self, progname):
types.program = self
instructions.program = self
self.curr_tape = None
execfile(progname)
def malloc(self, *args):
pass

352
Compiler/GC/types.py Normal file
View File

@@ -0,0 +1,352 @@
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.types import MemValue, read_mem_value, regint, Array
from Compiler.program import Tape, Program
from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint
import Compiler.GC.instructions as inst
import operator
class bits(Tape.Register):
n = 40
size = 1
PreOp = staticmethod(floatingpoint.PreOpN)
MemValue = staticmethod(lambda value: MemValue(value))
@staticmethod
def PreOR(l):
return [1 - x for x in \
floatingpoint.PreOpN(operator.mul, \
[1 - x for x in l])]
@classmethod
def get_type(cls, length):
if length is None:
return cls
elif length == 1:
return cls.bit_type
if length not in cls.types:
class bitsn(cls):
n = length
cls.types[length] = bitsn
bitsn.__name__ = cls.__name__ + str(length)
return cls.types[length]
@classmethod
def conv(cls, other):
if isinstance(other, cls):
return other
elif isinstance(other, MemValue):
return cls.conv(other.read())
else:
res = cls()
res.load_other(other)
return res
hard_conv = conv
@classmethod
def compose(cls, items, bit_length):
return cls.bit_compose(sum([item.bit_decompose(bit_length) for item in items], []))
@classmethod
def bit_compose(cls, bits):
if len(bits) == 1:
return bits[0]
bits = list(bits)
res = cls.new(n=len(bits))
cls.bitcom(res, *bits)
res.decomposed = bits
return res
def bit_decompose(self, bit_length=None):
n = bit_length or self.n
if n > self.n:
raise Exception('wanted %d bits, only got %d' % (n, self.n))
if n == 1:
return [self]
if self.decomposed is None or len(self.decomposed) < n:
res = [self.bit_type() for i in range(n)]
self.bitdec(self, *res)
self.decomposed = res
return res
else:
return self.decomposed[:n]
@classmethod
def load_mem(cls, address, mem_type=None):
res = cls()
if mem_type == 'sd':
return cls.load_dynamic_mem(address)
else:
cls.load_inst[util.is_constant(address)](res, address)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, (int, long))](self, address)
def __init__(self, value=None, n=None):
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape)
self.set_length(n or self.n)
if value is not None:
self.load_other(value)
self.decomposed = None
def set_length(self, n):
if n > self.max_length:
print self.max_length
raise Exception('too long: %d' % n)
self.n = n
def load_other(self, other):
if isinstance(other, (int, long)):
self.set_length(self.n or util.int_len(other))
self.load_int(other)
elif isinstance(other, regint):
self.conv_regint(self.n, self, other)
elif isinstance(self, type(other)) or isinstance(other, type(self)):
self.mov(self, other)
else:
raise CompilerError('cannot convert from %s to %s' % \
(type(other), type(self)))
def __repr__(self):
return '%s(%d/%d)' % \
(super(bits, self).__repr__(), self.n, type(self).n)
class cbits(bits):
max_length = 64
reg_type = 'cb'
is_clear = True
load_inst = (None, inst.ldmc)
store_inst = (None, inst.stmc)
bitdec = inst.bitdecc
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
types = {}
def load_int(self, value):
self.load_other(regint(value))
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
if isinstance(other, cbits):
res = cbits(n=max(self.n, other.n))
c_inst(res, self, other)
return res
else:
if util.is_constant(other):
if other >= 2**31 or other < -2**31:
return op(self, cbits(other))
res = cbits(n=max(self.n, len(bin(other)) - 2))
ci_inst(res, self, other)
return res
else:
return op(self, cbits(other))
__add__ = lambda self, other: \
self.clear_op(other, inst.addc, inst.addci, operator.add)
__xor__ = lambda self, other: \
self.clear_op(other, inst.xorc, inst.xorci, operator.xor)
__radd__ = __add__
__rxor__ = __xor__
def __mul__(self, other):
if isinstance(other, cbits):
return NotImplemented
else:
try:
res = cbits(n=min(self.max_length, self.n+util.int_len(other)))
inst.mulci(res, self, other)
return res
except TypeError:
return NotImplemented
def __rshift__(self, other):
res = cbits(n=self.n-other)
inst.shrci(res, self, other)
return res
def print_reg(self, desc=''):
inst.print_reg(self, desc)
def print_reg_plain(self):
inst.print_reg_plain(self)
output = print_reg_plain
def reveal(self):
return self
class sbits(bits):
max_length = 128
reg_type = 'sb'
is_clear = False
clear_type = cbits
default_type = cbits
load_inst = (inst.ldmsi, inst.ldms)
store_inst = (inst.stmsi, inst.stms)
bitdec = inst.bitdecs
bitcom = inst.bitcoms
conv_regint = inst.convsint
mov = inst.movs
types = {}
def __init__(self, *args, **kwargs):
bits.__init__(self, *args, **kwargs)
@staticmethod
def new(value=None, n=None):
if n == 1:
return sbit(value)
else:
return sbits.get_type(n)(value)
@staticmethod
def get_random_bit():
res = sbit()
inst.bit(res)
return res
@classmethod
def load_dynamic_mem(cls, address):
res = cls()
if isinstance(address, (long, int)):
inst.ldmsd(res, address, cls.n)
else:
inst.ldmsdi(res, address, cls.n)
return res
def store_in_dynamic_mem(self, address):
if isinstance(address, (long, int)):
inst.stmsd(self, address)
else:
inst.stmsdi(self, cbits.conv(address))
def load_int(self, value):
if abs(value) < 2**31:
if (abs(value) > (1 << self.n)):
raise Exception('public value %d longer than %d bits' \
% (value, self.n))
inst.ldbits(self, self.n, value)
else:
value %= 2**self.n
if value >> 64 != 0:
raise NotImplementedError('public value too large')
self.load_other(regint(value))
@read_mem_value
def __add__(self, other):
if isinstance(other, int):
return self.xor_int(other)
else:
if not isinstance(other, sbits):
other = sbits(other)
n = self.n
else:
n = max(self.n, other.n)
res = self.new(n=n)
inst.xors(n, res, self, other)
return res
__radd__ = __add__
__sub__ = __add__
__xor__ = __add__
__rxor__ = __add__
@read_mem_value
def __rsub__(self, other):
if isinstance(other, cbits):
return other + self
else:
return self.xor_int(other)
@read_mem_value
def __mul__(self, other):
if isinstance(other, int):
return self.mul_int(other)
try:
if min(self.n, other.n) != 1:
raise NotImplementedError('high order multiplication')
n = max(self.n, other.n)
res = self.new(n=max(self.n, other.n))
order = (self, other) if self.n != 1 else (other, self)
inst.andrs(n, res, *order)
return res
except AttributeError:
return NotImplemented
@read_mem_value
def __rmul__(self, other):
if isinstance(other, cbits):
return other * self
else:
return self.mul_int(other)
def xor_int(self, other):
if other == 0:
return self
self_bits = self.bit_decompose()
other_bits = util.bit_decompose(other, max(self.n, util.int_len(other)))
extra_bits = [self.new(b, n=1) for b in other_bits[self.n:]]
return self.bit_compose([~x if y else x \
for x,y in zip(self_bits, other_bits)] \
+ extra_bits)
def mul_int(self, other):
if other == 0:
return 0
elif other == 1:
return self
elif self.n == 1:
bits = util.bit_decompose(other, util.int_len(other))
zero = sbit(0)
mul_bits = [self if b else zero for b in bits]
return self.bit_compose(mul_bits)
else:
print self.n, other
return NotImplemented
def __lshift__(self, i):
return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i])
def __invert__(self):
# res = type(self)(n=self.n)
# inst.nots(res, self)
# return res
one = self.new(value=1, n=1)
bits = [one + bit for bit in self.bit_decompose()]
return self.bit_compose(bits)
def __neg__(self):
return self
def reveal(self):
if self.n > self.clear_type.max_length:
raise Exception('too long to reveal')
res = self.clear_type(n=self.n)
inst.reveal(self.n, res, self)
return res
def equal(self, other, n=None):
bits = (~(self + other)).bit_decompose()
return reduce(operator.mul, bits)
class bit(object):
n = 1
class sbit(bit, sbits):
def if_else(self, x, y):
return self * (x ^ y) ^ y
class cbit(bit, cbits):
pass
sbits.bit_type = sbit
cbits.bit_type = cbit
class bitsBlock(oram.Block):
value_type = sbits
def __init__(self, value, start, lengths, entries_per_block):
oram.Block.__init__(self, value, lengths)
length = sum(self.lengths)
used_bits = entries_per_block * length
self.value_bits = self.value.bit_decompose(used_bits)
start_length = util.log2(entries_per_block)
self.start_bits = util.bit_decompose(start, start_length)
self.start_demux = oram.demux_list(self.start_bits)
self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \
for i in range(entries_per_block)]
self.mul_entries = map(operator.mul, self.start_demux, self.entries)
self.bits = sum(self.mul_entries).bit_decompose()
self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths))
self.anti_value = self.mul_value + self.value
def set_slice(self, value):
value = sbits.compose(util.tuplify(value), sum(self.lengths))
for i,b in enumerate(self.start_bits):
value = b.if_else(value << (2**i * sum(self.lengths)), value)
self.value = value + self.anti_value
return self
oram.block_types[sbits] = bitsBlock
class dyn_sbits(sbits):
pass
class DynamicArray(Array):
def __init__(self, *args):
Array.__init__(self, *args)
def _malloc(self):
return Program.prog.malloc(self.length, 'sd', self.value_type)
def _load(self, address):
return self.value_type.load_dynamic_mem(cbits.conv(address))
def _store(self, value, address):
if isinstance(value, MemValue):
value = value.read()
if isinstance(value, sbits):
self.value_type.conv(value).store_in_dynamic_mem(address)
else:
cbits.conv(value).store_in_dynamic_mem(address)
sbits.dynamic_array = DynamicArray
cbits.dynamic_array = Array

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import compilerLib, program, instructions, types, library, floatingpoint
import inspect

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import itertools, time
from collections import defaultdict, deque
@@ -151,7 +151,7 @@ def determine_scope(block, options):
block.defined_registers = set(last_def.iterkeys())
class Merger:
def __init__(self, block, options):
def __init__(self, block, options, merge_classes, stop_class=stopopen_class):
self.block = block
self.instructions = block.instructions
self.options = options
@@ -159,7 +159,7 @@ class Merger:
self.max_parallel_open = int(options.max_parallel_open)
else:
self.max_parallel_open = float('inf')
self.dependency_graph()
self.dependency_graph(merge_classes, stop_class)
def do_merge(self, merges_iter):
""" Merge an iterable of nodes in G, returning the number of merged
@@ -341,8 +341,7 @@ class Merger:
preorder.extend(reversed(startinputs))
return preorder
def longest_paths_merge(self, instruction_type=startopen_class,
merge_stopopens=True):
def longest_paths_merge(self, merge_stopopens=True):
""" Attempt to merge instructions of type instruction_type (which are given in
merge_nodes) using longest paths algorithm.
@@ -357,8 +356,6 @@ class Merger:
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
if instruction_type is not startopen_class and merge_stopopens:
raise CompilerError('Cannot merge stopopens whilst merging %s instructions' % instruction_type)
if not merge_nodes and not self.input_nodes:
return 0
@@ -379,8 +376,7 @@ class Merger:
my_merge = (m for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b)
if merge_stopopens:
my_stopopen = [G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b]
my_stopopen = filter(lambda x: x != -1, (G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b))
mc, nodes[0,b] = self.do_merge(iter(my_merge))
if merge_stopopens:
@@ -420,8 +416,16 @@ class Merger:
return len(merges)
def dependency_graph(self, merge_class=startopen_class):
def dependency_graph(self, merge_classes, stop_class):
""" Create the program dependency graph. """
if len(merge_classes) != 1:
if stop_class is not type(None):
raise NotImplementedError('stop merging only implemented ' \
'for single instruction')
if int(self.options.max_parallel_open):
raise NotImplementedError('parallel limit only implemented ' \
'for single instruction')
block = self.block
options = self.options
open_nodes = set()
@@ -453,13 +457,10 @@ class Merger:
next_available_depth = {}
self.sources = []
self.real_depths = [0] * len(block.instructions)
round_type = {}
def add_edge(i, j):
from_merge = isinstance(block.instructions[i], merge_class)
to_merge = isinstance(block.instructions[j], merge_class)
G.add_edge(i, j)
is_source = G.get_attr(i, 'is_source') and G.get_attr(j, 'is_source') and not from_merge
G.set_attr(j, 'is_source', is_source)
for d in (self.depths, self.real_depths):
if d[j] < d[i]:
d[j] = d[i]
@@ -515,7 +516,7 @@ class Merger:
for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used()
G.add_node(n, is_source=True)
G.add_node(n)
# if options.debug:
# col = colordict[instr.__class__.__name__]
@@ -534,13 +535,19 @@ class Merger:
else:
write(reg, n)
if isinstance(instr, merge_class):
if isinstance(instr, merge_classes):
open_nodes.add(n)
last_open.append(n)
G.add_node(n, merges=[])
if stop_class != type(None):
last_open.append(n)
# the following must happen after adding the edge
self.real_depths[n] += 1
depth = depths[n] + 1
while depth in round_type:
if round_type[depth] == type(instr):
break
depth += 1
round_type[depth] = type(instr)
if int(options.max_parallel_open):
skipped_depths = set()
while parallel_open[depth] >= int(options.max_parallel_open):
@@ -548,10 +555,12 @@ class Merger:
depth = next_available_depth.get(depth, depth + 1)
for d in skipped_depths:
next_available_depth[d] = depth
else:
self.real_depths[n] = depth
parallel_open[depth] += len(instr.args) * instr.get_size()
depths[n] = depth
if isinstance(instr, stopopen_class):
if isinstance(instr, stop_class):
startopen = last_open.popleft()
add_edge(startopen, n)
G.set_attr(startopen, 'stop', n)
@@ -609,7 +618,7 @@ class Merger:
(n, len(block.instructions)), time.asctime()
if len(open_nodes) > 1000:
print "Program has %d %s instructions" % (len(open_nodes), merge_class)
print "Program has %d %s instructions" % (len(open_nodes), merge_classes)
def merge_nodes(self, i, j):
""" Merge node j into i, removing node j """

295
Compiler/circuit_oram.py Normal file
View File

@@ -0,0 +1,295 @@
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.path_oram import *
from Compiler.util import bit_compose
def first_diff(a_bits, b_bits):
length = len(a_bits)
level_bits = [None] * length
not_found = 1
for i in range(length):
# find first position where bits differ (i.e. first 0 in 1 - a XOR b)
t = 1 - XOR(a_bits[i], b_bits[i])
prev_nf = not_found
not_found *= t
level_bits[i] = prev_nf - not_found
return level_bits, not_found
def find_deeper(a, b, path, start, length, compute_level=True):
a_bits = a.value.bit_decompose(length)
b_bits = b.value.bit_decompose(length)
path_bits = [type(a_bits[0])(x) for x in path.bit_decompose(length)]
a_bits.reverse()
b_bits.reverse()
path_bits.reverse()
level_bits = [0] * length
# make sure that winner is set at start if one input is empty
any_empty = OR(a.empty, b.empty)
a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)]
b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)]
diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]]
diff_preor = type(a.value).PreOR([any_empty] + diff)
diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])]
winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty)
winner_bits = [if_else(winner, bd, ad) for ad,bd in zip(a_diff, b_diff)]
winner_preor = type(a.value).PreOR(winner_bits)
level_bits = [x - y for x,y in zip(winner_preor, [0] + winner_preor)]
return [0] * start + level_bits + [1 - sum(level_bits)], winner
def find_deepest(paths, search_path, start, length, compute_level=True):
if len(paths) == 1:
return None, paths[0], 1
l = len(paths) / 2
_, a, a_index = find_deepest(paths[:l], search_path, start, length, False)
_, b, b_index = find_deepest(paths[l:], search_path, start, length, False)
level, winner = find_deeper(a, b, search_path, start, length, compute_level)
return level, if_else(winner, b, a), if_else(winner, b_index << l, a_index)
def ge_unary_public(a, b):
return sum(a[b-1:])
def gu_step(high, low):
greater = high[0] * (1 - high[1])
not_greater = high[1]
return if_else(not_greater, 0, high[0] + low[0]), \
if_else(greater, 0, high[1] + low[1])
def greater_unary(a, b):
if len(a) == 1:
return a[0], b[0]
else:
l = len(a) / 2
return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l]))
def comp_step(high, low):
prod = high[0] * high[1]
greater = high[0] - prod
smaller = high[1] - prod
deferred = 1 - greater - smaller
indicator = greater, smaller, deferred
return sum(map(operator.mul, indicator, (1, 0, low[0]))), \
sum(map(operator.mul, indicator, (0, 1, low[1])))
def comp_binary(a, b):
if len(a) != len(b):
raise CompilerError('Arguments must have same length: %s %s' % (str(a), str(b)))
if len(a) == 1:
return a[0], b[0]
else:
l = len(a) / 2
return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l]))
def unary_to_binary(l):
return sum(x * (i + 1) for i,x in enumerate(l)).bit_decompose(log2(len(l) + 1))
class CircuitORAM(PathORAM):
def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \
stash_size=None, bucket_size=2, init_rounds=-1):
self.bucket_oram = TrivialORAM
self.bucket_size = bucket_size
self.D = log2(size)
self.logD = log2(self.D)
self.L = self.D + 1
print 'create oram of size %d with depth %d and %d buckets' \
% (size, self.D, self.n_buckets())
self.value_type = value_type
self.index_type = value_type.get_type(self.D)
if entry_size is not None:
self.value_length = len(tuplify(entry_size))
self.entry_size = tuplify(entry_size)
else:
self.value_length = value_length
self.entry_size = [None] * value_length
self.index_size = log2(size)
self.size = size
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
index_size=self.index_size)
self.entry_type = empty_entry.types()
self.ram = RAM(self.bucket_size * 2**(self.D+1), self.entry_type, \
self.get_array)
self.buckets = self.ram
if init_rounds != -1:
# put memory initialization in different timer
stop_timer()
start_timer(1)
self.ram.init_mem(self.empty_entry(apply_type=False))
if init_rounds != -1:
stop_timer(1)
start_timer()
self.root = RefBucket(1, self)
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
stash_size = 20
vt, es = self.internal_entry_size()
self.stash = TrivialORAM(stash_size, vt, entry_size=es, \
index_size=self.index_size)
self.t = MemValue(regint(0))
self.state = MemValue(self.value_type.get_type(self.D)(0))
self.read_path = MemValue(value_type.clear_type(0))
# these include a read value from the stash
self.read_value = [Array(self.D + 2, self.value_type.get_type(l))
for l in self.entry_size]
self.read_empty = Array(self.D + 2, self.value_type.bit_type)
def get_ram_index(self, path, level):
clear_type = self.value_type.clear_type
return ((2**(self.D) + clear_type.conv(path)) >> (self.D - (level - 1)))
def get_bucket_ram(self, path, level):
if level == 0:
return self.stash.ram
else:
return RefRAM(self.get_ram_index(path, level), self)
def get_bucket_oram(self, path, level):
if level == 0:
return self.stash
else:
return RefTrivialORAM(self.get_ram_index(path, level), self)
def prepare_deepest(self, path):
deepest = [None] * (self.D + 2)
deepest_index = [None] * (self.D + 2)
src = Value()
stash_empty = self.stash.ram.is_empty()
level, _, index = find_deepest(self.stash.ram.get_value_array(0), path, 0, self.D)
goal = if_else(stash_empty, ValueTuple([0] * len(level)), unary_to_binary(level))
src = if_else(stash_empty, src, Value(0))
src_index = if_else(stash_empty, 0, index)
buckets = [self.get_bucket_ram(path, i) for i in range(self.L + 1)]
bucket_deepest = [(goal, src, src_index, None)]
for i in range(1, self.L):
l, _, index = find_deepest(buckets[i].get_value_array(0), path, i - 1, self.D)
bucket_deepest.append((unary_to_binary(l), Value(i), index, i))
def op(left, right, void=None):
goal, src, src_index, _ = left
l, secret_i, index, i = right
high, low = comp_binary(l, goal)
replace = high * (1 - low) * (1 - buckets[i].is_empty())
goal = if_else(replace, bit_compose(l), \
bit_compose(goal)).bit_decompose(len(goal))
src = if_else(replace, secret_i, src)
src_index = if_else(replace, index, src_index)
return goal, src, src_index, i
preop_bucket_deepest = self.value_type.PreOp(op, bucket_deepest)
for i in range(1, self.L + 1):
goal, src, src_index, _ = preop_bucket_deepest[i-1]
high, low = comp_binary(goal, bit_decompose(i, len(goal)))
cond = 1 - low * (1 - high)
deepest[i] = if_else(cond, src, Value())
deepest_index[i] = if_else(cond, src_index, 0)
return deepest, deepest_index
def prepare_target(self, path, deepest):
deepest, deepest_index = deepest
dest = Value()
src = Value()
src_index = 0
target = [None] * (self.L + 1)
target_index = [None] * (self.L + 1)
for i in range(self.L, -1 , -1):
i_eq_src = src.equal(i, self.logD + 1)
target[i] = if_else(i_eq_src, dest, Value())
target_index[i] = if_else(i_eq_src, src_index, 0)
dest = if_else(i_eq_src, Value(), dest)
src = if_else(i_eq_src, Value(), src)
if i == 0:
break
cond = or_op(dest.empty * self.get_bucket_ram(path, i).has_empty_entry(), \
(1 - target[i].empty)) * (1 - deepest[i].empty)
src = if_else(cond, deepest[i], src)
src_index = if_else(cond, deepest_index[i], src_index)
dest = if_else(cond, Value(i), dest)
return target, target_index
def evict_once(self, path):
deepest = self.prepare_deepest(path)
target = self.prepare_target(path, deepest)
evictor = self.evict_once_fast(path, target)
next(evictor)
towrite = next(evictor)
yield
self.add_evicted(path, towrite)
yield
def evict_once_fast(self, path, target):
target, target_index = target
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
index_size=self.index_size)
hold = empty_entry
dest = Value()
towrite = [None] * (self.L + 1)
for i in range(self.L + 1):
cond = (1 - hold.is_empty) * (dest.equal(i, self.logD + 1))
towrite[i] = if_else(cond, hold, empty_entry)
hold = if_else(cond, empty_entry, hold)
dest = if_else(cond, Value(), dest)
cond = 1 - target[i].empty
bucket = self.get_bucket_oram(path, i)
if i != self.L:
index = target_index[i].bit_decompose(bucket.size)
hold = if_else(cond, bucket.read_and_remove_by_public(index), hold)
dest = if_else(cond, target[i], dest)
if i == 1:
yield
yield towrite
def add_evicted(self, path, towrite):
# make sure to add after removing
for i in range(1, self.L + 1):
self.get_bucket_oram(path, i).add(towrite[i])
def evict_rounds(self):
get_path = lambda x: bit_compose(reversed(x.bit_decompose(self.D)))
paths = [get_path(2 * self.t + i) for i in range(2)]
for path in paths:
for _ in self.evict_once(path):
yield
self.t.iadd(1)
def evict(self):
raise CompilerException('Using this function is likely an error. Use recursive_evict() instead.')
Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-%d' % self.size)
for i,_ in enumerate(self.evict_rounds()):
Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-round-%d-%d' % (i, self.size))
def recursive_evict(self):
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-%d' % self.size)
for i,_ in enumerate(self.recursive_evict_rounds()):
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size))
def recursive_evict_rounds(self):
for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
yield
def bucket_indices_on_path_to(self, leaf):
# root is at 1, different to PathORAM
for level in range(self.D + 1):
base = self.get_ram_index(leaf, level + 1) * self.bucket_size
yield [base + i for i in range(self.bucket_size)]
def output(self):
print_ln('stash')
self.stash.output()
@for_range(1, 2**(self.D+1))
def f(i):
print_ln('node %s', self.value_type.clear_type(i))
RefRAM(i, self).output()
self.index.output()
def __repr__(self):
return repr(self.stash) + '\n' + repr(RefBucket(1, self))
class DebugCircuitORAM(CircuitORAM):
""" Debugging only. Tree ORAM using index revealing the access
pattern. """
index_structure = LocalIndexStructure
threshold = 2**10
def OptimalCircuitORAM(size, value_type, *args, **kwargs):
if size <= threshold:
print size, 'below threshold', threshold
return LinearORAM(size, value_type, *args, **kwargs)
else:
print size, 'above threshold', threshold
return RecursiveCircuitORAM(size, value_type, *args, **kwargs)
class RecursiveCircuitIndexStructure(PackedIndexStructure):
""" Secure index using secure tree ORAM. """
storage = staticmethod(OptimalCircuitORAM)
class RecursiveCircuitORAM(CircuitORAM):
""" Secure tree ORAM using secure index. """
index_structure = RecursiveCircuitIndexStructure
class AtLeastOneRecursionPackedCircuitORAM(PackedIndexStructure):
storage = RecursiveCircuitORAM
class AtLeastOneRecursionPackedCircuitORAMWithEmpty(PackedORAMWithEmpty):
storage = RecursiveCircuitORAM

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
"""
Functions for secure comparison of GF(p) types.

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.program import Program
from Compiler.config import *

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from collections import defaultdict

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.oram import *

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
class CompilerError(Exception):
"""Base class for compiler exceptions."""

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from math import log, floor, ceil
from Compiler.instructions import *

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import heapq
from Compiler.exceptions import *
@@ -19,7 +19,7 @@ class SparseDiGraph(object):
""" max_nodes: maximum no of nodes
default_attributes: dict of node attributes and default values """
if default_attributes is None:
default_attributes = { 'merges': None, 'stop': -1, 'start': -1, 'is_source': True }
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
self.default_attributes = default_attributes
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
self.n = max_nodes

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import sys
import math

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
""" This module is for classes of actual assembly instructions.
@@ -335,6 +335,14 @@ class crash(base.IOInstruction):
code = base.opcodes['CRASH']
arg_format = []
class start_grind(base.IOInstruction):
code = base.opcodes['STARTGRIND']
arg_format = []
class stop_grind(base.IOInstruction):
code = base.opcodes['STOPGRIND']
arg_format = []
@base.gf2n
class use_prep(base.Instruction):
r""" Input usage. """
@@ -1175,6 +1183,12 @@ class divint(base.IntegerInstruction):
__slots__ = []
code = base.opcodes['DIVINT']
@base.vectorize
class bitdecint(base.Instruction):
__slots__ = []
code = base.opcodes['BITDECINT']
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
###
### Clear comparison instructions
###

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import itertools
from random import randint
@@ -54,6 +54,8 @@ opcodes = dict(
JOIN_TAPE = 0x1A,
CRASH = 0x1B,
USE_PREP = 0x1C,
STARTGRIND = 0x1D,
STOPGRIND = 0x1E,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -132,6 +134,7 @@ opcodes = dict(
EQC = 0x97,
JMPI = 0x98,
# Integers
BITDECINT = 0x99,
LDINT = 0x9A,
ADDINT = 0x9B,
SUBINT = 0x9C,
@@ -535,7 +538,11 @@ class Instruction(object):
return ""
def has_var_args(self):
return False
try:
len(self.arg_format)
return False
except:
return True
def is_vec(self):
return False

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat
from Compiler.instructions import *

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import random
import math
@@ -187,9 +187,10 @@ def demux_list(x):
res = map(operator.mul, a, b)
return res
def demux_array(x):
def demux_array(x, res=None):
n = len(x)
res = Array(2**n, type(x[0]))
if res is None:
res = Array(2**n, type(x[0]))
if n == 1:
res[0] = 1 - x[0]
res[1] = x[0]
@@ -282,10 +283,18 @@ class ValueTuple(tuple):
class Entry(object):
""" An (O)RAM entry with empty bit, index, and value. """
@staticmethod
def get_empty(value_type, value_length, apply_type=True):
t = value_type if apply_type else lambda x: x
bt = value_type if apply_type else lambda x: x
return Entry(t(0), tuple(t(0) for i in range(value_length)), bt(True), t)
def get_empty(value_type, entry_size, apply_type=True, index_size=None):
res = {}
for i,tt in enumerate((value_type, value_type.default_type)):
if apply_type:
apply = lambda length, x: value_type.get_type(length)(x)
else:
apply = lambda length, x: x
res[i] = Entry(apply(index_size, 0), \
tuple(apply(l, 0) for l in entry_size), \
apply(1, True), value_type)
res[0].defaults = res[1]
return res[0]
def __init__(self, v, x=None, empty=None, value_type=None):
self.created_non_empty = False
if x is None:
@@ -296,7 +305,7 @@ class Entry(object):
else:
if empty is None:
self.created_non_empty = True
empty = value_type(False)
empty = value_type.bit_type(False)
self.is_empty = empty
self.v = v
if not isinstance(x, (tuple, list)):
@@ -360,14 +369,15 @@ class RefRAM(object):
crash()
self.size = oram.bucket_size
self.entry_type = oram.entry_type
self.l = [t.dynamic_array(self.size, t, array.address + \
index * oram.bucket_size) \
for t,array in zip(self.entry_type,oram.ram.l)]
self.l = [oram.get_array(self.size, t, array.address + \
index * oram.bucket_size) \
for t,array in zip(self.entry_type,oram.ram.l)]
self.index = index
def init_mem(self, empty_entry):
print 'init ram'
for a,value in zip(self.l, empty_entry.values()):
a.assign_all(value)
for a,value in zip(self.l, empty_entry.defaults.values()):
# don't use threads if n_threads explicitly set to 1
a.assign_all(value, n_threads != 1)
def get_empty_bits(self):
return self.l[0]
def get_indices(self):
@@ -401,7 +411,8 @@ class RefRAM(object):
return tree_reduce(operator.mul, list(self.get_empty_bits()))
def reveal(self):
Program.prog.curr_tape.start_new_basicblock()
res = RAM(self.size, [t.clear_type for t in self.entry_type], self.index)
res = RAM(self.size, [t.clear_type for t in self.entry_type], \
lambda *args: Array(*args), self.index)
for i,a in enumerate(self.l):
for j,x in enumerate(a):
res.l[i][j] = x.reveal()
@@ -422,33 +433,40 @@ class RefRAM(object):
class RAM(RefRAM):
""" List of entries in memory. """
def __init__(self, size, entry_type, index=0):
def __init__(self, size, entry_type, get_array, index=0):
#print_reg(cint(0), 'r in')
self.size = size
self.entry_type = entry_type
self.l = [t.dynamic_array(self.size, t) for t in entry_type]
self.l = [get_array(self.size, t) for t in entry_type]
self.index = index
class AbstractORAM(object):
""" Implements reading and writing using read_and_remove and add. """
@staticmethod
def get_array(size, t, *args, **kwargs):
return t.dynamic_array(size, t, *args, **kwargs)
def read(self, index):
return self._read(self.value_type.hard_conv(index))
def write(self, index, value):
new_value = [self.value_type.hard_conv(v) \
for v in (value if isinstance(value, (tuple, list)) \
new_value = [self.value_type.get_type(length).hard_conv(v) \
for length,v in zip(self.entry_size, value \
if isinstance(value, (tuple, list)) \
else (value,))]
return self._write(self.value_type.hard_conv(index), *new_value)
return self._write(self.index_type.hard_conv(index), *new_value)
def access(self, index, new_value, write, new_empty=False):
return self._access(self.value_type.hard_conv(index),
self.value_type.hard_conv(write),
self.value_type.hard_conv(new_empty),
*[self.value_type.hard_conv(v) for v in tuplify(new_value)])
return self._access(self.index_type.hard_conv(index),
self.value_type.bit_type.hard_conv(write),
self.value_type.bit_type.hard_conv(new_empty),
*[self.value_type.get_type(length).hard_conv(v) \
for length,v in zip(self.entry_size, \
tuplify(new_value))])
def read_and_maybe_remove(self, index):
return self.read_and_remove(self.value_type.hard_conv(index)), \
return self.read_and_remove(self.index_type.hard_conv(index)), \
self.state.read()
@method_block
def _read(self, index):
return self.access(index, (self.value_type(0),) * self.value_length, \
return self.access(index, tuple(self.value_type.get_type(l)(0) \
for l in self.entry_size), \
False)
@method_block
def _write(self, index, *value):
@@ -457,7 +475,7 @@ class AbstractORAM(object):
def _access(self, index, write, new_empty, *new_value):
Program.prog.curr_tape.\
start_new_basicblock(name='abstract-access-remove-%d' % self.size)
index = MemValue(self.value_type.hard_conv(index))
index = MemValue(self.index_type.hard_conv(index))
read_value, read_empty = self.read_and_remove(index)
if len(read_value) != self.value_length:
raise Exception('read_and_remove() of %s returns wrong length of ' \
@@ -472,9 +490,10 @@ class AbstractORAM(object):
if len(new_value) != self.value_length:
raise Exception('wrong length of new value')
value = tuple(MemValue(i) for i in if_else(write, new_value, read_value))
empty = self.value_type.hard_conv(new_empty)
empty = self.value_type.bit_type.hard_conv(new_empty)
self.add(Entry(index, value, if_else(write, empty, read_empty), \
value_type=self.value_type))
value_type=self.value_type), evict=False)
self.recursive_evict()
return read_value, read_empty
@method_block
def delete(self, index, for_real=True):
@@ -490,19 +509,25 @@ class AbstractORAM(object):
class EmptyException(Exception):
pass
class RefTrivialORAM(object):
class EndRecursiveEviction(object):
recursive_evict = lambda self: None
recursive_evict_rounds = lambda self: itertools.repeat([None])
class RefTrivialORAM(EndRecursiveEviction):
""" Trivial ORAM reference. """
contiguous = False
def empty_entry(self, apply_type=True):
return Entry.get_empty(self.value_type, self.value_length, apply_type)
return Entry.get_empty(self.value_type, self.entry_size, \
apply_type, self.index_size)
def __init__(self, index, oram):
self.ram = RefRAM(index, oram)
self.index_size = oram.index_size
self.value_type, self.value_length = oram.internal_value_type()
self.value_type, self.entry_size = oram.internal_entry_size()
self.size = oram.bucket_size
def init_mem(self):
print 'init trivial oram'
self.ram.init_mem(self.empty_entry())
self.ram.init_mem(self.empty_entry(apply_type=False))
def search(self, read_index):
if use_binary_search and self.value_type == sgf2n:
return self.binary_search(read_index)
@@ -695,7 +720,7 @@ class RefTrivialORAM(object):
Program.prog.security)
return [prefix_empty[i+1] - prefix_empty[i] \
for i in range(len(self.ram))]
def add(self, new_entry, state=None):
def add(self, new_entry, state=None, evict=None):
# if self.last_index != new_entry.v:
# raise Exception('index mismatch: %s / %s' %
# (str(self.last_index), str(new_entry.v)))
@@ -761,14 +786,17 @@ class TrivialORAM(RefTrivialORAM, AbstractORAM):
entry_size=None, contiguous=True, init_rounds=-1):
self.index_size = index_size or log2(size)
self.value_type = value_type
self.index_type = value_type.get_type(self.index_size)
if entry_size is None:
self.value_length = value_length
self.entry_size = [None] * value_length
else:
self.value_length = len(tuplify(entry_size))
self.entry_size = tuplify(entry_size)
self.contiguous = contiguous
entry_type = self.empty_entry().types()
self.size = size
self.ram = RAM(size, entry_type)
self.ram = RAM(size, entry_type, self.get_array)
if init_rounds != -1:
# put memory initialization in different timer
stop_timer()
@@ -789,9 +817,16 @@ def get_n_threads(n_loops):
class LinearORAM(TrivialORAM):
""" Contiguous ORAM that stores entries in order. """
@staticmethod
def get_array(size, t, *args, **kwargs):
return Array(size, t, *args, **kwargs)
def __init__(self, *args, **kwargs):
TrivialORAM.__init__(self, *args, **kwargs)
self.index_vector = self.get_array(2 ** self.index_size, \
self.index_type.bit_type)
def read_and_maybe_remove(self, index):
return self.read(index), 0
def add(self, entry, state=None):
def add(self, entry, state=None, evict=None):
if entry.created_non_empty is True:
self.write(entry.v, entry.x)
else:
@@ -802,14 +837,14 @@ class LinearORAM(TrivialORAM):
def _read(self, index):
maybe_start_timer(6)
empty_entry = self.empty_entry(False)
index_vector = \
demux_array(bit_decompose(index, self.index_size))
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
self.value_length + 1, [self.value_type] + \
[self.value_type] * self.value_length)
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type.get_type(l) for l in self.entry_size])
def f(i):
entry = self.ram[i]
access_here = index_vector[i]
access_here = self.index_vector[i]
return access_here * ValueTuple((entry.empty(),) + entry.x)
not_found = f()[0]
read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x
@@ -819,13 +854,13 @@ class LinearORAM(TrivialORAM):
def _write(self, index, *new_value):
maybe_start_timer(7)
empty_entry = self.empty_entry(False)
index_vector = \
demux_array(bit_decompose(index, self.index_size))
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
new_value = make_array(new_value)
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
def f(i):
entry = self.ram[i]
access_here = index_vector[i]
access_here = self.index_vector[i]
nv = ValueTuple(new_value)
delta_entry = \
Entry(0, access_here * (nv - entry.x), \
@@ -841,7 +876,7 @@ class LinearORAM(TrivialORAM):
new_empty = MemValue(new_empty)
write = MemValue(write)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
self.value_length + 1, [self.value_type] + \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
def f(i):
entry = self.ram[i]
@@ -892,16 +927,24 @@ class RefBucket(object):
child.output()
def random_block(length, value_type):
return sum(value_type.get_random_bit() << i for i in range(length))
return sum(value_type.bit_type.get_random_bit() << i for i in range(length))
class List(object):
class List(EndRecursiveEviction):
""" Debugging only. List which accepts secret values as indices
and *reveals* them. """
def __init__(self, size, value_type, value_length=1, init_rounds=None):
def __init__(self, size, value_type, value_length=1, \
init_rounds=None, entry_size=None):
self.value_type = value_type
self.index_type = value_type.get_type(log2(size))
self.value_length = value_length
self.l = [value_type.dynamic_array(size, value_type) \
if entry_size is None:
self.l = [value_type.dynamic_array(size, value_type) \
for i in range(value_length)]
else:
self.l = [value_type.dynamic_array(size, \
value_type.get_type(length)) \
for length in entry_size]
self.value_length = len(entry_size)
for l in self.l:
l.assign_all(0)
__getitem__ = lambda self,index: [self.l[i][regint(reveal(index))] \
@@ -916,8 +959,9 @@ class List(object):
read_and_remove = lambda self,i: (self[i], None)
def read_and_maybe_remove(self, *args, **kwargs):
return self.read_and_remove(*args, **kwargs), 0
add = lambda self,entry: self.__setitem__(entry.v.read(), \
[v.read() for v in entry.x])
add = lambda self,entry,**kwargs: self.__setitem__(entry.v.read(), \
[v.read() for v in entry.x])
recursive_evict = lambda *args,**kwargs: None
def batch_init(self, values):
for i,value in enumerate(values):
index = self.value_type.hard_conv(i)
@@ -939,7 +983,7 @@ class LocalIndexStructure(List):
def f(i):
self.l[0][i] = random_block(entry_size, value_type)
print 'index size:', size
def update(self, index, value):
def update(self, index, value, evict=None):
read_value = self[index]
#print 'read', index, read_value
#print self.l
@@ -977,13 +1021,18 @@ class TreeORAM(AbstractORAM):
self.value_type = value_type
if entry_size is not None:
self.value_length = len(tuplify(entry_size))
self.entry_size = tuplify(entry_size)
else:
self.value_length = value_length
self.entry_size = [None] * value_length
self.index_size = log2(size)
self.index_type = value_type.get_type(self.index_size)
self.size = size
empty_entry = Entry.get_empty(*self.internal_value_type())
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
index_size=self.D)
self.entry_type = empty_entry.types()
self.ram = RAM(self.n_buckets() * self.bucket_size, self.entry_type)
self.ram = RAM(self.n_buckets() * self.bucket_size, self.entry_type, \
self.get_array)
if init_rounds != -1:
# put memory initialization in different timer
stop_timer()
@@ -996,7 +1045,7 @@ class TreeORAM(AbstractORAM):
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
self.read_value = Array(self.value_length, value_type)
self.read_non_empty = MemValue(self.value_type(0))
self.read_non_empty = MemValue(self.value_type.bit_type(0))
self.state = MemValue(self.value_type(0))
@method_block
def add_to_root(self, state, is_empty, v, *x):
@@ -1011,7 +1060,7 @@ class TreeORAM(AbstractORAM):
entry = bucket.bucket.pop()
#print 'evict', entry
#print 'from', bucket
b = if_else(entry.empty(), self.value_type.get_random_bit(), \
b = if_else(entry.empty(), self.value_type.bit_type.get_random_bit(), \
get_bit(entry.x[0], self.D - 1 - d, self.D))
block = cond_swap(b, entry, self.root.bucket.empty_entry())
#print 'empty', entry.empty()
@@ -1042,7 +1091,7 @@ class TreeORAM(AbstractORAM):
new_path = regint.get_random(self.D)
l_star = self.value_type(new_path)
self.state.write(l_star)
return self.index.update(u, l_star).reveal()
return self.index.update(u, l_star, evict=False).reveal()
@method_block
def read_and_remove_levels(self, u, read_path):
u = MemValue(u)
@@ -1050,7 +1099,7 @@ class TreeORAM(AbstractORAM):
levels = self.D + 1
parallel = get_parallel(self.index_size, *self.internal_value_type())
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
self.value_length + 1, [self.value_type] + \
self.value_length + 1, [self.value_type.bit_type] + \
[self.value_type] * self.value_length)
def process(level):
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
@@ -1074,6 +1123,8 @@ class TreeORAM(AbstractORAM):
crash()
def internal_value_type(self):
return self.value_type, self.value_length + 1
def internal_entry_size(self):
return self.value_type, [self.D] + list(self.entry_size)
def n_buckets(self):
return 2**(self.D+1)
@method_block
@@ -1097,7 +1148,7 @@ class TreeORAM(AbstractORAM):
Program.prog.curr_tape.\
start_new_basicblock(name='read_and_remove-%d-end' % self.size)
return [MemValue(v) for v in read_value], MemValue(read_empty)
def add(self, entry, state=None):
def add(self, entry, state=None, evict=None):
if state is None:
state = self.state.read()
#print_reg(cint(0), 'add')
@@ -1141,6 +1192,9 @@ class TreeORAM(AbstractORAM):
#print 'd, 2^d', d, 1 << d
self.evict2(s1 + (1 << d), s2 + (1 << d), d)
self.check()
def recursive_evict(self):
self.evict()
self.index.recursive_evict()
def batch_init(self, values):
""" Batch initalization. Obliviously shuffles and adds N entries to
@@ -1349,11 +1403,11 @@ class PackedIndexStructure(object):
self.entry_size = tuplify(entry_size)
self.value_type = value_type
for demux_bits in range(max_demux_bits + 1):
self.log_entries_per_element = min(size, \
self.log_entries_per_element = min(log2(size), \
int(math.floor(math.log(float(get_value_size(value_type)) / \
sum(self.entry_size), 2))))
self.log_elements_per_block = \
max(0, min(demux_bits, log2(size * sum(self.entry_size)) - \
max(0, min(demux_bits, log2(size) - \
self.log_entries_per_element))
if self.log_entries_per_element < 0:
self.entries_per_block = 1
@@ -1388,14 +1442,17 @@ class PackedIndexStructure(object):
print 'log(elements per block):', self.log_elements_per_block
print 'elements per block:', self.elements_per_block
print 'used bits:', self.used_bits
entry_size = [self.used_bits] * self.elements_per_block
if real_size > 1:
# no need to init underlying ORAM, will be initialized implicitely
self.l = self.storage(real_size, value_type, self.elements_per_block, \
init_rounds=0)
self.l = self.storage(real_size, value_type, \
entry_size=entry_size, init_rounds=0)
self.small = False
else:
self.l = List(1, value_type, self.elements_per_block)
self.l = List(1, value_type, self.elements_per_block, \
entry_size=entry_size)
self.small = True
self.index_type = self.l.index_type
if init_rounds:
if init_rounds > 0:
real_init_rounds = init_rounds * real_size / size
@@ -1484,20 +1541,20 @@ class PackedIndexStructure(object):
return self.MultiSlicer(self, index)
else:
return self.Slicer(self, index)
def update(self, index, value):
def update(self, index, value, evict=True):
""" Updating index return current value. Has to be done in one
step to avoid exponential blow-up in ORAM recursion. """
return self.access(index, value, True)
def access(self, index, value, write):
return self.access(index, value, True, evict=evict)
def access(self, index, value, write, evict=True):
slicer = self.get_slicer(index)
block = self.l.read_and_maybe_remove(slicer.a)[0][0]
read_value = slicer.read(block)
value = if_else(write, ValueTuple(tuplify(value)), \
ValueTuple(read_value))
self.l.add(Entry(MemValue(self.value_type(slicer.a)), \
self.l.add(Entry(MemValue(self.l.index_type(slicer.a)), \
ValueTuple(MemValue(v) \
for v in slicer.write(value)), \
value_type=self.value_type))
value_type=self.value_type), evict=evict)
return untuplify(read_value)
def __getitem__(self, index):
slicer = self.get_slicer(index)
@@ -1507,7 +1564,9 @@ class PackedIndexStructure(object):
# no need for reading first
self.l[index] = self.get_slicer(index).write(value)
else:
self.access(index, value, True)
self.access(index, value, True, False)
self.l.recursive_evict()
recursive_evict = lambda self: self.l.recursive_evict()
def batch_init(self, values):
""" Initialize m values with indices 0, ..., m-1 """
@@ -1551,8 +1610,8 @@ class PackedORAMWithEmpty(AbstractORAM, PackedIndexStructure):
return res[1:], 1 - res[0]
def read_and_maybe_remove(self, index):
return self.read(index), 0
def add(self, entry, state=None):
self.access(entry.v, entry.x, True, entry.empty())
def add(self, entry, state=None, evict=True):
self.access(entry.v, entry.x, True, entry.empty(), evict=evict)
class LocalPackedIndexStructure(PackedIndexStructure):
""" Debugging only. Packed tree ORAM index revealing the access
@@ -1621,30 +1680,39 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
storage = OptimalORAM
def test_oram(oram_type, N, value_type=sint, iterations=100):
stop_grind()
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
value_type = value_type.get_type(32)
index_type = value_type.get_type(log2(N))
start_grind()
print 'initialized'
print_ln('initialized')
stop_timer()
# synchronize
start_timer(2)
Program.prog.curr_tape.start_new_basicblock(name='sync')
value_type(0).reveal()
Program.prog.curr_tape.start_new_basicblock(name='sync')
stop_timer(2)
start_timer()
#oram[value_type(0)] = -1
#iterations = N
@for_range(iterations)
def f(i):
oram[value_type(i % N)] = value_type(i % N)
time()
oram[index_type(i % N)] = value_type(i % N)
#value, empty = oram.read_and_remove(value_type(i))
#print 'first write'
time()
oram[value_type(i % N)].reveal().print_reg('writ')
oram[index_type(i % N)].reveal().print_reg('writ')
#print 'first read'
@for_range(iterations)
def f(i):
x = oram[value_type(i % N)]
time()
x = oram[index_type(i % N)]
x.reveal().print_reg('read')
# print 'second read'
print_ln('%s accesses', 3 * iterations)
return oram
def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100):

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
if '_Array' not in dir():
from oram import *
@@ -147,13 +147,17 @@ class PathORAM(TreeORAM):
self.value_type = value_type
if entry_size is not None:
self.value_length = len(tuplify(entry_size))
self.entry_size = tuplify(entry_size)
else:
self.value_length = value_length
self.entry_size = [None] * value_length
self.index_size = log2(size)
self.index_type = value_type.get_type(self.index_size)
self.size = size
self.entry_type = Entry.get_empty(*self.internal_value_type()).types()
self.entry_type = Entry.get_empty(*self.internal_entry_size()).types()
self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type)
self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type,
self.get_array)
if init_rounds != -1:
# put memory initialization in different timer
stop_timer()
@@ -197,15 +201,15 @@ class PathORAM(TreeORAM):
# temp storage for the path + stash in eviction
self.temp_size = stash_size + self.bucket_size*(self.D+1)
self.temp_storage = RAM(self.temp_size, self.entry_type)
self.temp_storage = RAM(self.temp_size, self.entry_type, self.get_array)
self.temp_levels = [0] * self.temp_size # Array(self.temp_size, 'c')
for i in range(self.temp_size):
self.temp_levels[i] = 0
# these include a read value from the stash
self.read_value = [Array(self.D + 2, self.value_type)
for i in range(self.value_length)]
self.read_empty = Array(self.D + 2, self.value_type)
self.read_value = [Array(self.D + 2, self.value_type.get_type(l))
for l in self.entry_size]
self.read_empty = Array(self.D + 2, self.value_type.bit_type)
self.state = MemValue(self.value_type(0))
self.eviction_count = MemValue(cint(0))
@@ -242,7 +246,8 @@ class PathORAM(TreeORAM):
for j, ram_index in enumerate(ram_indices):
self.temp_storage[i*self.bucket_size + j] = self.buckets[ram_index]
self.temp_levels[i*self.bucket_size + j] = i
self.buckets[ram_index] = Entry.get_empty(self.value_type, 1)
ies = self.internal_entry_size()
self.buckets[ram_index] = Entry.get_empty(*ies)
# load the stash
for i in range(len(self.stash.ram)):
@@ -253,7 +258,7 @@ class PathORAM(TreeORAM):
entry = self.stash.ram[i]
self.temp_storage[i + self.bucket_size*(self.D+1)] = entry
te = Entry.get_empty(self.value_type, 1)
te = Entry.get_empty(*self.internal_entry_size())
self.stash.ram[i] = te
self.path_regs = [None] * self.bucket_size*(self.D+1)
@@ -268,11 +273,11 @@ class PathORAM(TreeORAM):
#self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)]
if self.use_shuffle_evict:
if self.bucket_size == 4:
self.size_bits = [[self.value_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)]
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)]
elif self.bucket_size == 2 or self.bucket_size == 3:
self.size_bits = [[self.value_type(i) for i in (0, 0)] for j in range(self.D+1)]
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0)] for j in range(self.D+1)]
else:
self.size_bits = [[self.value_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)]
self.size_bits = [[self.value_type.bit_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)]
self.stash_size = Counter(0, max_val=len(self.stash.ram))
leaf = self.state.read().reveal()
@@ -308,7 +313,7 @@ class PathORAM(TreeORAM):
empty_entry = self.empty_entry(False)
skip = 1
found = Array(self.bucket_size, self.value_type)
found = Array(self.bucket_size, self.value_type.bit_type)
entries = [self.buckets[j] for j in ram_indices]
indices = [e.v for e in entries]
empty_bits = [e.empty() for e in entries]
@@ -341,8 +346,8 @@ class PathORAM(TreeORAM):
self.read_empty[self.D+1] = empty
def empty_entry(self, apply_type=True):
vtype, vlength = self.internal_value_type()
return Entry.get_empty(vtype, vlength, apply_type)
vtype, entry_size = self.internal_entry_size()
return Entry.get_empty(vtype, entry_size, apply_type, self.index_size)
def shuffle_evict(self, leaf):
""" Evict using oblivious shuffling etc """
@@ -409,25 +414,25 @@ class PathORAM(TreeORAM):
if self.bucket_size == 4:
c = s[0]*s[1]
if self.value_type == sgf2n:
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1] + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type(c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][3] = [1 - self.value_type(s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
else:
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1] + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type(c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][3] = [1 - self.value_type(s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
elif self.bucket_size == 2:
if evict_debug:
print_str('%s,%s,', s[0].reveal(), s[1].reveal())
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
elif self.bucket_size == 3:
c = s[0]*s[1]
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] - c), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type(c), self.value_type.clear_type(j)]
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c), self.value_type.clear_type(j)]
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c), self.value_type.clear_type(j)]
if evict_debug:
print_ln()
@@ -471,7 +476,7 @@ class PathORAM(TreeORAM):
merged_entries = [e for e in merged_entries if e is not None]
# need to copy entries/levels to memory for re-positioning
entries_ram = RAM(self.temp_size, self.entry_type)
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)
levels_array = Array(self.temp_size, cint)
for i,entrylev in enumerate(merged_entries):
@@ -570,10 +575,10 @@ class PathORAM(TreeORAM):
def adjust_lca(self, lca_bits, lev, not_empty, prnt=False):
""" Adjust LCA based on bucket capacities (and original clear level, lev) """
found = self.value_type(0)
assigned = self.value_type(0)
try_add_here = self.value_type(0)
new_lca = [self.value_type(0)] * (self.D + 1)
found = self.value_type.bit_type(0)
assigned = self.value_type.bit_type(0)
try_add_here = self.value_type.bit_type(0)
new_lca = [self.value_type.bit_type(0)] * (self.D + 1)
upper = min(lev + self.sigma, self.D)
lower = max(lev - self.tau, 0)
@@ -639,7 +644,7 @@ class PathORAM(TreeORAM):
a_bits = bit_decompose(a, self.D)
b_bits = bit_decompose(b, self.D)
found = [None] * self.D
not_found = self.value_type(not_empty) #1
not_found = self.value_type.bit_type(not_empty) #1
if limit is None:
limit = self.D
@@ -722,13 +727,13 @@ class PathORAM(TreeORAM):
return levstar, a
def add(self, entry, state=None):
def add(self, entry, state=None, evict=True):
if state is None:
state = self.state.read()
l = state
x = tuple((self.value_type(i.read())) for i in entry.x)
x = tuple(i.read() for i in entry.x)
e = Entry(self.value_type(entry.v.read()), (l,) + x, entry.empty())
e = Entry(entry.v.read(), (l,) + x, entry.empty())
#self.temp_storage[self.temp_size-1] = e * 1
#self.temp_levels[self.temp_size-1] = 0
@@ -738,7 +743,8 @@ class PathORAM(TreeORAM):
except Exception:
print self
raise
self.evict()
if evict:
self.evict()
class LocalPathORAM(PathORAM):
""" Debugging only. Path ORAM using index revealing the access

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from random import randint
import math
@@ -361,7 +361,7 @@ def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
if config is None:
config = configure_waksman(random_perm(n))
for i,c in enumerate(config):
config[i] = [value_type(b) for b in c]
config[i] = [value_type.bit_type(b) for b in c]
waksman(x, config, reverse=reverse)
waksman(x, config, reverse=reverse)

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.config import *
from Compiler.exceptions import *
@@ -66,6 +66,8 @@ class Program(object):
self.free_threads = set()
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
self.types = {}
self.to_merge = [Compiler.instructions.startopen_class]
self.stop_class = Compiler.instructions.stopopen_class
Program.prog = self
self.reset_values()
@@ -125,6 +127,7 @@ class Program(object):
self.name = progname
if len(args) > 1:
self.name += '-' + '-'.join(args[1:])
self.progname = progname
def new_tape(self, function, args=[], name=None):
if name is None:
@@ -534,7 +537,9 @@ class Tape:
(block.name, i, len(self.basicblocks), \
len(block.instructions))
# the next call is necessary for allocation later even without merging
merger = al.Merger(block, options)
merger = al.Merger(block, options, \
tuple(self.program.to_merge), \
self.program.stop_class)
if options.dead_code_elimination:
if len(block.instructions) > 10000:
print 'Eliminate dead code...'
@@ -545,8 +550,8 @@ class Tape:
block.defined_registers = set()
continue
if len(block.instructions) > 10000:
print 'Merging open instructions...'
numrounds = merger.longest_paths_merge()
print 'Merging instructions...'
numrounds = merger.longest_paths_merge(self.program.stop_class != type(None))
if numrounds > 0:
print 'Program requires %d rounds of communication' % numrounds
numinv = sum(len(i.args) for i in block.instructions if isinstance(i, Compiler.instructions.startopen_class))
@@ -555,7 +560,9 @@ class Tape:
if options.dead_code_elimination:
block.instructions = filter(lambda x: x is not None, block.instructions)
if not (options.merge_opens and self.merge_opens):
print 'Not merging open instructions in tape %s' % self.name
print 'Not merging instructions in tape %s' % self.name
else:
print 'Rounds determined by', self.program.to_merge
# add jumps
offset = 0

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import itertools
@@ -7,3 +7,9 @@ class chain(object):
self.args = args
def __iter__(self):
return itertools.chain(*self.args)
class cycle(object):
def __init__(self, *args):
self.args = args
def __iter__(self):
return itertools.cycle(*self.args)

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
from Compiler.program import Tape
from Compiler.exceptions import *
@@ -791,13 +791,8 @@ class regint(_register, _int):
return cint(self).mod2m(*args, **kwargs)
def bit_decompose(self, bit_length=None):
res = []
x = self
two = regint(2)
for i in range(bit_length or program.bit_length):
y = x / two
res.append(x - two * y)
x = y
res = [regint() for i in range(bit_length or program.bit_length)]
bitdecint(self, *res)
return res
@staticmethod
@@ -819,6 +814,9 @@ class regint(_register, _int):
class _secret(_register):
__slots__ = []
PreOR = staticmethod(lambda l: floatingpoint.PreORC(l))
PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l))
@vectorized_classmethod
@set_instruction_type
def protect_memory(cls, start, end):
@@ -971,6 +969,9 @@ class sint(_secret, _int):
clear_type = cint
reg_type = 's'
PreOp = staticmethod(floatingpoint.PreOpL)
PreOR = staticmethod(floatingpoint.PreOR)
@vectorized_classmethod
def get_random_int(cls, bits):
res = sint()
@@ -1164,6 +1165,10 @@ class sgf2n(_secret, _gf2n):
clear_type = cgf2n
reg_type = 'sg'
@classmethod
def get_type(cls, length):
return cls
@classmethod
def get_raw_input_from(cls, player):
res = cls()
@@ -1265,8 +1270,10 @@ class sgf2n(_secret, _gf2n):
masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal()
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
sint.basic_type = sint
sgf2n.basic_type = sgf2n
for t in (sint, sgf2n):
t.bit_type = t
t.basic_type = t
t.default_type = t
class sgf2nint(sgf2n):
@@ -2305,13 +2312,13 @@ class Array(object):
self[i] = value[source_index]
source_index.iadd(1)
return
self._store(self.value_type.conv(value), self.get_address(index))
self._store(value, self.get_address(index))
def _load(self, address):
return self.value_type.load_mem(address)
def _store(self, value, address):
value.store_in_mem(address)
self.value_type.conv(value).store_in_mem(address)
def __len__(self):
return self.length
@@ -2335,10 +2342,10 @@ class Array(object):
self[i] = j
return self
def assign_all(self, value):
mem_value = self.value_type.MemValue(value)
n_loops = 8 if len(self) > 2**20 else 1
@library.for_range_multithread(n_loops, 1024, len(self))
def assign_all(self, value, use_threads=True):
mem_value = MemValue(value)
n_threads = 8 if use_threads and len(self) > 2**20 else 1
@library.for_range_multithread(n_threads, 1024, len(self))
def f(i):
self[i] = mem_value
return self

View File

@@ -1,4 +1,4 @@
# (C) 2018 University of Bristol. See License.txt
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
import math
import operator

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _Exceptions
#define _Exceptions

View File

@@ -1,4 +1,4 @@
(C) 2018 University of Bristol. See License.txt.
(C) 2018 University of Bristol, Bar-Ilan University. See License.txt.
The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).

View File

@@ -1,5 +1,5 @@
/*
* (C) 2018 University of Bristol. See License.txt
* (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.

View File

@@ -1,5 +1,5 @@
/*
* (C) 2018 University of Bristol. See License.txt
* (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* AddableVector.h

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "Ciphertext.h"
#include "Exceptions/Exceptions.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _Ciphertext
#define _Ciphertext

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "DiscreteGauss.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _DiscreteGauss
#define _DiscreteGauss

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE/FFT.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _FFT
#define _FFT

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE/FFT_Data.h"
#include "FHE/FFT.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _FFT_Data
#define _FFT_Data

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE_Keys.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _FHE_Keys
#define _FHE_Keys

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE_Params.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _FHE_Params
#define _FHE_Params

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* Generator.h

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE/Matrix.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _myHNF

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#include "FHE/NTL-Subs.h"

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
#ifndef _NTL_Subs
#define _NTL_Subs

View File

@@ -1,4 +1,4 @@
// (C) 2018 University of Bristol. See License.txt
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
/*
* NoiseBound.cpp

Some files were not shown because too many files have changed in this diff Show More