mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
ORAM in SPDZ-BMR.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -44,6 +44,7 @@ callgrind.out.*
|
||||
|
||||
# Compiled source #
|
||||
###################
|
||||
Programs/Source/*
|
||||
Programs/Bytecode/*
|
||||
Programs/Schedules/*
|
||||
Programs/Public-Input/*
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "Auth/MAC_Check.h"
|
||||
@@ -366,16 +366,76 @@ void Direct_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S
|
||||
this->CheckIfNeeded(P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Passing_MAC_Check<T>::Passing_MAC_Check(const T& ai, Names& Nms, int num) :
|
||||
Separate_MAC_Check<T>(ai, Nms, num)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T, int t>
|
||||
void passing_add_openings(vector<T>& values, octetStream& os)
|
||||
{
|
||||
octetStream new_os;
|
||||
for (unsigned int i=0; i<values.size(); i++)
|
||||
{
|
||||
T tmp;
|
||||
tmp.unpack(os);
|
||||
(tmp + values[i]).pack(new_os);
|
||||
}
|
||||
os = new_os;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Passing_MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
values.resize(S.size());
|
||||
this->os.reset_write_head();
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
{
|
||||
S[i].get_share().pack(this->os);
|
||||
values[i] = S[i].get_share();
|
||||
}
|
||||
this->AddToMacs(S);
|
||||
|
||||
for (int i = 0; i < P.num_players() - 1; i++)
|
||||
{
|
||||
P.pass_around(this->os);
|
||||
if (T::t() == 2)
|
||||
passing_add_openings<T,2>(values, this->os);
|
||||
else
|
||||
passing_add_openings<T,0>(values, this->os);
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < values.size(); i++)
|
||||
{
|
||||
T tmp;
|
||||
tmp.unpack(this->os);
|
||||
this->vals.push_back(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Passing_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
(void)S;
|
||||
this->GetValues(values);
|
||||
this->popen_cnt += values.size();
|
||||
this->CheckIfNeeded(P);
|
||||
}
|
||||
|
||||
template class MAC_Check<gfp>;
|
||||
template class Direct_MAC_Check<gfp>;
|
||||
template class Parallel_MAC_Check<gfp>;
|
||||
template class Passing_MAC_Check<gfp>;
|
||||
|
||||
template class MAC_Check<gf2n>;
|
||||
template class Direct_MAC_Check<gf2n>;
|
||||
template class Parallel_MAC_Check<gf2n>;
|
||||
template class Passing_MAC_Check<gf2n>;
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template class MAC_Check<gf2n_short>;
|
||||
template class Direct_MAC_Check<gf2n_short>;
|
||||
template class Parallel_MAC_Check<gf2n_short>;
|
||||
template class Passing_MAC_Check<gf2n_short>;
|
||||
#endif
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _MAC_Check
|
||||
#define _MAC_Check
|
||||
@@ -161,6 +161,16 @@ public:
|
||||
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class Passing_MAC_Check : public Separate_MAC_Check<T>
|
||||
{
|
||||
public:
|
||||
Passing_MAC_Check(const T& ai, Names& Nms, int thread_num);
|
||||
|
||||
void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
};
|
||||
|
||||
|
||||
enum mc_timer { SEND, RECV_ADD, BCAST, RECV_SUM, SEED, COMMIT, WAIT_SUMMER, RECV, SUM, SELECT, MAX_TIMER };
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "Auth/Subroutines.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _Subroutines
|
||||
#define _Subroutines
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Summer.cpp
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Summer.h
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#ifndef _fake_stuff
|
||||
|
||||
38
BMR/AndJob.cpp
Normal file
38
BMR/AndJob.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* AndJob.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "AndJob.h"
|
||||
#include "Party.h"
|
||||
#include "Register_inline.h"
|
||||
|
||||
int AndJob::run()
|
||||
{
|
||||
#ifdef DEBUG_AND_JOB
|
||||
printf("thread %d: run and job from %d to %d with %d gates\n",
|
||||
pthread_self(), start, end, gates.size());
|
||||
#endif
|
||||
__m128i prf_output[PAD_TO_8(MAX_N_PARTIES)];
|
||||
auto gate = gates.begin();
|
||||
vector< GC::Secret<EvalRegister> >& S = *this->S;
|
||||
const vector<int>& args = *this->args;
|
||||
int i_gate = 0;
|
||||
for (size_t i = start; i < end; i += 4)
|
||||
{
|
||||
GC::Secret<EvalRegister>& dest = S[args[i + 1]];
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
{
|
||||
i_gate++;
|
||||
gate->init_inputs(gate_id + i_gate,
|
||||
ProgramParty::s().get_n_parties());
|
||||
dest.get_reg(j).eval(S[args[i + 2]].get_reg(j),
|
||||
S[args[i + 3]].get_reg(0), *gate,
|
||||
ProgramParty::s().get_id(), (char*) prf_output, 0, 0, 0);
|
||||
gate++;
|
||||
}
|
||||
}
|
||||
return i_gate;
|
||||
}
|
||||
44
BMR/AndJob.h
Normal file
44
BMR/AndJob.h
Normal file
@@ -0,0 +1,44 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* AndJob.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef BMR_ANDJOB_H_
|
||||
#define BMR_ANDJOB_H_
|
||||
|
||||
#include "GC/Secret.h"
|
||||
#include "Register.h"
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
class AndJob
|
||||
{
|
||||
vector< GC::Secret<EvalRegister> >* S;
|
||||
const vector<int>* args;
|
||||
|
||||
public:
|
||||
vector<GarbledGate> gates;
|
||||
size_t start, end;
|
||||
gate_id_t gate_id;
|
||||
|
||||
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
|
||||
|
||||
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
|
||||
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
|
||||
{
|
||||
this->S = &S;
|
||||
this->args = &args;
|
||||
this->start = start;
|
||||
this->end = start;
|
||||
this->gate_id = gate_id;
|
||||
if (gates.size() < n_gates)
|
||||
gates.resize(n_gates, {n_parties});
|
||||
}
|
||||
|
||||
int run();
|
||||
};
|
||||
|
||||
#endif /* BMR_ANDJOB_H_ */
|
||||
@@ -1,14 +1,16 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "BooleanCircuit.h"
|
||||
|
||||
#include "prf.h"
|
||||
|
||||
static void throw_party_exists(pid_t pid, unsigned int pos) {
|
||||
std::cout << "ERROR: in circuit description" << std::endl
|
||||
<< "\tPosition: " << pos << std::endl
|
||||
<< "\tPlayer id " << pid << " already exists" << std::endl;
|
||||
throw std::invalid_argument( "player id error" );
|
||||
}
|
||||
//static void throw_party_exists(pid_t pid, unsigned int pos) {
|
||||
// std::cout << "ERROR: in circuit description" << std::endl
|
||||
// << "\tPosition: " << pos << std::endl
|
||||
// << "\tPlayer id " << pid << " already exists" << std::endl;
|
||||
// throw std::invalid_argument( "player id error" );
|
||||
//}
|
||||
|
||||
static void throw_bad_circuit_file() {
|
||||
std::cout << "ERROR: could not read circuit file" << std::endl;
|
||||
@@ -27,7 +29,7 @@ void BooleanCircuit::_parse_circuit(const char* desc_file)
|
||||
unsigned int total_input_wires;
|
||||
circuit_file >> total_input_wires;
|
||||
|
||||
for (int idx_party = 1; idx_party <= _num_parties; idx_party++) {
|
||||
for (size_t idx_party = 1; idx_party <= _num_parties; idx_party++) {
|
||||
unsigned int num_input_wires = total_input_wires/_num_parties;
|
||||
if (idx_party == _num_parties) {
|
||||
num_input_wires = total_input_wires-(total_input_wires/_num_parties)*(_num_parties-1);
|
||||
@@ -53,7 +55,7 @@ void BooleanCircuit::_parse_circuit(const char* desc_file)
|
||||
/* Parse gates */
|
||||
for (gate_id_t gate_id = 1; gate_id <= _num_gates; gate_id++)
|
||||
{
|
||||
int fan_in, fan_out, left, right, out;
|
||||
size_t fan_in, fan_out, left, right, out;
|
||||
std::string func;
|
||||
|
||||
circuit_file >> fan_in >> fan_out >> left >> right >> out >> func;
|
||||
@@ -163,11 +165,11 @@ int BooleanCircuit::__make_layers(gate_id_t g)
|
||||
return 0;
|
||||
}
|
||||
int layer_left, layer_right;
|
||||
if(_gates[g]._left == NULL)
|
||||
if(_gates[g]._left == 0)
|
||||
layer_left = -1;
|
||||
else
|
||||
layer_left = __make_layers(_wires[_gates[g]._left]._out_from);
|
||||
if(_gates[g]._right == NULL)
|
||||
if(_gates[g]._right == 0)
|
||||
layer_right = -1;
|
||||
else
|
||||
layer_right = __make_layers(_wires[_gates[g]._right]._out_from);
|
||||
@@ -183,7 +185,7 @@ int BooleanCircuit::__make_layers(gate_id_t g)
|
||||
void BooleanCircuit::_add_to_layer(int layer, gate_id_t g)
|
||||
{
|
||||
// printf("adding %u to layer %d\n", g, layer);
|
||||
if (_layers.size()<layer+1) {
|
||||
if (_layers.size()<(size_t)layer+1) {
|
||||
// printf("layer doesn't exist, creating it\n");
|
||||
_layers.resize(layer+1);
|
||||
}
|
||||
@@ -192,11 +194,11 @@ void BooleanCircuit::_add_to_layer(int layer, gate_id_t g)
|
||||
|
||||
void BooleanCircuit::_print_layers()
|
||||
{
|
||||
printf("num layers = %d\n",_layers.size());
|
||||
for(int i=0; i<_layers.size(); i++) {
|
||||
printf ("\nlayer %d size=%d\n",i,_layers[i].size());
|
||||
for(int j=0; j<_layers[i].size(); j++) {
|
||||
printf("%u ", _layers[i][j]);
|
||||
printf("num layers = %lu\n",_layers.size());
|
||||
for(size_t i=0; i<_layers.size(); i++) {
|
||||
printf ("\nlayer %lu size=%lu\n",i,_layers[i].size());
|
||||
for(size_t j=0; j<_layers[i].size(); j++) {
|
||||
printf("%lu ", _layers[i][j]);
|
||||
}
|
||||
printf ("\n");
|
||||
}
|
||||
@@ -207,11 +209,11 @@ void BooleanCircuit::_validate_layers()
|
||||
_max_layer_sz = 0;
|
||||
for (gate_id_t g=1; g<_num_gates; g++){
|
||||
if(_gates[g]._layer == NO_LAYER)
|
||||
printf("gate %d has no layer!\n",g);
|
||||
printf("gate %lu has no layer!\n",g);
|
||||
assert(_gates[g]._layer != NO_LAYER);
|
||||
}
|
||||
gate_id_t sum_gates_in_layers = 0;
|
||||
for(int i=0; i<_layers.size(); i++) {
|
||||
for(size_t i=0; i<_layers.size(); i++) {
|
||||
sum_gates_in_layers += _layers[i].size();
|
||||
if(_layers[i].size() > _max_layer_sz) {
|
||||
_max_layer_sz = _layers[i].size();
|
||||
@@ -245,7 +247,7 @@ void BooleanCircuit::Inputs(const char* inputs_file_path)
|
||||
}
|
||||
|
||||
BooleanCircuit::BooleanCircuit(const char* desc_file)
|
||||
:_num_evaluated_out_wires(0),_num_input_wires(0), _keys(NULL), _masks(NULL), _prf_inputs(NULL)
|
||||
:_num_evaluated_out_wires(0)
|
||||
{
|
||||
_parse_circuit(desc_file);
|
||||
_make_layers();
|
||||
@@ -259,18 +261,17 @@ void BooleanCircuit::EvaluateByLayerLinearly(party_id_t my_id) {
|
||||
mpz_t temp_mpz;
|
||||
init_temp_mpz_t(temp_mpz);
|
||||
#endif
|
||||
for(int i=0; i<_layers.size(); i++) {
|
||||
for (int j=0; j<_layers[i].size(); j++) {
|
||||
for(size_t i=0; i<_layers.size(); i++) {
|
||||
for (size_t j=0; j<_layers[i].size(); j++) {
|
||||
gate_id_t gid = _layers[i][j];
|
||||
#ifdef __PURE_SHE__
|
||||
signal_t s = _eval_gate(gid, my_id, prf_output, temp_mpz);
|
||||
_eval_gate(gid, my_id, prf_output, temp_mpz);
|
||||
#else
|
||||
signal_t s = _eval_gate(gid, my_id, prf_output);
|
||||
_eval_gate(gid, my_id, prf_output);
|
||||
#endif
|
||||
// printf("%d ", s);
|
||||
_externals[_gates[gid]._out] = s;
|
||||
}
|
||||
}
|
||||
delete[] prf_output;
|
||||
}
|
||||
|
||||
void BooleanCircuit::EvaluateByLayer(int num_threads, party_id_t my_id)
|
||||
@@ -291,7 +292,7 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
|
||||
mpz_t temp_mpz;
|
||||
init_temp_mpz_t(temp_mpz);
|
||||
#endif
|
||||
for(int l=0; l<_layers.size(); l++) {
|
||||
for(size_t l=0; l<_layers.size(); l++) {
|
||||
int layer_sz = _layers[l].size();
|
||||
int start_idx = (layer_sz/num_threads)*i;
|
||||
int end_idx = (layer_sz/num_threads)*(i+1)-1;
|
||||
@@ -301,11 +302,10 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
|
||||
for(int g=start_idx; g<=end_idx; g++) {
|
||||
gate_id_t gid = _layers[l][g];
|
||||
#ifdef __PURE_SHE__
|
||||
signal_t s = _eval_gate(gid, my_id, prf_output, temp_mpz);
|
||||
_eval_gate(gid, my_id, prf_output, temp_mpz);
|
||||
#else
|
||||
signal_t s = _eval_gate(gid, my_id, prf_output);
|
||||
_eval_gate(gid, my_id, prf_output);
|
||||
#endif
|
||||
_externals[_gates[gid]._out] = s;
|
||||
}
|
||||
// printf("done eval layer %d\n", l);
|
||||
|
||||
@@ -375,75 +375,17 @@ void BooleanCircuit::_eval_by_layer(int i, int num_threads, party_id_t my_id)
|
||||
//}
|
||||
|
||||
#ifdef __PURE_SHE__
|
||||
signal_t BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz)
|
||||
void BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz)
|
||||
#else
|
||||
signal_t BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output)
|
||||
void BooleanCircuit::_eval_gate(gate_id_t g, party_id_t my_id, char* prf_output)
|
||||
#endif
|
||||
{
|
||||
// std::cout << std::endl << "evaluate gate " << g << std::endl;
|
||||
wire_id_t w_l = _gates[g]._left;
|
||||
wire_id_t w_r = _gates[g]._right;
|
||||
wire_id_t w_o = _gates[g]._out;
|
||||
int sig_l = _externals[w_l];
|
||||
int sig_r = _externals[w_r];
|
||||
int entry = 2 * sig_l + sig_r;
|
||||
|
||||
Key* garbled_entry = _garbled_entry(g, entry);
|
||||
int ext_l = entry%2 ? 1 : 0 ;
|
||||
int ext_r = entry<2 ? 0 : 1 ;
|
||||
|
||||
Key k;
|
||||
for(party_id_t i=1; i<=_num_parties; i++) {
|
||||
// std::cout << "using key: " << *_key(i,w_l,sig_l) << ": ";
|
||||
PRF_chunk(_key(i,w_l,sig_l), _input(ext_l,g,1), prf_output, PAD_TO_8(_num_parties));
|
||||
for(party_id_t j=1; j<=_num_parties; j++) {
|
||||
k = *(Key*)(prf_output+16*(j-1));
|
||||
#ifdef __PRIME_FIELD__
|
||||
k.adjust();
|
||||
#endif
|
||||
// printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_l, sig_l,ext_l,g,j);
|
||||
// std::cout << k << std::endl;
|
||||
garbled_entry[j-1] -= k;
|
||||
}
|
||||
|
||||
PRF_chunk(_key(i,w_r,sig_r), _input(ext_r,g,1) , prf_output, PAD_TO_8(_num_parties));
|
||||
for(party_id_t j=1; j<=_num_parties; j++) {
|
||||
// std::cout << "using key: " << *_key(i,w_r,sig_r) << ": ";
|
||||
k = *(Key*)(prf_output+16*(j-1));
|
||||
#ifdef __PRIME_FIELD__
|
||||
k.adjust();
|
||||
#endif
|
||||
// printf("Fk^%d_{%u,%d}(%d,%u,%d) = ",i, w_r, sig_r,ext_r,g,j);
|
||||
// std::cout << k << std::endl;
|
||||
garbled_entry[j-1] -= k;
|
||||
}
|
||||
}
|
||||
|
||||
#if __PURE_SHE__
|
||||
for(party_id_t j=1; j<=_num_parties; j++) {
|
||||
garbled_entry[j-1].sqr_in_place(tmp_mpz);
|
||||
}
|
||||
#endif
|
||||
|
||||
// for(party_id_t i=1; i<=_num_parties; i++) {
|
||||
// std::cout << garbled_entry[i-1] << " ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
if(garbled_entry[my_id-1] == *_key(my_id, w_o, 0)) {
|
||||
// std::cout << "k^"<<my_id<<"_{"<<w_o<<",0} = " << *_key(my_id, w_o, 0) << std::endl;
|
||||
memcpy(_key(1,w_o,0), garbled_entry, sizeof(Key)*_num_parties);
|
||||
return 0;
|
||||
} else if (garbled_entry[my_id-1] == *_key(my_id, w_o, 1)) {
|
||||
// std::cout << "k^"<<my_id<<"_{"<<w_o<<",1} = " << *_key(my_id, w_o, 1) << std::endl;
|
||||
memcpy(_key(1,w_o,1), garbled_entry, sizeof(Key)*_num_parties);
|
||||
return 1;
|
||||
} else {
|
||||
printf("\nERROR!!!\n");
|
||||
// throw std::invalid_argument("result key doesn't fit any of my keys");
|
||||
// return NO_SIGNAL;
|
||||
}
|
||||
|
||||
party->registers[w_o].eval(party->registers[w_l], party->registers[w_r],
|
||||
party->_garbled_tbl[g-1], my_id, prf_output, w_o, w_l, w_r);
|
||||
}
|
||||
|
||||
|
||||
@@ -522,12 +464,11 @@ std::string BooleanCircuit::Output()
|
||||
{
|
||||
std::cout << "output:" <<std::endl;
|
||||
std::stringstream ss;
|
||||
// printf("masks\n");
|
||||
// phex(_masks, _wires.size());
|
||||
// printf("externals\n");
|
||||
// phex(_externals, _wires.size());
|
||||
for( int i=0; i<_num_output_wires; i++ ) {
|
||||
int output = _externals[_output_start+i] ^ _masks[_output_start+i];
|
||||
printf("masks/externals\n");
|
||||
for( size_t i=0; i<_num_output_wires; i++ ) {
|
||||
cout << "mask " << i << ": " << (int)party->registers[_output_start+i].mask << endl;
|
||||
cout << "external " << i << ": " << (int)party->registers[_output_start+i].get_external() << endl;
|
||||
int output = party->registers[_output_start+i].get_output();
|
||||
ss << output;
|
||||
}
|
||||
std::cout << ss.str();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#ifndef __BOOLEAN_CIRCUIT__
|
||||
#define __BOOLEAN_CIRCUIT__
|
||||
@@ -22,10 +24,10 @@
|
||||
#include "Gate.h"
|
||||
#include "Wire.h"
|
||||
#include "Key.h"
|
||||
#include "Register.h"
|
||||
#include "GarbledGate.h"
|
||||
|
||||
typedef unsigned int party_id_t;
|
||||
|
||||
//#include "Party.h"
|
||||
#include "Party.h"
|
||||
|
||||
|
||||
#define INIT_PARTY(W,N) {.wires=W, .n_wires=N }
|
||||
@@ -41,13 +43,11 @@ typedef struct party_t{
|
||||
#define GARBLED_GATE_SIZE(N) (4*N)
|
||||
#define MSG_KEYS_HEADER_SZ (16)
|
||||
|
||||
//#define PAD_TO_8(n) (n+8-n%8)
|
||||
#define PAD_TO_8(n) (n)
|
||||
|
||||
class BooleanCircuit
|
||||
{
|
||||
friend class Party;
|
||||
friend class TrustedParty;
|
||||
friend class CommonCircuitParty;
|
||||
public:
|
||||
BooleanCircuit(const char* desc_file);
|
||||
// void RawInputs(std::string raw_inputs);
|
||||
@@ -63,15 +63,7 @@ public:
|
||||
inline wire_id_t OutWiresStart() { return _output_start; }
|
||||
inline gate_id_t NumGates() {return _num_gates; }
|
||||
inline wire_id_t NumParties() { return _num_parties; }
|
||||
inline party_t Party(party_id_t id) {return _parties[id];}
|
||||
inline void Keys(Key* keys) {_keys = keys;}
|
||||
inline Key* Keys() { return _keys; }
|
||||
inline void Masks(char* masks) {_masks = masks;}
|
||||
inline char* Masks() { return _masks; }
|
||||
inline char* PrfInputs() {return _prf_inputs;}
|
||||
inline void PrfInputs(char* prf_inputs) { _prf_inputs = prf_inputs; }
|
||||
inline char* Prfs(){return _prfs;}
|
||||
inline void Prfs(char* prfs){_prfs=prfs;}
|
||||
inline party_t get_party(party_id_t id) {return _parties[id];}
|
||||
|
||||
|
||||
private:
|
||||
@@ -88,7 +80,7 @@ private:
|
||||
|
||||
|
||||
std::vector<std::vector<gate_id_t>> _layers;
|
||||
int _max_layer_sz;
|
||||
size_t _max_layer_sz;
|
||||
void _make_layers();
|
||||
int __make_layers(gate_id_t g);
|
||||
void _add_to_layer(int layer, gate_id_t g);
|
||||
@@ -111,13 +103,13 @@ private:
|
||||
void _parse_circuit(const char* desc_file);
|
||||
void _eval_thread(party_id_t my_id);
|
||||
#ifdef __PURE_SHE__
|
||||
signal_t _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz);
|
||||
void _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output, mpz_t& tmp_mpz);
|
||||
#else
|
||||
signal_t _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output);
|
||||
void _eval_gate(gate_id_t g, party_id_t my_id, char* prf_output);
|
||||
#endif
|
||||
|
||||
inline bool is_wire_ready(wire_id_t w) {
|
||||
return _externals[w] != NO_SIGNAL;
|
||||
return party->registers[w].get_external() != NO_SIGNAL;
|
||||
}
|
||||
inline bool is_gate_ready(gate_id_t g) {
|
||||
bool ready = is_wire_ready(_gates[g]._left)
|
||||
@@ -125,73 +117,7 @@ private:
|
||||
return ready;
|
||||
}
|
||||
|
||||
/* Additional data stored per per party per wire: */
|
||||
Key* _keys; /* Total of n*W*2 keys
|
||||
* For every w={0,...,W}
|
||||
* For every b={0,1}
|
||||
* For every i={1...n}
|
||||
* k^i_{w,b}
|
||||
* This is helpful that the keys for specific w and b are adjacent
|
||||
* for pipelining matters.
|
||||
*/
|
||||
inline Key* _key(party_id_t i, wire_id_t w,int b) {return _keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
|
||||
|
||||
#ifdef __PURE_SHE__
|
||||
Key* _sqr_keys;
|
||||
inline Key* _sqr_key(party_id_t i, wire_id_t w,int b) {return _sqr_keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
|
||||
#endif
|
||||
|
||||
char* _masks; /* There are W masks, one per wire. beginning with 0 */
|
||||
char* _externals; /* Same as _masks */
|
||||
|
||||
Key* _garbled_tbl; /* will be allocated 4*n*G keys;
|
||||
* (n keys for each of A,B,C,D entries).
|
||||
* For each g in G:
|
||||
* A1, A2, ... , An
|
||||
* B1, B2, ... , Bn
|
||||
* C1, C2, ... , Cn
|
||||
* D1, D2, ... , Dn
|
||||
*/
|
||||
Key* _garbled_tbl_copy;
|
||||
inline Key* _garbled_entry(gate_id_t g, int entry) {return _garbled_tbl+(g-1)*4*_num_parties+entry*_num_parties;}
|
||||
|
||||
char* _prfs; /*
|
||||
* Total of n*(G*2*2*2*n+4) = n*(8*G*n+RESERVE_FOR_MSG_TYPE) = n*(PRFS_PER_PARTY+RESERVE_FOR_MSG_TYPE)
|
||||
*
|
||||
* For every party i={1...n}
|
||||
* <Message_type> = saves us copying memory to another location - only for Party.
|
||||
* For every gate g={1...G}
|
||||
* For input wires x={left,right}
|
||||
* For b={0,1} (0-key/1-key)
|
||||
* For e={0,1} (extension)
|
||||
* for every party j={1...n}
|
||||
* F_{k^i_{x,b}}(e,j,g,)
|
||||
*/
|
||||
|
||||
|
||||
char* _prf_inputs; /*
|
||||
* These are all possible inputs to the prf,
|
||||
* This is not efficient in terms of storage but increase
|
||||
* performance in terms of speed since no need to generate
|
||||
* (new allocation plus filling) those inputs every time
|
||||
* we compute the prf.
|
||||
* Structure:
|
||||
* - Total of G*n*2 inputs. (i.e. for every gate g, for every
|
||||
* party j and for every extension e (from {0,1}).
|
||||
* - We want to be able to set key once and use it to encrypt
|
||||
* several inputs, so we want those inputs to be adjacent in
|
||||
* memory in order to save time of building the block of
|
||||
* inputs. So, the structure of the inputs is as follows:
|
||||
* - First half of inputs (G*n inputs) are:
|
||||
* - For every gate g=1,...,G we store the inputs:
|
||||
* (0||g||1),(0||g||2),...,(0||g||n)
|
||||
* - Second half of the inputs are:
|
||||
* - For every gate g=1,...,G we store the inputs:
|
||||
* (1||g||1),(1||g||2),...,(1||g||n)
|
||||
*/
|
||||
inline char* _input(int e, gate_id_t g, party_id_t j)
|
||||
{ return _prf_inputs + (e*_num_gates*_num_parties + (g-1)*_num_parties + j-1) * 16 ;}
|
||||
|
||||
Party* party;
|
||||
};
|
||||
|
||||
|
||||
|
||||
219
BMR/CommonParty.cpp
Normal file
219
BMR/CommonParty.cpp
Normal file
@@ -0,0 +1,219 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* CommonParty.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "CommonParty.h"
|
||||
#include "BooleanCircuit.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
|
||||
CommonParty* CommonParty::singleton = 0;
|
||||
|
||||
CommonParty::CommonParty() :
|
||||
_node(0), gate_counter(0), gate_counter2(0), garbled_tbl_size(0),
|
||||
cpu_timer(CLOCK_PROCESS_CPUTIME_ID), buffers(TYPE_MAX)
|
||||
{
|
||||
insecure("MPC emulation");
|
||||
if (singleton != 0)
|
||||
throw runtime_error("there can only be one");
|
||||
singleton = this;
|
||||
prng.ReSeed();
|
||||
#ifdef DEBUG_PRNG
|
||||
octet seed[SEED_SIZE];
|
||||
memset(seed, 0, sizeof(seed));
|
||||
prng.SetSeed(seed);
|
||||
#endif
|
||||
cpu_timer.start();
|
||||
timer.start();
|
||||
gf2n::init_field(128);
|
||||
mac_key.randomize(prng);
|
||||
}
|
||||
|
||||
CommonParty::~CommonParty()
|
||||
{
|
||||
if (_node)
|
||||
delete _node;
|
||||
cout << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl;
|
||||
cout << "CPU time: " << cpu_timer.elapsed() << endl;
|
||||
cout << "Total time: " << timer.elapsed() << endl;
|
||||
cout << "First phase time: " << timers[0].elapsed() << endl;
|
||||
cout << "Second phase time: " << timers[1].elapsed() << endl;
|
||||
cout << "Number of gates: " << gate_counter << endl;
|
||||
}
|
||||
|
||||
void CommonParty::init(const char* netmap_file, int id, int n_parties)
|
||||
{
|
||||
#ifdef N_PARTIES
|
||||
if (n_parties != N_PARTIES)
|
||||
throw runtime_error("wrong number of parties");
|
||||
#else
|
||||
#ifdef MAX_N_PARTIES
|
||||
if (n_parties > MAX_N_PARTIES)
|
||||
throw runtime_error("too many parties");
|
||||
#endif
|
||||
_N = n_parties;
|
||||
#endif // N_PARTIES
|
||||
printf("netmap_file: %s\n", netmap_file);
|
||||
if (0 == strcmp(netmap_file, LOOPBACK_STR)) {
|
||||
_node = new Node( NULL, id, this, _N + 1);
|
||||
} else {
|
||||
_node = new Node(netmap_file, id, this);
|
||||
}
|
||||
}
|
||||
|
||||
int CommonParty::init(const char* netmap_file, int id)
|
||||
{
|
||||
int n_parties;
|
||||
if (string(netmap_file) != string(LOOPBACK_STR))
|
||||
{
|
||||
ifstream(netmap_file) >> n_parties;
|
||||
n_parties--;
|
||||
}
|
||||
else
|
||||
n_parties = 2;
|
||||
init(netmap_file, id, n_parties);
|
||||
return n_parties;
|
||||
}
|
||||
|
||||
void CommonParty::reset()
|
||||
{
|
||||
garbled_tbl_size = 0;
|
||||
}
|
||||
|
||||
gate_id_t CommonParty::new_gate()
|
||||
{
|
||||
gate_counter++;
|
||||
garbled_tbl_size++;
|
||||
return gate_counter;
|
||||
}
|
||||
|
||||
void CommonParty::next_gate(GarbledGate& gate)
|
||||
{
|
||||
gate_counter2++;
|
||||
gate.init_inputs(gate_counter2, _N);
|
||||
}
|
||||
|
||||
void CommonParty::input(Register& reg, party_id_t from)
|
||||
{
|
||||
(void)reg;
|
||||
(void)from;
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
SendBuffer& CommonParty::get_buffer(MSG_TYPE type)
|
||||
{
|
||||
SendBuffer& buffer = buffers[type];
|
||||
buffer.clear();
|
||||
fill_message_type(buffer, type);
|
||||
#ifdef DEBUG_BUFFER
|
||||
cout << type << " buffer:";
|
||||
phex(buffers.data(), 4);
|
||||
#endif
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void CommonCircuitParty::print_masks(const vector<int>& indices)
|
||||
{
|
||||
vector<char> bits;
|
||||
for (auto i = indices.begin(); i != indices.end(); i++)
|
||||
bits.push_back(registers[*i].get_mask_no_check());
|
||||
print_bit_array(bits);
|
||||
}
|
||||
|
||||
void CommonCircuitParty::print_outputs(const vector<int>& indices)
|
||||
{
|
||||
vector<char> bits;
|
||||
for (auto i = indices.begin(); i != indices.end(); i++)
|
||||
bits.push_back(registers[*i].get_output_no_check());
|
||||
print_bit_array(bits);
|
||||
}
|
||||
|
||||
|
||||
template <class T, class U>
|
||||
GC::BreakType CommonParty::first_phase(GC::Program<U>& program,
|
||||
GC::Processor<T>& processor, GC::Machine<T>& machine)
|
||||
{
|
||||
(void)machine;
|
||||
timers[0].start();
|
||||
reset();
|
||||
wires.clear();
|
||||
GC::BreakType next = (reinterpret_cast<GC::Program<T>*>(&program))->execute(processor);
|
||||
#ifdef DEBUG_ROUNDS
|
||||
cout << "finished first phase at pc " << processor.PC
|
||||
<< " reason " << next << endl;
|
||||
#endif
|
||||
timers[0].stop();
|
||||
cout << "First round time: " << timers[0].elapsed() << " / "
|
||||
<< timer.elapsed() << endl;
|
||||
#ifdef DEBUG_WIRES
|
||||
cout << "Storing wires with " << 1e-9 * wires.size() << " GB on disk" << endl;
|
||||
#endif
|
||||
wire_storage.push(wires);
|
||||
return next;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
GC::BreakType CommonParty::second_phase(GC::Program<T>& program,
|
||||
GC::Processor<T>& processor, GC::Machine<T>& machine)
|
||||
{
|
||||
(void)machine;
|
||||
wire_storage.pop(wires);
|
||||
wires.reset_head();
|
||||
timers[1].start();
|
||||
GC::BreakType next = GC::TIME_BREAK;
|
||||
next = program.execute(processor);
|
||||
#ifdef DEBUG_ROUNDS
|
||||
cout << "finished second phase at " << processor.PC
|
||||
<< " reason " << next << endl;
|
||||
#endif
|
||||
timers[1].stop();
|
||||
// cout << "Second round time: " << timers[1].elapsed() << ", ";
|
||||
// cout << "total time: " << timer.elapsed() << endl;
|
||||
if (false)
|
||||
return GC::CAP_BREAK;
|
||||
else
|
||||
return next;
|
||||
}
|
||||
|
||||
void CommonCircuitParty::prepare_input_regs(party_id_t from)
|
||||
{
|
||||
party_t sender = _circuit->_parties[from];
|
||||
wire_id_t s = sender.wires; //the beginning of my wires
|
||||
wire_id_t n = sender.n_wires; // number of my wires
|
||||
input_regs_queue.clear();
|
||||
input_regs_queue.push_back(_N + 1);
|
||||
(*input_regs)[from].clear();
|
||||
for (wire_id_t i = 0; i < n; i++) {
|
||||
wire_id_t w = s + i;
|
||||
(*input_regs)[from].push_back(w);
|
||||
}
|
||||
}
|
||||
|
||||
void CommonCircuitParty::prepare_output_regs()
|
||||
{
|
||||
output_regs.clear();
|
||||
for (size_t i = 0; i < _OW; i++)
|
||||
output_regs.push_back(_circuit->OutWiresStart()+i);
|
||||
}
|
||||
|
||||
template GC::BreakType CommonParty::first_phase(
|
||||
GC::Program<GC::Secret<GarbleRegister> >& program,
|
||||
GC::Processor<GC::Secret<RandomRegister> >& processor,
|
||||
GC::Machine<GC::Secret<RandomRegister> >& machine);
|
||||
|
||||
template GC::BreakType CommonParty::first_phase(
|
||||
GC::Program<GC::Secret<EvalRegister> >& program,
|
||||
GC::Processor<GC::Secret<PRFRegister> >& processor,
|
||||
GC::Machine<GC::Secret<PRFRegister> >& machine);
|
||||
|
||||
template GC::BreakType CommonParty::second_phase(
|
||||
GC::Program<GC::Secret<GarbleRegister> >& program,
|
||||
GC::Processor<GC::Secret<GarbleRegister> >& processor,
|
||||
GC::Machine<GC::Secret<GarbleRegister> >& machine);
|
||||
|
||||
template GC::BreakType CommonParty::second_phase(
|
||||
GC::Program<GC::Secret<EvalRegister> >& program,
|
||||
GC::Processor<GC::Secret<EvalRegister> >& processor,
|
||||
GC::Machine<GC::Secret<EvalRegister> >& machine);
|
||||
163
BMR/CommonParty.h
Normal file
163
BMR/CommonParty.h
Normal file
@@ -0,0 +1,163 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* CommonParty.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef BMR_COMMONPARTY_H_
|
||||
#define BMR_COMMONPARTY_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "GarbledGate.h"
|
||||
#include "Register.h"
|
||||
#include "proto_utils.h"
|
||||
#include "network/Node.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "GC/Program.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
#if (defined(DEBUG) || defined(DEBUG_COMM)) && !defined(DEBUG_STEPS)
|
||||
#define DEBUG_STEPS
|
||||
#endif
|
||||
|
||||
enum SpdzOp
|
||||
{
|
||||
SPDZ_LOAD,
|
||||
SPDZ_STORE,
|
||||
SPDZ_MAC,
|
||||
SPDZ_OP_N,
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class PersistentFront
|
||||
{
|
||||
deque<T>& container;
|
||||
int i;
|
||||
public:
|
||||
PersistentFront(deque<T>& container) : container(container), i(0) {}
|
||||
T& operator*() { return container.front(); }
|
||||
void operator++(int) { i++; }
|
||||
void reset() {}
|
||||
int get_i() { return i; }
|
||||
};
|
||||
|
||||
class CommonParty : public NodeUpdatable
|
||||
{
|
||||
protected:
|
||||
friend class Register;
|
||||
|
||||
#ifdef N_PARTIES
|
||||
const party_id_t _N = N_PARTIES;
|
||||
#else
|
||||
party_id_t _N;
|
||||
#endif
|
||||
Node* _node;
|
||||
|
||||
int gate_counter, gate_counter2;
|
||||
int garbled_tbl_size;
|
||||
|
||||
Timer cpu_timer;
|
||||
Timer timers[2];
|
||||
Timer timer;
|
||||
|
||||
gf2n mac_key;
|
||||
|
||||
mutex global_lock;
|
||||
|
||||
LocalBuffer wires;
|
||||
ReceivedMsgStore wire_storage;
|
||||
|
||||
template<class T, class U>
|
||||
GC::BreakType first_phase(GC::Program<U>& program, GC::Processor<T>& processor,
|
||||
GC::Machine<T>& machine);
|
||||
template<class T>
|
||||
GC::BreakType second_phase(GC::Program<T>& program, GC::Processor<T>& processor,
|
||||
GC::Machine<T>& machine);
|
||||
|
||||
public:
|
||||
static CommonParty* singleton;
|
||||
static CommonParty& s();
|
||||
|
||||
vector<SendBuffer> buffers;
|
||||
|
||||
PRNG prng;
|
||||
|
||||
CommonParty();
|
||||
virtual ~CommonParty();
|
||||
|
||||
#ifdef N_PARTIES
|
||||
static int get_n_parties() { return N_PARTIES; }
|
||||
#else
|
||||
static int get_n_parties() { return s()._N; }
|
||||
#endif
|
||||
|
||||
void init(const char* netmap_file, int id, int n_parties);
|
||||
int init(const char* netmap_file, int id);
|
||||
virtual void reset();
|
||||
|
||||
gate_id_t new_gate();
|
||||
void next_gate(GarbledGate& gate);
|
||||
gate_id_t next_gate(int skip) { return gate_counter2 += skip; }
|
||||
size_t get_garbled_tbl_size() { return garbled_tbl_size; }
|
||||
|
||||
void input(Register& reg, party_id_t from);
|
||||
|
||||
SendBuffer& get_buffer(MSG_TYPE type);
|
||||
|
||||
gf2n get_mac_key() { return mac_key; }
|
||||
};
|
||||
|
||||
class BooleanCircuit;
|
||||
|
||||
class CommonCircuitParty : virtual public CommonParty
|
||||
{
|
||||
protected:
|
||||
BooleanCircuit* _circuit;
|
||||
gate_id_t _G;
|
||||
wire_id_t _W;
|
||||
wire_id_t _OW;
|
||||
|
||||
CheckVector<Register> registers;
|
||||
|
||||
deque< CheckVector< CheckVector<int> > > input_regs_queue;
|
||||
PersistentFront< CheckVector< CheckVector<int> > > input_regs;
|
||||
vector<int> output_regs;
|
||||
|
||||
CommonCircuitParty() : input_regs(input_regs_queue) {}
|
||||
|
||||
void prepare_input_regs(party_id_t id);
|
||||
void prepare_output_regs();
|
||||
|
||||
void resize_registers() { registers.resize(_W, {(int)_N}); }
|
||||
|
||||
Register& get_reg(int reg);
|
||||
|
||||
void print_masks(const vector<int>& indices);
|
||||
void print_outputs(const vector<int>& indices);
|
||||
void print_round_regs();
|
||||
};
|
||||
|
||||
|
||||
inline Register& CommonCircuitParty::get_reg(int reg)
|
||||
{
|
||||
#ifdef DEBUG_REGS
|
||||
cout << "get reg " << reg << endl << registers.at(reg).keys[0] << endl
|
||||
<< registers.at(reg).keys[1] << endl;
|
||||
#endif
|
||||
return registers.at(reg);
|
||||
}
|
||||
|
||||
inline CommonParty& CommonParty::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
#endif /* BMR_COMMONPARTY_H_ */
|
||||
97
BMR/GarbledGate.cpp
Normal file
97
BMR/GarbledGate.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* GarbledGate.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GarbledGate.h"
|
||||
#include "prf.h"
|
||||
#include "CommonParty.h"
|
||||
|
||||
GarbledGate::~GarbledGate() {
|
||||
}
|
||||
|
||||
void GarbledGate::init_inputs(gate_id_t g, int n_parties)
|
||||
{
|
||||
n_parties = CommonParty::get_n_parties();
|
||||
id = g;
|
||||
for (unsigned int e = 0; e <= 1; e++) {
|
||||
prf_inputs[e].resize(n_parties);
|
||||
for (unsigned int j = 0; j < (size_t)n_parties; j++) {
|
||||
prf_inputs[e][j] = 0;
|
||||
|
||||
/* fill out this buffer s.t. first 4 bytes are the extension (0/1),
|
||||
* next 4 bytes are gate_id and next 4 bytes are party id.
|
||||
* For the first half we dont need to fill the extension because
|
||||
* it is zero anyway.
|
||||
*/
|
||||
|
||||
unsigned int* prf_input_index =
|
||||
(unsigned int*) &prf_inputs[e][j]; //easier to refer as integers
|
||||
// printf("e,g,j=%u,%u,%u\n",e,g,j);
|
||||
*prf_input_index = e;
|
||||
*(prf_input_index + 1) = g;
|
||||
*(prf_input_index + 2) = j + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id,
|
||||
SendBuffer& buffer, gate_id_t g)
|
||||
{
|
||||
int n_parties = CommonParty::get_n_parties();
|
||||
init_inputs(g, n_parties);
|
||||
PRFOutputs prf_output(n_parties);
|
||||
for(int w=0; w<=1; w++) {
|
||||
for (int b=0; b<=1; b++) {
|
||||
const Key& key = in_wires[w]->key(my_id, b);
|
||||
AES_KEY aes_key;
|
||||
AES_128_Key_Expansion((unsigned char*)&key.r, &aes_key);
|
||||
#ifdef DEBUG
|
||||
cout << "using key " << key << endl;
|
||||
#endif
|
||||
for (int e=0; e<=1; e++) {
|
||||
for (int j=1; j<= n_parties; j++) {
|
||||
prf_output[my_id-1][j-1].outputs[w][b][e][0] =
|
||||
aes_128_encrypt(*(__m128i*)input(e, j), (octet*)aes_key.rd_key);
|
||||
#ifdef __PRIME_FIELD__
|
||||
((Key*)prf_outputs_index)->adjust();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < n_parties; i++)
|
||||
buffer.serialize(prf_output[my_id - 1][i]);
|
||||
#ifdef DEBUG
|
||||
wire_id_t wire_ids[] = { (wire_id_t)in_wires[0]->get_id(), (wire_id_t)in_wires[1]->get_id() };
|
||||
prf_output.print_prfs(g, wire_ids, my_id, n_parties);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PRFOutputs::print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties)
|
||||
{
|
||||
for(int w=0; w<=1; w++) {
|
||||
for (int b=0; b<=1; b++) {
|
||||
for (int e=0; e<=1; e++) {
|
||||
for(party_id_t j=1; j<=(size_t)n_parties; j++) {
|
||||
printf("F_k^%d_{%lu,%u}(%d,%lu,%u) = ", my_id, in_wires[w], b, e, g, j);
|
||||
Key k = *((Key*)(*this)[my_id-1][j-1].outputs[w][b][e]);
|
||||
std::cout << k << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GarbledGate::print()
|
||||
{
|
||||
cout << "garbled gate " << id << endl;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
for (size_t j = 0; j < keys[i].size(); j++)
|
||||
cout << keys[i][j] << " ";
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
90
BMR/GarbledGate.h
Normal file
90
BMR/GarbledGate.h
Normal file
@@ -0,0 +1,90 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* GarbledGate.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_
|
||||
#define BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_
|
||||
|
||||
#include "Register.h"
|
||||
#include "common.h"
|
||||
|
||||
struct PRFTuple {
|
||||
Key outputs[2][2][2][1];
|
||||
};
|
||||
|
||||
/*
|
||||
* Total of n*(G*2*2*2*n+4) = n*(8*G*n+RESERVE_FOR_MSG_TYPE) = n*(PRFS_PER_PARTY+RESERVE_FOR_MSG_TYPE)
|
||||
*
|
||||
* For every party i={1...n}
|
||||
* <Message_type> = saves us copying memory to another location - only for Party.
|
||||
* For every gate g={1...G}
|
||||
* For input wires x={left,right}
|
||||
* For b={0,1} (0-key/1-key)
|
||||
* For e={0,1} (extension)
|
||||
* for every party j={1...n}
|
||||
* F_{k^i_{x,b}}(e,j,g,)
|
||||
*/
|
||||
struct PRFOutputs {
|
||||
#ifdef MAX_N_PARTIES
|
||||
PRFTuple tuples[MAX_N_PARTIES][MAX_N_PARTIES];
|
||||
PRFOutputs(int n_parties) { (void)n_parties; }
|
||||
PRFTuple* operator[](int i) { return tuples[i]; }
|
||||
#else
|
||||
vector<PRFTuple> tuples;
|
||||
int n_parties;
|
||||
|
||||
PRFOutputs(int n_parties) : n_parties(n_parties), tuples(n_parties * n_parties) {}
|
||||
PRFTuple* operator[](int i) { return &tuples[i*n_parties]; }
|
||||
#endif
|
||||
void print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties);
|
||||
};
|
||||
|
||||
class GarbledGate : public KeyTuple<4> {
|
||||
/* will be allocated 4*n keys;
|
||||
* (n keys for each of A,B,C,D entries):
|
||||
* A1, A2, ... , An
|
||||
* B1, B2, ... , Bn
|
||||
* C1, C2, ... , Cn
|
||||
* D1, D2, ... , Dn
|
||||
*/
|
||||
|
||||
public:
|
||||
KeyVector prf_inputs[2]; /*
|
||||
* These are all possible inputs to the prf,
|
||||
* This is not efficient in terms of storage but increase
|
||||
* performance in terms of speed since no need to generate
|
||||
* (new allocation plus filling) those inputs every time
|
||||
* we compute the prf.
|
||||
* Structure:
|
||||
* - Total of G*n*2 inputs. (i.e. for every gate g, for every
|
||||
* party j and for every extension e (from {0,1}).
|
||||
* - We want to be able to set key once and use it to encrypt
|
||||
* several inputs, so we want those inputs to be adjacent in
|
||||
* memory in order to save time of building the block of
|
||||
* inputs. So, the structure of the inputs is as follows:
|
||||
* - First half of inputs (G*n inputs) are:
|
||||
* - For every gate g=1,...,G we store the inputs:
|
||||
* (0||g||1),(0||g||2),...,(0||../circuit_bmr/inc/GarbledGate.h:43:19: error: ‘gate_id_t’ has not been declared
|
||||
* g||n)
|
||||
* - Second half of the inputs are:
|
||||
* - For every gate g=1,...,G we store the inputs:
|
||||
* (1||g||1),(1||g||2),...,(1||g||n)
|
||||
*/
|
||||
|
||||
gate_id_t id;
|
||||
|
||||
GarbledGate(int n_parties) : KeyTuple<4>(n_parties), id(-1) {}
|
||||
virtual ~GarbledGate();
|
||||
|
||||
void init_inputs(gate_id_t g, int n_parties);
|
||||
|
||||
char* input(int e, party_id_t j) { return (char*)&prf_inputs[e][j-1]; }
|
||||
|
||||
void compute_prfs_outputs(const Register** in_wires, int my_id, SendBuffer& buffer, gate_id_t g);
|
||||
void print();
|
||||
};
|
||||
|
||||
#endif /* BMR_PRIME_CIRCUIT_BMR_INC_GARBLEDGATE_H_ */
|
||||
19
BMR/Gate.h
19
BMR/Gate.h
@@ -1,32 +1,33 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#ifndef __GATE_H__
|
||||
#define __GATE_H__
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <boost/thread/mutex.hpp>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "Key.h"
|
||||
|
||||
#define NO_LAYER (-1)
|
||||
|
||||
typedef struct Gate {
|
||||
class Gate {
|
||||
public:
|
||||
wire_id_t _left;
|
||||
wire_id_t _right;
|
||||
wire_id_t _out;
|
||||
uint8_t _func[4];
|
||||
Function _func;
|
||||
|
||||
int _layer;
|
||||
|
||||
Gate() : _left(-1), _right(-1), _out(-1), _layer(-1) {}
|
||||
|
||||
inline void init(wire_id_t left, wire_id_t right, wire_id_t out,std::string func) {
|
||||
_left = left;
|
||||
_right = right;
|
||||
_out = out;
|
||||
_func[0] = (func[0]=='0')?0:1;//TODO change to func[0]-0x30
|
||||
_func[1] = (func[1]=='0')?0:1;
|
||||
_func[2] = (func[2]=='0')?0:1;
|
||||
_func[3] = (func[3]=='0')?0:1;
|
||||
_func = func;
|
||||
_layer = NO_LAYER;
|
||||
}
|
||||
inline uint8_t func(uint8_t left, uint8_t right) {
|
||||
@@ -34,9 +35,9 @@ typedef struct Gate {
|
||||
}
|
||||
|
||||
void print(int id) {
|
||||
printf ("gate %d: l:%d, r:%d, o:%d, func:%d%d%d%d\n", id, _left, _right, _out, _func[0], _func[1], _func[2], _func[3] );
|
||||
printf ("gate %d: l:%lu, r:%lu, o:%lu, func:%d%d%d%d\n", id, _left, _right, _out, _func[0], _func[1], _func[2], _func[3] );
|
||||
}
|
||||
|
||||
} Gate;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
31
BMR/Key.cpp
31
BMR/Key.cpp
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Key.cpp
|
||||
*
|
||||
* Created on: Oct 27, 2015
|
||||
* Author: marcel
|
||||
*/
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ ostream& operator<<(ostream& o, const Key& key)
|
||||
|
||||
ostream& operator<<(ostream& o, const __m128i& x) {
|
||||
o.fill('0');
|
||||
o << hex;
|
||||
o << hex << noshowbase;
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
o.width(16);
|
||||
@@ -27,26 +27,11 @@ ostream& operator<<(ostream& o, const __m128i& x) {
|
||||
return o;
|
||||
}
|
||||
|
||||
Key& Key::operator=(const Key& other) {
|
||||
r= other.r;
|
||||
// memcpy(&r, &other.r, sizeof(r));
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool Key::operator==(const Key& other) {
|
||||
__m128i neq = _mm_xor_si128(r, other.r);
|
||||
return _mm_test_all_zeros(neq,neq);
|
||||
}
|
||||
|
||||
Key& Key::operator-=(const Key& other) {
|
||||
r ^= other.r;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Key& Key::operator+=(const Key& other) {
|
||||
r ^= other.r;
|
||||
return *this;
|
||||
}
|
||||
//Key& Key::operator=(const Key& other) {
|
||||
// r= other.r;
|
||||
//// memcpy(&r, &other.r, sizeof(r));
|
||||
// return *this;
|
||||
//}
|
||||
|
||||
#else //__PRIME_FIELD__ is defined
|
||||
|
||||
|
||||
56
BMR/Key.h
56
BMR/Key.h
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Key.h
|
||||
*
|
||||
* Created on: Oct 27, 2015
|
||||
* Author: marcel
|
||||
*/
|
||||
|
||||
#ifndef COMMON_INC_KEY_H_
|
||||
@@ -13,7 +13,7 @@
|
||||
#include <smmintrin.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "proto_utils.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@@ -23,21 +23,63 @@ class Key {
|
||||
public:
|
||||
__m128i r;
|
||||
|
||||
Key() : r(_mm_set1_epi64x(0)) {}
|
||||
Key(long long a) : r(_mm_set1_epi64x(a)) {}
|
||||
Key(const Key& other) {r= other.r;}
|
||||
Key() {}
|
||||
Key(long long a) : r(_mm_cvtsi64_si128(a)) {}
|
||||
Key(long long a, long long b) : r(_mm_set_epi64x(a, b)) {}
|
||||
Key(__m128i r) : r(r) {}
|
||||
// Key(const Key& other) {r= other.r;}
|
||||
|
||||
Key& operator=(const Key& other);
|
||||
// Key& operator=(const Key& other);
|
||||
bool operator==(const Key& other);
|
||||
bool operator!=(const Key& other) { return !(*this == other); }
|
||||
|
||||
Key& operator-=(const Key& other);
|
||||
Key& operator+=(const Key& other);
|
||||
|
||||
Key operator^(const Key& other) const { return r ^ other.r; }
|
||||
Key operator^=(const Key& other) { r ^= other.r; return *this; }
|
||||
|
||||
void serialize(SendBuffer& output) const { output.serialize(r); }
|
||||
void serialize_no_allocate(SendBuffer& output) const { output.serialize_no_allocate(r); }
|
||||
|
||||
bool get_signal() { return _mm_cvtsi128_si64(r) & 1; }
|
||||
|
||||
template <class T>
|
||||
T get() const;
|
||||
};
|
||||
|
||||
ostream& operator<<(ostream& o, const Key& key);
|
||||
ostream& operator<<(ostream& o, const __m128i& x);
|
||||
|
||||
|
||||
inline bool Key::operator==(const Key& other) {
|
||||
__m128i neq = _mm_xor_si128(r, other.r);
|
||||
return _mm_test_all_zeros(neq,neq);
|
||||
}
|
||||
|
||||
inline Key& Key::operator-=(const Key& other) {
|
||||
r ^= other.r;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline Key& Key::operator+=(const Key& other) {
|
||||
r ^= other.r;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline unsigned long Key::get() const
|
||||
{
|
||||
return _mm_cvtsi128_si64(r);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __m128i Key::get() const
|
||||
{
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
#else //__PRIME_FIELD__ is defined
|
||||
|
||||
const __uint128_t MODULUS= 0xffffffffffffffffffffffffffffff61;
|
||||
|
||||
953
BMR/Party.cpp
953
BMR/Party.cpp
File diff suppressed because it is too large
Load Diff
241
BMR/Party.h
241
BMR/Party.h
@@ -1,94 +1,253 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Party.h
|
||||
*
|
||||
* Created on: Feb 15, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOL_PARTY_H_
|
||||
#define PROTOCOL_PARTY_H_
|
||||
|
||||
#include "BooleanCircuit.h"
|
||||
#include <mutex>
|
||||
#include <boost/atomic.hpp>
|
||||
#include "Node.h"
|
||||
#include "Register.h"
|
||||
#include "GarbledGate.h"
|
||||
#include "network/Node.h"
|
||||
#include "CommonParty.h"
|
||||
#include "SpdzWire.h"
|
||||
#include "AndJob.h"
|
||||
|
||||
#include "GC/Machine.h"
|
||||
#include "GC/Program.h"
|
||||
#include "GC/Processor.h"
|
||||
#include "GC/Secret.h"
|
||||
#include "Tools/Worker.h"
|
||||
|
||||
class BooleanCircuit;
|
||||
|
||||
#define SERVER_ID (0)
|
||||
#define INPUT_KEYS_MSG_TYPE_SIZE (16) // so memory will by alligned
|
||||
|
||||
#ifndef N_EVAL_THREADS
|
||||
// Default Intel desktop processor has 8 half cores.
|
||||
// This is beneficial if only one AES available per full core.
|
||||
#define N_EVAL_THREADS (8)
|
||||
#endif
|
||||
|
||||
|
||||
typedef struct {
|
||||
unsigned long min=0;
|
||||
unsigned long long acc=0;
|
||||
} exec_props_t;
|
||||
|
||||
class Party : public NodeUpdatable {
|
||||
class BaseParty : virtual public CommonParty {
|
||||
public:
|
||||
BaseParty(party_id_t id);
|
||||
virtual ~BaseParty();
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
void NodeReady();
|
||||
void NewMessage(int from, ReceivedMsg& msg);
|
||||
void NodeAborted(struct sockaddr_in* from) { (void)from; }
|
||||
|
||||
void Start();
|
||||
|
||||
party_id_t get_id() { return _id; }
|
||||
Key get_delta() { return delta; }
|
||||
|
||||
protected:
|
||||
party_id_t _id;
|
||||
|
||||
// int _num_evaluation_threads;
|
||||
struct timeval _start_online_net, _end_online_net;
|
||||
|
||||
vector<char> input_masks;
|
||||
vector<char>::iterator input_mask;
|
||||
|
||||
Timer online_timer;
|
||||
|
||||
Key delta;
|
||||
|
||||
virtual void _compute_prfs_outputs(Key* keys) = 0;
|
||||
void _send_prfs();
|
||||
|
||||
virtual void _process_external_received(char* externals,
|
||||
party_id_t from) = 0;
|
||||
virtual void _process_all_external_received(char* externals) = 0;
|
||||
virtual void _process_input_keys(Key* keys, party_id_t from) = 0;
|
||||
virtual void _process_all_input_keys(char* keys) = 0;
|
||||
|
||||
virtual void store_garbled_circuit(ReceivedMsg& msg) = 0;
|
||||
virtual void _check_evaluate() = 0;
|
||||
|
||||
virtual void mask_output(ReceivedMsg& msg) = 0;
|
||||
|
||||
void done();
|
||||
|
||||
virtual void start_online_round() = 0;
|
||||
|
||||
virtual void receive_spdz_wires(ReceivedMsg& msg) = 0;
|
||||
};
|
||||
|
||||
class Party : public BaseParty, public CommonCircuitParty {
|
||||
friend class BooleanCircuit;
|
||||
|
||||
public:
|
||||
Party(const char* netmap_file, const char* circuit_file, party_id_t id, const std::string input, int numthreads=5, int numtries=2);
|
||||
virtual ~Party();
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
void NodeReady();
|
||||
void NewMessage(int from, char* message, unsigned int len);
|
||||
void NodeAborted(struct sockaddr_in* from) {}
|
||||
|
||||
void Start();
|
||||
|
||||
/* TEST methods */
|
||||
|
||||
private:
|
||||
party_id_t _id;
|
||||
Node* _node;
|
||||
BooleanCircuit* _circuit;
|
||||
|
||||
gate_id_t _G;
|
||||
party_id_t _N;
|
||||
wire_id_t _W;
|
||||
wire_id_t _OW;
|
||||
wire_id_t _IO;
|
||||
wire_id_t _IO;
|
||||
|
||||
std::string _all_input;
|
||||
char* _input;
|
||||
|
||||
char* _external_values_msg;
|
||||
unsigned int _external_values_msg_sz;
|
||||
int _NUMTHREADS;
|
||||
int _NUMTRIES;
|
||||
|
||||
vector<GarbledGate> _garbled_tbl;
|
||||
|
||||
vector<char> _input;
|
||||
|
||||
SendBuffer _external_values_msg;
|
||||
int _num_externals_msg_received;
|
||||
std::mutex _process_externals_mx;
|
||||
|
||||
char* _input_wire_keys_msg;
|
||||
unsigned int _input_wire_keys_msg_sz;
|
||||
SendBuffer _input_wire_keys_msg;
|
||||
int _num_inputkeys_msg_received;
|
||||
std::mutex _process_keys_mx;
|
||||
|
||||
std::mutex _sync_mx;
|
||||
|
||||
// int _num_evaluation_threads;
|
||||
struct timeval* _start_online_net, *_end_online_net;
|
||||
int _NUMTHREADS;
|
||||
int _NUMTRIES;
|
||||
#ifdef __PURE_SHE__
|
||||
Key* _sqr_keys;
|
||||
inline Key* _sqr_key(party_id_t i, wire_id_t w,int b) {return _sqr_keys+ w*2*_num_parties + b*_num_parties + i-1 ; }
|
||||
#endif
|
||||
|
||||
inline Key& _key(party_id_t i, wire_id_t w,int b) {return registers[w][b][i-1] ; }
|
||||
inline KeyVector& _garbled_entry(gate_id_t g, int entry) {return _garbled_tbl[g-1][entry];}
|
||||
vector<GarbledGate>::iterator get_garbled_tbl_end() { return _garbled_tbl.begin() + garbled_tbl_size; }
|
||||
void resize_garbled_tbl() { _garbled_tbl.resize(_G, _N); garbled_tbl_size = _G; }
|
||||
|
||||
void _initialize_input();
|
||||
void _generate_prf_inputs();
|
||||
void _allocate_prf_outputs();
|
||||
void _compute_prfs_outputs();
|
||||
void _print_prfs();
|
||||
void _send_prfs();
|
||||
void _compute_prfs_outputs(Key* keys);
|
||||
void _print_keys();
|
||||
void _printf_garbled_table();
|
||||
|
||||
void _allocate_external_values();
|
||||
void _generate_external_values_msg(char * masks);
|
||||
void _process_external_received(char* externals, party_id_t from);
|
||||
void _generate_external_values_msg();
|
||||
void _process_external_received(char* externals,
|
||||
party_id_t from);
|
||||
void _process_all_external_received(char* externals);
|
||||
inline void _allocate_input_wire_keys();
|
||||
void _print_input_keys_msg();
|
||||
void _print_keys_of_party(Key *keys, int id);
|
||||
void _print_input_keys_checksum();
|
||||
void _process_input_keys(Key* keys, party_id_t from);
|
||||
void _process_all_input_keys(char* keys);
|
||||
|
||||
void _print_input_keys_msg();
|
||||
void _print_keys_of_party(Key *keys, int id);
|
||||
void _printf_garbled_table();
|
||||
|
||||
void store_garbled_circuit(ReceivedMsg& msg);
|
||||
void load_garbled_circuit() {}
|
||||
|
||||
void _check_evaluate();
|
||||
|
||||
void receive_keys(Key* keys);
|
||||
|
||||
void receive_spdz_wires(ReceivedMsg& msg) { (void)msg; }
|
||||
|
||||
void start_online_round();
|
||||
|
||||
void mask_output(ReceivedMsg& msg);
|
||||
|
||||
int get_n_inputs();
|
||||
};
|
||||
|
||||
class ProgramParty : public BaseParty
|
||||
{
|
||||
friend class PRFRegister;
|
||||
friend class EvalRegister;
|
||||
friend class Register;
|
||||
|
||||
char* prf_output;
|
||||
Key* keys_for_prf;
|
||||
|
||||
deque<octetStream> spdz_wires[SPDZ_OP_N];
|
||||
size_t spdz_storage;
|
||||
size_t garbled_storage;
|
||||
vector<size_t> spdz_counters;
|
||||
|
||||
Worker<AndJob> eval_threads[N_EVAL_THREADS];
|
||||
AndJob and_jobs[N_EVAL_THREADS];
|
||||
|
||||
ReceivedMsgStore output_masks_store;
|
||||
|
||||
GC::Memory< GC::Secret<EvalRegister>::DynamicType > dynamic_memory;
|
||||
GC::Machine< GC::Secret<EvalRegister> > machine;
|
||||
GC::Processor<GC::Secret<EvalRegister> > processor;
|
||||
GC::Program<GC::Secret<EvalRegister> > program;
|
||||
|
||||
GC::Machine< GC::Secret<PRFRegister> > prf_machine;
|
||||
GC::Processor<GC::Secret<PRFRegister> > prf_processor;
|
||||
|
||||
void _compute_prfs_outputs(Key* keys);
|
||||
|
||||
void _process_external_received(char* externals,
|
||||
party_id_t from) { (void)externals; (void)from; }
|
||||
void _process_all_external_received(char* externals) { (void)externals; }
|
||||
void _process_input_keys(Key* keys, party_id_t from)
|
||||
{ (void)keys; (void)from; }
|
||||
void _process_all_input_keys(char* keys) { (void)keys; }
|
||||
|
||||
void store_garbled_circuit(ReceivedMsg& msg);
|
||||
void load_garbled_circuit();
|
||||
|
||||
void _check_evaluate();
|
||||
|
||||
void receive_keys(Register& reg);
|
||||
void receive_all_keys(Register& reg, bool external);
|
||||
|
||||
void receive_spdz_wires(ReceivedMsg& msg);
|
||||
|
||||
void start_online_round();
|
||||
|
||||
void mask_output(ReceivedMsg& msg) { output_masks_store.push(msg); }
|
||||
|
||||
public:
|
||||
static ProgramParty* singleton;
|
||||
|
||||
ReceivedMsg garbled_circuit;
|
||||
ReceivedMsgStore garbled_circuits;
|
||||
|
||||
ReceivedMsg output_masks;
|
||||
|
||||
MAC_Check<gf2n>* MC;
|
||||
Player* P;
|
||||
Names N;
|
||||
|
||||
int threshold;
|
||||
|
||||
static ProgramParty& s();
|
||||
|
||||
ProgramParty(int argc, char** argv);
|
||||
~ProgramParty();
|
||||
|
||||
void reset();
|
||||
|
||||
void input_value(party_id_t from, char value);
|
||||
|
||||
void get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire);
|
||||
|
||||
void store_wire(const Register& reg);
|
||||
void load_wire(Register& reg);
|
||||
};
|
||||
|
||||
inline ProgramParty& ProgramParty::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
#endif /* PROTOCOL_PARTY_H_ */
|
||||
|
||||
1094
BMR/Register.cpp
Normal file
1094
BMR/Register.cpp
Normal file
File diff suppressed because it is too large
Load Diff
394
BMR/Register.h
Normal file
394
BMR/Register.h
Normal file
@@ -0,0 +1,394 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Register.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOL_SRC_REGISTER_H_
|
||||
#define PROTOCOL_SRC_REGISTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <stdint.h>
|
||||
using namespace std;
|
||||
|
||||
#include "Key.h"
|
||||
#include "Wire.h"
|
||||
#include "GC/Clear.h"
|
||||
#include "GC/Memory.h"
|
||||
#include "GC/Access.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
#ifndef FREE_XOR
|
||||
#warning not using free XOR has not been tested in a while
|
||||
#endif
|
||||
|
||||
typedef unsigned int party_id_t;
|
||||
|
||||
//#define PAD_TO_8(n) (n+8-n%8)
|
||||
#define PAD_TO_8(n) (n)
|
||||
|
||||
#ifdef N_PARTIES
|
||||
#define MAX_N_PARTIES N_PARTIES
|
||||
#endif
|
||||
|
||||
#ifdef MAX_N_PARTIES
|
||||
class BaseKeyVector
|
||||
{
|
||||
Key keys[MAX_N_PARTIES];
|
||||
public:
|
||||
Key& operator[](int i) { return keys[i]; }
|
||||
const Key& operator[](int i) const { return keys[i]; }
|
||||
Key* data() { return keys; }
|
||||
const Key* data() const { return keys; }
|
||||
#ifdef N_PARTIES
|
||||
BaseKeyVector(int n_parties = 0) { (void)n_parties; avx_memzero(keys, sizeof(keys)); }
|
||||
size_t size() const { return N_PARTIES; }
|
||||
void resize(int size) { (void)size; }
|
||||
#else
|
||||
BaseKeyVector(int n_parties = 0) : n_parties(n_parties) { memset(keys, 0, sizeof(keys)); }
|
||||
size_t size() const { return n_parties; }
|
||||
void resize(int size) { n_parties = size; }
|
||||
private:
|
||||
int n_parties;
|
||||
#endif
|
||||
};
|
||||
#else
|
||||
typedef vector<Key> BaseKeyVector;
|
||||
#endif
|
||||
|
||||
class KeyVector : public BaseKeyVector
|
||||
{
|
||||
public:
|
||||
KeyVector(int size = 0) : BaseKeyVector(size) {}
|
||||
size_t byte_size() const { return size() * sizeof(Key); }
|
||||
void operator=(const KeyVector& source);
|
||||
KeyVector operator^(const KeyVector& other) const;
|
||||
template <class T>
|
||||
void serialize_no_allocate(T& output) const { output.serialize_no_allocate(data(), byte_size()); }
|
||||
template <class T>
|
||||
void serialize(T& output) const { output.serialize(data(), byte_size()); }
|
||||
void unserialize(ReceivedMsg& source, int n_parties);
|
||||
friend ostream& operator<<(ostream& os, const KeyVector& kv);
|
||||
};
|
||||
|
||||
class GarbledGate;
|
||||
class CommonParty;
|
||||
|
||||
template <int I>
|
||||
class KeyTuple {
|
||||
friend class Register;
|
||||
|
||||
static long counter;
|
||||
|
||||
protected:
|
||||
KeyVector keys[I];
|
||||
|
||||
int part_size() { return keys[0].size() * sizeof(Key); }
|
||||
public:
|
||||
KeyTuple() {}
|
||||
KeyTuple(int n_parties) { init(n_parties); }
|
||||
void init(int n_parties);
|
||||
int byte_size() { return I * keys[0].byte_size(); }
|
||||
KeyVector& operator[](int i) { return keys[i]; }
|
||||
const KeyVector& operator[](int i) const { return keys[i]; }
|
||||
KeyTuple<I> operator^(const KeyTuple<I>& other) const;
|
||||
void copy_to(Key* dest);
|
||||
void unserialize(ReceivedMsg& source, int n_parties);
|
||||
void copy_from(Key* source, int n_parties, int except);
|
||||
template <class T>
|
||||
void serialize_no_allocate(T& output) const;
|
||||
template <class T>
|
||||
void serialize(T& output) const;
|
||||
template <class T>
|
||||
void serialize(T& output, party_id_t pid) const;
|
||||
void unserialize(vector<char>& output);
|
||||
template <class T>
|
||||
void unserialize(T& output);
|
||||
void randomize();
|
||||
void reset();
|
||||
void print(int wire_id) const;
|
||||
void print(int wire_id, party_id_t pid);
|
||||
};
|
||||
|
||||
namespace GC
|
||||
{
|
||||
class AuthValue;
|
||||
class Mask;
|
||||
class SpdzShare;
|
||||
template <class T>
|
||||
class Secret;
|
||||
}
|
||||
|
||||
class Register {
|
||||
protected:
|
||||
static int counter;
|
||||
|
||||
KeyVector garbled_entry;
|
||||
char external;
|
||||
|
||||
public:
|
||||
char mask;
|
||||
KeyTuple<2> keys; /* Additional data stored per per party per wire: */
|
||||
/* Total of n*W*2 keys
|
||||
* For every w={0,...,W}
|
||||
* For every b={0,1}
|
||||
* For every i={1...n}
|
||||
* k^i_{w,b}
|
||||
* This is helpful that the keys for specific w and b are adjacent
|
||||
* for pipelining matters.
|
||||
*/
|
||||
|
||||
Register(int n_parties);
|
||||
|
||||
void init(int n_parties);
|
||||
void init(int rfd, int n_parties);
|
||||
KeyVector& operator[](int i) { return keys[i]; }
|
||||
const Key& key(party_id_t i, int b) const { return keys[b][i-1]; }
|
||||
Key& key(party_id_t i, int b) { return keys[b][i-1]; }
|
||||
void set_eval_keys();
|
||||
void set_eval_keys(Key* keys, int n_parties, int except);
|
||||
const Key& external_key(party_id_t i) const { return garbled_entry[i-1]; }
|
||||
void set_external_key(party_id_t i, const Key& key);
|
||||
void reset_non_external_key(party_id_t i);
|
||||
void set_external(char ext);
|
||||
char get_external() const { check_external(); return external; }
|
||||
char get_external_no_check() const { return external; }
|
||||
void set_mask(char mask);
|
||||
int get_mask() const { check_mask(); return mask; }
|
||||
char get_mask_no_check() { return mask; }
|
||||
char get_output() { check_external(); check_mask(); return mask ^ external; }
|
||||
char get_output_no_check() { return mask ^ external; }
|
||||
const KeyVector& get_garbled_entry() const { return garbled_entry; }
|
||||
void print_input(int id);
|
||||
void print() const { keys.print(get_id()); }
|
||||
void check_external() const;
|
||||
void check_mask() const;
|
||||
void check_signal_key(int my_id, KeyVector& garbled_entry);
|
||||
|
||||
void eval(const Register& left, const Register& right, GarbledGate& gate,
|
||||
party_id_t my_id, char* prf_output, int, int, int);
|
||||
|
||||
void garble(const Register& left, const Register& right, Function func,
|
||||
Gate* gate, int g, vector<ReceivedMsg>& prf_outputs, SendBuffer& buffer);
|
||||
|
||||
size_t get_id() const { return (size_t)this; }
|
||||
|
||||
template <class T>
|
||||
void set_trace();
|
||||
};
|
||||
|
||||
|
||||
// this is to fake a "cout" that does nothing
|
||||
class BlackHole
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
BlackHole& operator<<(T) { return *this; }
|
||||
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
|
||||
};
|
||||
inline BlackHole& endl(BlackHole& b) { return b; }
|
||||
inline BlackHole& flush(BlackHole& b) { return b; }
|
||||
|
||||
class ProgramRegister : public Register
|
||||
{
|
||||
public:
|
||||
typedef BlackHole out_type;
|
||||
static const BlackHole out;
|
||||
|
||||
static Register new_reg();
|
||||
static Register tmp_reg() { return new_reg(); }
|
||||
static Register and_reg() { return new_reg(); }
|
||||
|
||||
static void check(const int128& value, word share, int128 mac)
|
||||
{ (void)value; (void)share; (void)mac; }
|
||||
static void get_dyn_mask(GC::Mask& mask, int length, int mac_length)
|
||||
{ (void)mask; (void)length; (void)mac_length; }
|
||||
template <class T>
|
||||
static void store_clear_in_dynamic(T& mem, const vector<GC::ClearWriteAccess>& accesses)
|
||||
{ (void)mem; (void)accesses; }
|
||||
static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share,
|
||||
word masked, int128 masked_mac)
|
||||
{ (void)dest; (void)mask_share; (void)mac_mask_share; (void)masked; (void)masked_mac; }
|
||||
|
||||
template<class T>
|
||||
static void store(GC::Memory<GC::SpdzShare>& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
|
||||
template<class T>
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const GC::Memory<GC::SpdzShare>& source) { (void)accesses; (void)source; }
|
||||
|
||||
template <class T>
|
||||
static void andrs(T& processor, const vector<int>& args) { processor.andrs(args); }
|
||||
|
||||
void input(party_id_t from, char value = -1) { (void)from; (void)value; }
|
||||
void public_input(bool value) { (void)value; }
|
||||
void random() {}
|
||||
char get_output() { return 0; }
|
||||
};
|
||||
|
||||
class FirstRoundRegister : public ProgramRegister
|
||||
{
|
||||
public:
|
||||
};
|
||||
|
||||
class SecondRoundRegister : public ProgramRegister
|
||||
{
|
||||
public:
|
||||
};
|
||||
|
||||
class PRFRegister : public FirstRoundRegister
|
||||
{
|
||||
public:
|
||||
static string name() { return "PRF"; }
|
||||
|
||||
template<class T>
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const GC::Memory<GC::SpdzShare>& source);
|
||||
|
||||
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
void input(party_id_t from, char input = -1);
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
};
|
||||
|
||||
class EvalRegister : public SecondRoundRegister
|
||||
{
|
||||
public:
|
||||
static string name() { return "Evaluation"; }
|
||||
|
||||
typedef ostream& out_type;
|
||||
static ostream& out;
|
||||
|
||||
static void check(const int128& value, word share, int128 mac);
|
||||
static void get_dyn_mask(GC::Mask& mask, int length, int mac_length);
|
||||
static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share,
|
||||
word masked, int128 masked_mac);
|
||||
|
||||
template<class T>
|
||||
static void store(GC::Memory<GC::SpdzShare>& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses);
|
||||
template<class T>
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const GC::Memory<GC::SpdzShare>& source);
|
||||
|
||||
template <class T>
|
||||
static void andrs(T& processor, const vector<int>& args);
|
||||
|
||||
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
unsigned long long get_output() { return Register::get_output(); }
|
||||
|
||||
template <class T>
|
||||
static void store_clear_in_dynamic(GC::Memory<T>& mem,
|
||||
const vector<GC::ClearWriteAccess>& accesses);
|
||||
void input(party_id_t from, char value = -1);
|
||||
};
|
||||
|
||||
class GarbleRegister : public SecondRoundRegister
|
||||
{
|
||||
public:
|
||||
static string name() { return "Garbling"; }
|
||||
|
||||
template<class T>
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const GC::Memory<GC::SpdzShare>& source);
|
||||
|
||||
void op(const Register& left, const Register& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output() {}
|
||||
};
|
||||
|
||||
class RandomRegister : public FirstRoundRegister
|
||||
{
|
||||
public:
|
||||
static string name() { return "Randomization"; }
|
||||
|
||||
template<class T>
|
||||
static void store(GC::Memory<GC::SpdzShare>& dest,
|
||||
const vector<GC::WriteAccess<T> >& accesses);
|
||||
template<class T>
|
||||
static void load(vector<GC::ReadAccess<T> >& accesses,
|
||||
const GC::Memory<GC::SpdzShare>& source);
|
||||
|
||||
void randomize();
|
||||
|
||||
void op(const Register& left, const Register& right, Function func);
|
||||
void XOR(const Register& left, const Register& right);
|
||||
|
||||
void input(party_id_t from, char value = -1);
|
||||
void public_input(bool value);
|
||||
void random();
|
||||
void output();
|
||||
};
|
||||
|
||||
|
||||
inline Register::Register(int n_parties) :
|
||||
garbled_entry(n_parties), external(NO_SIGNAL),
|
||||
mask(NO_SIGNAL), keys(n_parties)
|
||||
{
|
||||
}
|
||||
|
||||
inline void KeyVector::unserialize(ReceivedMsg& source, int n_parties)
|
||||
{
|
||||
resize(n_parties);
|
||||
source.unserialize(data(), size() * sizeof(Key));
|
||||
}
|
||||
|
||||
template <int I>
|
||||
inline void KeyTuple<I>::init(int n_parties) {
|
||||
for (int i = 0; i < I; i++)
|
||||
keys[i].resize(n_parties);
|
||||
}
|
||||
|
||||
template<int I>
|
||||
inline void KeyTuple<I>::reset()
|
||||
{
|
||||
for (int i = 0; i < I; i++)
|
||||
for (size_t j = 0; j < keys[i].size(); j++)
|
||||
keys[i][j] = 0;
|
||||
}
|
||||
|
||||
template <int I>
|
||||
inline void KeyTuple<I>::unserialize(ReceivedMsg& source, int n_parties) {
|
||||
for (int b = 0; b < I; b++)
|
||||
keys[b].unserialize(source, n_parties);
|
||||
}
|
||||
|
||||
template<int I> template <class T>
|
||||
void KeyTuple<I>::serialize_no_allocate(T& output) const {
|
||||
for (int i = 0; i < I; i++)
|
||||
keys[i].serialize_no_allocate(output);
|
||||
}
|
||||
|
||||
template<int I> template <class T>
|
||||
void KeyTuple<I>::serialize(T& output) const {
|
||||
for (int i = 0; i < I; i++)
|
||||
for (unsigned int j = 0; j < keys[i].size(); j++)
|
||||
keys[i][j].serialize(output);
|
||||
}
|
||||
|
||||
template<int I> template <class T>
|
||||
void KeyTuple<I>::serialize(T& output, party_id_t pid) const {
|
||||
for (int i = 0; i < I; i++)
|
||||
keys[i][pid - 1].serialize(output);
|
||||
}
|
||||
|
||||
template<int I> template <class T>
|
||||
void KeyTuple<I>::unserialize(T& output) {
|
||||
for (int i = 0; i < I; i++)
|
||||
for (unsigned int j = 0; j < keys[i].size(); j++)
|
||||
output.unserialize(keys[i][j]);
|
||||
}
|
||||
|
||||
#endif /* PROTOCOL_SRC_REGISTER_H_ */
|
||||
20
BMR/Register_inline.h
Normal file
20
BMR/Register_inline.h
Normal file
@@ -0,0 +1,20 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Register_inline.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef BMR_REGISTER_INLINE_H_
|
||||
#define BMR_REGISTER_INLINE_H_
|
||||
|
||||
#include "CommonParty.h"
|
||||
#include "Party.h"
|
||||
|
||||
|
||||
inline Register ProgramRegister::new_reg()
|
||||
{
|
||||
return Register(CommonParty::s().get_n_parties());
|
||||
}
|
||||
|
||||
#endif /* BMR_REGISTER_INLINE_H_ */
|
||||
26
BMR/SpdzWire.cpp
Normal file
26
BMR/SpdzWire.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* SpdzWire.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "SpdzWire.h"
|
||||
|
||||
SpdzWire::SpdzWire()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void SpdzWire::pack(octetStream& os) const
|
||||
{
|
||||
mask.pack(os);
|
||||
os.serialize(my_keys);
|
||||
}
|
||||
|
||||
void SpdzWire::unpack(octetStream& os, size_t wanted_size)
|
||||
{
|
||||
(void)wanted_size;
|
||||
mask.unpack(os);
|
||||
os.unserialize(my_keys);
|
||||
}
|
||||
25
BMR/SpdzWire.h
Normal file
25
BMR/SpdzWire.h
Normal file
@@ -0,0 +1,25 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* SpdzWire.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef BMR_SPDZWIRE_H_
|
||||
#define BMR_SPDZWIRE_H_
|
||||
|
||||
#include "Math/Share.h"
|
||||
#include "Key.h"
|
||||
|
||||
class SpdzWire
|
||||
{
|
||||
public:
|
||||
Share<gf2n> mask;
|
||||
Key my_keys[2];
|
||||
|
||||
SpdzWire();
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os, size_t wanted_size);
|
||||
};
|
||||
|
||||
#endif /* BMR_SPDZWIRE_H_ */
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* TrustedParty.cpp
|
||||
*
|
||||
* Created on: Feb 15, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#include "TrustedParty.h"
|
||||
@@ -16,9 +16,25 @@
|
||||
|
||||
#include "proto_utils.h"
|
||||
#include "msg_types.h"
|
||||
#include "SpdzWire.h"
|
||||
#include "Auth/fake-stuff.h"
|
||||
|
||||
TrustedProgramParty* TrustedProgramParty::singleton = 0;
|
||||
|
||||
|
||||
|
||||
BaseTrustedParty::BaseTrustedParty()
|
||||
{
|
||||
#ifdef __PURE_SHE__
|
||||
init_modulos();
|
||||
init_temp_mpz_t(_temp_mpz);
|
||||
std::cout << "_temp_mpz: " << _temp_mpz << std::endl;
|
||||
#endif
|
||||
_num_prf_received = 0;
|
||||
_received_gc_received = 0;
|
||||
n_received = 0;
|
||||
randomfd = open("/dev/urandom", O_RDONLY);
|
||||
}
|
||||
|
||||
TrustedParty::TrustedParty(const char* netmap_file, // required to init Node
|
||||
const char* circuit_file // required to init BooleanCircuit
|
||||
@@ -26,126 +42,231 @@ TrustedParty::TrustedParty(const char* netmap_file, // required to init Node
|
||||
{
|
||||
_circuit = new BooleanCircuit( circuit_file );
|
||||
_G = _circuit->NumGates();
|
||||
#ifndef N_PARTIES
|
||||
_N = _circuit->NumParties();
|
||||
#endif
|
||||
_W = _circuit->NumWires();
|
||||
_OW = _circuit->NumOutWires();
|
||||
#ifdef __PURE_SHE__
|
||||
init_modulos();
|
||||
init_temp_mpz_t(_temp_mpz);
|
||||
std::cout << "_temp_mpz: " << _temp_mpz << std::endl;
|
||||
#endif
|
||||
_allocate_prf_outputs();
|
||||
_allocate_garbled_table();
|
||||
_num_prf_received = 0;
|
||||
_received_gc_received = 0;
|
||||
if (0 == strcmp(netmap_file, LOOPBACK_STR)) {
|
||||
_node = new Node( NULL, 0, this , _N+1);
|
||||
} else {
|
||||
_node = new Node( netmap_file, 0, this );
|
||||
}
|
||||
reset();
|
||||
garbled_tbl_size = _G;
|
||||
init(netmap_file, 0, _N);
|
||||
}
|
||||
|
||||
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :
|
||||
machine(dynamic_memory), processor(machine),
|
||||
random_machine(dynamic_memory), random_processor(random_machine)
|
||||
{
|
||||
if (argc < 2)
|
||||
{
|
||||
cerr << "Usage: " << argv[0] << " <program> [netmap]" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
ifstream file((string("Programs/Bytecode/") + argv[1] + "-0.bc").c_str());
|
||||
program.parse(file);
|
||||
processor.reset(program);
|
||||
machine.reset(program);
|
||||
random_processor.reset(program.cast< GC::Secret<RandomRegister> >());
|
||||
random_machine.reset(program.cast< GC::Secret<RandomRegister> >());
|
||||
if (singleton)
|
||||
throw runtime_error("there can only be one");
|
||||
singleton = this;
|
||||
if (argc == 3)
|
||||
init(argv[2], 0);
|
||||
else
|
||||
init("LOOPBACK", 0);
|
||||
#ifdef FREE_XOR
|
||||
deltas.resize(_N);
|
||||
for (size_t i = 0; i < _N; i++)
|
||||
{
|
||||
deltas[i] = prng.get_doubleword();
|
||||
#ifdef DEBUG
|
||||
deltas[i] = Key(i + 1, 0);
|
||||
#endif
|
||||
#ifdef KEY_SIGNAL
|
||||
if (deltas[i].get_signal() == 0)
|
||||
deltas[i] ^= Key(1);
|
||||
#endif
|
||||
cout << "Delta " << i << ": " << deltas[i] << endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TrustedProgramParty::~TrustedProgramParty()
|
||||
{
|
||||
cout << "Random timer: " << random_timer.elapsed() << endl;
|
||||
}
|
||||
|
||||
TrustedParty::~TrustedParty() {
|
||||
// TODO Auto-generated destructor stub
|
||||
}
|
||||
|
||||
void TrustedParty::NodeReady()
|
||||
void BaseTrustedParty::NodeReady()
|
||||
{
|
||||
#ifdef DEBUG_STEPS
|
||||
printf("\n\nNode ready \n\n");
|
||||
sleep(1);
|
||||
#endif
|
||||
//sleep(1);
|
||||
prepare_randomness();
|
||||
send_randomness();
|
||||
prf_outputs.resize(get_n_parties());
|
||||
}
|
||||
|
||||
_generate_masks();
|
||||
void BaseTrustedParty::prepare_randomness()
|
||||
{
|
||||
msg_keys.resize(_N);
|
||||
for (size_t i = 0; i < msg_keys.size(); i++)
|
||||
{
|
||||
msg_keys[i].clear();
|
||||
fill_message_type(msg_keys[i], TYPE_KEYS);
|
||||
msg_keys[i].resize(MSG_KEYS_HEADER_SZ);
|
||||
}
|
||||
|
||||
unsigned int number_of_keys = 2* _W * _N;
|
||||
unsigned int size_of_keys = number_of_keys*sizeof(Key);
|
||||
unsigned int msg_keys_size = size_of_keys + MSG_KEYS_HEADER_SZ;
|
||||
|
||||
Key* all_keys = new Key[number_of_keys];
|
||||
memset(all_keys, 0, size_of_keys);
|
||||
#ifdef __PURE_SHE__
|
||||
_circuit->_sqr_keys = new Key[number_of_keys];
|
||||
memset(_circuit->_sqr_keys, 0, size_of_keys);
|
||||
#endif
|
||||
|
||||
for(party_id_t pid=1; pid<=_N; pid++)
|
||||
{
|
||||
/* generating and sending keys */
|
||||
char* msg_keys = new char[msg_keys_size];
|
||||
memset(msg_keys, 0, msg_keys_size);
|
||||
fill_message_type(msg_keys, TYPE_KEYS);
|
||||
Key* party_keys = (Key*)(msg_keys + MSG_KEYS_HEADER_SZ);
|
||||
#ifdef __PURE_SHE__
|
||||
_fill_keys_for_party(_circuit->_sqr_keys, party_keys, pid);
|
||||
#else
|
||||
_fill_keys_for_party(party_keys, pid);
|
||||
done_filling = _fill_keys();
|
||||
#endif
|
||||
// printf("keys for party %u\n", pid);
|
||||
// phex(party_keys, size_of_keys);
|
||||
|
||||
_merge_keys(all_keys, party_keys);
|
||||
}
|
||||
|
||||
void BaseTrustedParty::send_randomness()
|
||||
{
|
||||
for(party_id_t pid=1; pid<=_N; pid++)
|
||||
{
|
||||
// printf("all keys\n");
|
||||
// phex(all_keys, size_of_keys);
|
||||
_node->Send(pid, msg_keys, msg_keys_size);
|
||||
|
||||
/* sending masks for input wires */
|
||||
party_t party = _circuit->Party(pid);
|
||||
char* msg_input_masks = new char[sizeof(MSG_TYPE) + party.n_wires];
|
||||
fill_message_type(msg_input_masks, TYPE_MASK_INPUTS);
|
||||
memcpy(msg_input_masks+sizeof(MSG_TYPE), _circuit->Masks()+party.wires, party.n_wires);
|
||||
_node->Send(pid, msg_input_masks, sizeof(MSG_TYPE) + party.n_wires);
|
||||
// TODO: test only
|
||||
// printf("input masks for party %d\n", pid);
|
||||
// phex(msg_input_masks + sizeof(MSG_TYPE) , party.n_wires);
|
||||
_node->Send(pid, msg_keys[pid - 1]);
|
||||
// printf("msg keys\n");
|
||||
// phex(msg_keys, msg_keys_size);
|
||||
send_input_masks(pid);
|
||||
}
|
||||
|
||||
_circuit->Keys(all_keys);
|
||||
send_output_masks();
|
||||
}
|
||||
|
||||
void TrustedParty::send_input_masks(party_id_t pid)
|
||||
{
|
||||
prepare_input_regs(pid);
|
||||
/* sending masks for input wires */
|
||||
msg_input_masks.resize(get_n_parties());
|
||||
SendBuffer& buffer = msg_input_masks[pid-1];
|
||||
buffer.clear();
|
||||
fill_message_type(buffer, TYPE_MASK_INPUTS);
|
||||
for (auto input_regs = input_regs_queue.begin();
|
||||
input_regs != input_regs_queue.end(); input_regs++)
|
||||
{
|
||||
int n_wires = (*input_regs)[pid].size();
|
||||
#ifdef DEBUG_ROUNDS
|
||||
cout << dec << n_wires << " inputs from " << pid << endl;
|
||||
#endif
|
||||
for (int i = 0; i < n_wires; i++)
|
||||
buffer.push_back(registers[(*input_regs)[pid][i]].get_mask());
|
||||
#ifdef DEBUG2
|
||||
printf("input masks for party %d\n", pid);
|
||||
phex(buffer);
|
||||
#endif
|
||||
#ifdef DEBUG_VALUES
|
||||
printf("input masks for party %d:\t", pid);
|
||||
print_masks((*input_regs)[pid]);
|
||||
cout << "on registers:" << endl;
|
||||
print_indices((*input_regs)[pid]);
|
||||
#endif
|
||||
}
|
||||
#ifdef DEBUG_ROUNDS
|
||||
cout << "sending " << dec << buffer.size() - 4 << " input masks" << endl;
|
||||
#endif
|
||||
_node->Send(pid, buffer);
|
||||
}
|
||||
|
||||
void TrustedParty::send_output_masks()
|
||||
{
|
||||
prepare_output_regs();
|
||||
/* output wires' masks are the same for all players */
|
||||
/* sending masks for output wires */
|
||||
|
||||
char* msg_output_masks = new char[sizeof(MSG_TYPE) + _OW];
|
||||
fill_message_type(msg_output_masks, TYPE_MASK_OUTPUT);
|
||||
memcpy(msg_output_masks + sizeof(MSG_TYPE), _circuit->Masks()+_circuit->OutWiresStart(), _OW);
|
||||
_node->Broadcast(msg_output_masks, sizeof(MSG_TYPE) + _OW);
|
||||
// TODO: test only
|
||||
// printf("output masks\n");
|
||||
// phex(msg_output_masks+ sizeof(MSG_TYPE) , _OW);
|
||||
int _OW = output_regs.size();
|
||||
SendBuffer& msg_output_masks = get_buffer(TYPE_MASK_OUTPUT);
|
||||
for (int i = 0; i < _OW; i++)
|
||||
msg_output_masks.push_back(get_reg(output_regs[i]).get_mask());
|
||||
_node->Broadcast(msg_output_masks);
|
||||
#ifdef DEBUG2
|
||||
printf("output masks\n");
|
||||
phex(msg_output_masks);
|
||||
#endif
|
||||
#ifdef DEBUG_VALUES
|
||||
printf("output masks:\t\t\t");
|
||||
print_masks(output_regs);
|
||||
cout << "on registers:" << endl;
|
||||
print_indices(output_regs);
|
||||
#endif
|
||||
}
|
||||
|
||||
void TrustedParty::NewMessage(int from, char* message, unsigned int len)
|
||||
void TrustedProgramParty::send_output_masks()
|
||||
{
|
||||
#ifdef DEBUG_OUTPUT_MASKS
|
||||
cout << "sending " << msg_output_masks.size() - 4 << " output masks" << endl;
|
||||
#endif
|
||||
_node->Broadcast(msg_output_masks);
|
||||
}
|
||||
|
||||
void BaseTrustedParty::NewMessage(int from, ReceivedMsg& msg)
|
||||
{
|
||||
char* message = msg.data();
|
||||
int len = msg.size();
|
||||
MSG_TYPE message_type;
|
||||
memcpy(&message_type, message, sizeof(MSG_TYPE));
|
||||
msg.unserialize(message_type);
|
||||
unique_lock<mutex> locker(global_lock);
|
||||
switch(message_type) {
|
||||
case TYPE_PRF_OUTPUTS:
|
||||
{
|
||||
#ifdef DEBUG
|
||||
cout << "TYPE_PRF_OUTPUTS" << endl;
|
||||
#endif
|
||||
_print_mx.lock();
|
||||
// printf("got message of len %u from %d\n", len, from);
|
||||
#ifdef DEBUG2
|
||||
printf("got message of len %u from %d\n", len, from);
|
||||
phex(message, len);
|
||||
cout << "garbled table size " << get_garbled_tbl_size() << endl;
|
||||
#endif
|
||||
#ifdef DEBUG_STEPS
|
||||
printf("\n Got prfs from %d\n",from);
|
||||
#endif
|
||||
|
||||
char* party_prfs = _circuit->Prfs() + (PRFS_PER_PARTY(_G, _N)*sizeof(Key)) *(from-1) ;
|
||||
memcpy(party_prfs, message + sizeof(MSG_TYPE), PRFS_PER_PARTY(_G, _N)*sizeof(Key));
|
||||
// phex(party_prfs, PRFS_PER_PARTY(G, N)*sizeof(Key));
|
||||
prf_outputs[from-1] = msg;
|
||||
_print_mx.unlock();
|
||||
|
||||
_num_prf_received ++;
|
||||
if(_num_prf_received == _N) {
|
||||
if(++_num_prf_received == _N) {
|
||||
_num_prf_received = 0;
|
||||
_compute_send_garbled_circuit();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TYPE_RECEIVED_GC:
|
||||
{
|
||||
_received_gc_received++;
|
||||
if(_received_gc_received == _N) {
|
||||
_launch_online();
|
||||
if(++_received_gc_received == _N) {
|
||||
_received_gc_received = 0;
|
||||
if (done_filling)
|
||||
_launch_online();
|
||||
else
|
||||
NodeReady();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TYPE_NEXT:
|
||||
if (++n_received == _N)
|
||||
{
|
||||
n_received = 0;
|
||||
send_randomness();
|
||||
}
|
||||
break;
|
||||
case TYPE_DONE:
|
||||
if (++n_received == _N)
|
||||
_node->Stop();
|
||||
break;
|
||||
default:
|
||||
{
|
||||
_print_mx.lock();
|
||||
@@ -155,6 +276,7 @@ void TrustedParty::NewMessage(int from, char* message, unsigned int len)
|
||||
phex(message, len);
|
||||
_print_mx.unlock();
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -163,155 +285,57 @@ void TrustedParty::_launch_online()
|
||||
{
|
||||
printf("press to launch online\n");
|
||||
getchar();
|
||||
char* launch_msg = new char[sizeof(MSG_TYPE)];
|
||||
fill_message_type(launch_msg, TYPE_LAUNCH_ONLINE);
|
||||
_node->Broadcast(launch_msg, sizeof(MSG_TYPE));
|
||||
_node->Broadcast(get_buffer(TYPE_LAUNCH_ONLINE));
|
||||
printf("launched\n");
|
||||
}
|
||||
|
||||
void TrustedParty::_allocate_garbled_table()
|
||||
void TrustedProgramParty::_launch_online()
|
||||
{
|
||||
unsigned int garbled_table_sz = _G*4*_N*sizeof(Key)+RESERVE_FOR_MSG_TYPE;
|
||||
_circuit->_garbled_tbl = (Key*)malloc(garbled_table_sz);
|
||||
memset(_circuit->_garbled_tbl,0,garbled_table_sz);
|
||||
|
||||
_node->Broadcast(get_buffer(TYPE_LAUNCH_ONLINE));
|
||||
}
|
||||
|
||||
void TrustedParty::_compute_send_garbled_circuit()
|
||||
void TrustedParty::garble()
|
||||
{
|
||||
|
||||
Key* tbl_start = _circuit->_garbled_tbl+(RESERVE_FOR_MSG_TYPE/sizeof(Key));
|
||||
unsigned int prfs_per_party_sz = PRFS_PER_PARTY(_G,_N)*sizeof(Key);
|
||||
std::vector<Gate>* gates = &_circuit->_gates;
|
||||
char* masks = _circuit->_masks;
|
||||
|
||||
for(gate_id_t g=1; g<=_G; g++) {
|
||||
// std::cout << "garbling gate " << g << std::endl ;
|
||||
Key* gg_A = tbl_start + GARBLED_GATE_SIZE(_N)*(g-1);
|
||||
Key* gg_B = gg_A + _N;
|
||||
Key* gg_C = gg_A + 2*_N;
|
||||
Key* gg_D = gg_A + 3*_N;
|
||||
unsigned int prfs_left_offset = (g-1)*8*_N*sizeof(Key);
|
||||
unsigned int prfs_right_offset = prfs_left_offset + 4*_N*sizeof(Key);
|
||||
|
||||
for(party_id_t i=1; i<=_N; i++) {
|
||||
// std::cout << "adding prfs of party " << i << std::endl ;
|
||||
char* party_prfs = _circuit->Prfs() + (i-1)*prfs_per_party_sz;
|
||||
Key left_i_j, right_i_j;
|
||||
for (party_id_t j=1; j<=_N; j++) {
|
||||
//A
|
||||
// std::cout << "A" << std::endl;
|
||||
left_i_j = *(Key*)(party_prfs+prfs_left_offset + (j-1)*sizeof(Key));
|
||||
right_i_j = *(Key*)(party_prfs+prfs_right_offset + (j-1)*sizeof(Key));
|
||||
// cout << *(gg_A+j-1) << std::endl;
|
||||
// cout << left_i_j << std::endl;
|
||||
// cout << right_i_j << std::endl;
|
||||
*(gg_A+j-1) += left_i_j; //left wire of party i in part j
|
||||
*(gg_A+j-1) += right_i_j; //right wire of party i in part j
|
||||
// cout << *(gg_A+j-1) << std::endl<< std::endl;
|
||||
|
||||
//B
|
||||
// std::cout << "B" << std::endl;
|
||||
left_i_j = *(Key*)(party_prfs+prfs_left_offset + _N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
right_i_j = *(Key*)(party_prfs+prfs_right_offset + 2*_N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
// cout << *(gg_B+j-1) << std::endl;
|
||||
// cout << left_i_j << std::endl;
|
||||
// cout << right_i_j << std::endl;
|
||||
*(gg_B+j-1) += left_i_j; //left wire of party i in part j
|
||||
*(gg_B+j-1) += right_i_j; //right wire of party i in part j
|
||||
// cout << *(gg_B+j-1) << std::endl<< std::endl;
|
||||
|
||||
//C
|
||||
// std::cout << "C" << std::endl;
|
||||
left_i_j = *(Key*)(party_prfs+prfs_left_offset + 2*_N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
right_i_j = *(Key*)(party_prfs+prfs_right_offset + _N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
// cout << *(gg_C+j-1) << std::endl;
|
||||
// cout << left_i_j << std::endl;
|
||||
// cout << right_i_j << std::endl;
|
||||
*(gg_C+j-1) += left_i_j; //left wire of party i in part j
|
||||
*(gg_C+j-1) += right_i_j; //right wire of party i in part j
|
||||
// cout << *(gg_C+j-1) << std::endl<< std::endl;
|
||||
|
||||
//D
|
||||
// std::cout << "D" << std::endl;
|
||||
left_i_j = *(Key*)(party_prfs+prfs_left_offset + 3*_N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
right_i_j = *(Key*)(party_prfs+prfs_right_offset + 3*_N*sizeof(Key) + (j-1)*sizeof(Key));
|
||||
// cout << *(gg_D+j-1) << std::endl;
|
||||
// cout << left_i_j << std::endl;
|
||||
// cout << right_i_j << std::endl;
|
||||
*(gg_D+j-1) += left_i_j; //left wire of party i in part j
|
||||
*(gg_D+j-1) += right_i_j; //right wire of party i in part j
|
||||
// cout << *(gg_D+j-1) << std::endl<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
//Adding the hidden keys
|
||||
Gate* gate = &gates->at(g);
|
||||
char maskl = masks[gate->_left];
|
||||
char maskr = masks[gate->_right];
|
||||
char masko = masks[gate->_out];
|
||||
// printf("\ngate %u, leftwire=%u, rightwire=%u, outwire=%u: func=%d%d%d%d, msk_l=%d, msk_r=%d, msk_o=%d\n"
|
||||
// , g,gate->_left, gate->_right, gate->_out
|
||||
// ,gate->_func[0],gate->_func[1],gate->_func[2],gate->_func[3], maskl, maskr, masko);
|
||||
|
||||
// printf("\n");
|
||||
// printf("maskl=%d, maskr=%d, masko=%d\n",maskl,maskr,masko);
|
||||
// printf("gate func = %d%d%d%d\n",gate->_func[0],gate->_func[1],gate->_func[2],gate->_func[3]);
|
||||
bool xa = gate->_func[2*maskl+maskr] != masko;
|
||||
bool xb = gate->_func[2*maskl+(1-maskr)] != masko;
|
||||
bool xc = gate->_func[2*(1-maskl)+maskr] != masko;
|
||||
bool xd = gate->_func[2*(1-maskl)+(1-maskr)] != masko;
|
||||
// printf("xa=%d, xb=%d, xc=%d, xd=%d\n", xa,xb,xc,xd);
|
||||
|
||||
// these are the 0-keys
|
||||
#ifdef __PURE_SHE__
|
||||
Key* outwire_start = _circuit->_sqr_keys + gate->_out*2*_N;
|
||||
#else
|
||||
Key* outwire_start = _circuit->_keys + gate->_out*2*_N;
|
||||
#ifdef DEBUG
|
||||
std::cout << "garbling gate " << g << std::endl ;
|
||||
#endif
|
||||
Key* keyxa = outwire_start + (xa?_N:0);
|
||||
Key* keyxb = outwire_start + (xb?_N:0);
|
||||
Key* keyxc = outwire_start + (xc?_N:0);
|
||||
Key* keyxd = outwire_start + (xd?_N:0);
|
||||
|
||||
for(party_id_t i=1; i<=_N; i++) {
|
||||
// std::cout << "adding to A = " << keyxa[i-1] << std::endl;
|
||||
// std::cout << "adding to B = " << keyxb[i-1] << std::endl;
|
||||
// std::cout << "adding to C = " << keyxc[i-1] << std::endl;
|
||||
// std::cout << "adding to D = " << keyxd[i-1] << std::endl;
|
||||
*(gg_A+i-1) += keyxa[i-1];
|
||||
*(gg_B+i-1) += keyxb[i-1];
|
||||
*(gg_C+i-1) += keyxc[i-1];
|
||||
*(gg_D+i-1) += keyxd[i-1];
|
||||
}
|
||||
Gate& gate = _circuit->_gates[g];
|
||||
registers[gate._out].garble(registers[gate._left], registers[gate._right],
|
||||
gate._func, &gate, g, prf_outputs, buffers[TYPE_GARBLED_CIRCUIT]);
|
||||
}
|
||||
|
||||
//sending to parties:
|
||||
fill_message_type(((char*)tbl_start)-4, TYPE_GARBLED_CIRCUIT );
|
||||
_node->Broadcast( ((char*)tbl_start)-4 , _G*4*_N*sizeof(Key)+sizeof(MSG_TYPE));
|
||||
|
||||
}
|
||||
|
||||
void TrustedParty::Start()
|
||||
void BaseTrustedParty::_compute_send_garbled_circuit()
|
||||
{
|
||||
SendBuffer& buffer = get_buffer(TYPE_GARBLED_CIRCUIT );
|
||||
buffer.allocate(get_garbled_tbl_size() * 4 * get_n_parties() * sizeof(Key));
|
||||
garble();
|
||||
//sending to parties:
|
||||
#ifdef DEBUG
|
||||
cout << "sending garbled circuit" << endl;
|
||||
#endif
|
||||
#ifdef DEBUG2
|
||||
phex(buffer);
|
||||
#endif
|
||||
_node->Broadcast(buffer);
|
||||
|
||||
//prepare_randomness();
|
||||
}
|
||||
|
||||
void BaseTrustedParty::Start()
|
||||
{
|
||||
_node->Start();
|
||||
}
|
||||
|
||||
/* keys - a 2*W*n keys buffer to be filled only in the places belong to party pid */
|
||||
void TrustedParty::_fill_keys_for_party(Key* keys, party_id_t pid)
|
||||
bool TrustedParty::_fill_keys()
|
||||
{
|
||||
int nullfd = open("/dev/urandom", O_RDONLY);
|
||||
|
||||
resize_registers();
|
||||
for (wire_id_t w=0; w<_W; w++) {
|
||||
read(nullfd, (char*)(keys+pid-1), sizeof(Key));
|
||||
read(nullfd, (char*)(keys+_N+pid-1), sizeof(Key));
|
||||
#ifdef __PRIME_FIELD__
|
||||
keys[pid-1].adjust();
|
||||
keys[pid+_N-1].adjust();
|
||||
#endif
|
||||
keys = keys + 2*_N;
|
||||
registers[w].init(randomfd, _N);
|
||||
add_keys(registers[w]);
|
||||
}
|
||||
close(nullfd);
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef __PURE_SHE__
|
||||
@@ -335,56 +359,98 @@ void TrustedParty::_fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Merge the two Key buffers into the dest buffer */
|
||||
void TrustedParty::_merge_keys(Key* dest, Key* src)
|
||||
{
|
||||
for (wire_id_t w=0; w<_W; w++) {
|
||||
for (int i=0; i<2; i++) {
|
||||
for (party_id_t p=0; p<_N; p++) {
|
||||
int offset = w*2*_N+i*_N+p;
|
||||
Key kk;
|
||||
kk = *(src + offset);
|
||||
dest[offset] += kk;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TrustedParty::_generate_masks ()
|
||||
{
|
||||
char* masks = new char[_W];
|
||||
fill_random(masks, _W);
|
||||
|
||||
//need to convert from total random to 0/1
|
||||
for (unsigned int i=0 ; i<_W; i++)
|
||||
{
|
||||
//because it is a SIGNED char ~half the chars whould be negetives
|
||||
masks[i] = masks[i]>0 ? 1 : 0;
|
||||
}
|
||||
_circuit->Masks(masks);
|
||||
// printf("masks\n");
|
||||
// phex(_circuit->Masks(), _W);
|
||||
}
|
||||
|
||||
|
||||
void TrustedParty::_allocate_prf_outputs()
|
||||
{
|
||||
unsigned int prf_output_size = (PRFS_PER_PARTY(_G, _N)*sizeof(Key)) *_N ;
|
||||
void* prf_outputs = malloc(prf_output_size);
|
||||
memset(prf_outputs,0,prf_output_size);
|
||||
_circuit->Prfs((char*)prf_outputs);
|
||||
}
|
||||
|
||||
void TrustedParty::_print_keys()
|
||||
{
|
||||
Key* key_idx = _circuit->_key(1,0,0);
|
||||
for (wire_id_t w=0; w<_W; w++) {
|
||||
for(int b=0; b<=1; b++) {
|
||||
for (party_id_t i=1; i<=_N; i++) {
|
||||
printf("k^%d_{%u,%d}: ",i,w,b);
|
||||
std::cout << *key_idx << std::endl;
|
||||
key_idx++;
|
||||
}
|
||||
registers[w].keys.print(w);
|
||||
}
|
||||
}
|
||||
|
||||
void TrustedProgramParty::NodeReady()
|
||||
{
|
||||
#ifdef FREE_XOR
|
||||
for (int i = 0; i < get_n_parties(); i++)
|
||||
{
|
||||
SendBuffer& buffer = get_buffer(TYPE_DELTA);
|
||||
buffer.serialize(deltas[i]);
|
||||
_node->Send(i + 1, buffer);
|
||||
}
|
||||
#endif
|
||||
this->BaseTrustedParty::NodeReady();
|
||||
}
|
||||
|
||||
bool TrustedProgramParty::_fill_keys()
|
||||
{
|
||||
for (int i = 0; i < SPDZ_OP_N; i++)
|
||||
{
|
||||
spdz_wires[i].clear();
|
||||
spdz_wires[i].resize(get_n_parties());
|
||||
}
|
||||
msg_output_masks = get_buffer(TYPE_MASK_OUTPUT);
|
||||
return GC::DONE_BREAK == first_phase(program, random_processor, random_machine);
|
||||
}
|
||||
|
||||
void TrustedProgramParty::garble()
|
||||
{
|
||||
second_phase(program, processor, machine);
|
||||
|
||||
vector< Share<gf2n> > tmp;
|
||||
make_share<gf2n>(tmp, 1, get_n_parties(), mac_key, prng);
|
||||
for (int i = 0; i < get_n_parties(); i++)
|
||||
tmp[i].get_mac().pack(spdz_wires[SPDZ_MAC][i]);
|
||||
for (int i = 0; i < get_n_parties(); i++)
|
||||
{
|
||||
for (int j = 0; j < SPDZ_OP_N; j++)
|
||||
{
|
||||
SendBuffer buffer;
|
||||
fill_message_type(buffer, TYPE_SPDZ_WIRES);
|
||||
buffer.serialize(j);
|
||||
buffer.serialize(spdz_wires[j][i].get_data(), spdz_wires[j][i].get_length());
|
||||
#ifdef DEBUG_SPDZ_WIRE
|
||||
cout << "send " << spdz_wires[j][i].get_length() << "/" << buffer.size()
|
||||
<< " bytes for type " << j << " to " << i << endl;
|
||||
#endif
|
||||
_node->Send(i + 1, buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TrustedProgramParty::store_spdz_wire(SpdzOp op, const Register& reg)
|
||||
{
|
||||
make_share(mask_shares, gf2n(reg.get_mask()), get_n_parties(), gf2n(get_mac_key()), prng);
|
||||
for (int i = 0; i < get_n_parties(); i++)
|
||||
{
|
||||
SpdzWire wire;
|
||||
wire.mask = mask_shares[i];
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
wire.my_keys[j] = reg.keys[j][i];
|
||||
}
|
||||
wire.pack(spdz_wires[op][i]);
|
||||
}
|
||||
#ifdef DEBUG_SPDZ_WIRE
|
||||
cout << "stored SPDZ wire of type " << op << ":" << endl;
|
||||
reg.keys.print(reg.get_id());
|
||||
#endif
|
||||
}
|
||||
|
||||
void TrustedProgramParty::store_wire(const Register& reg)
|
||||
{
|
||||
wires.serialize(reg.mask);
|
||||
reg.keys.serialize(wires);
|
||||
#ifdef DEBUG
|
||||
cout << "storing wire" << endl;
|
||||
reg.print();
|
||||
#endif
|
||||
}
|
||||
|
||||
void TrustedProgramParty::load_wire(Register& reg)
|
||||
{
|
||||
wires.unserialize(reg.mask);
|
||||
reg.keys.unserialize(wires);
|
||||
#ifdef DEBUG
|
||||
cout << "loading wire" << endl;
|
||||
reg.print();
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* TrustedParty.h
|
||||
*
|
||||
* Created on: Feb 15, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOL_TRUSTEDPARTY_H_
|
||||
@@ -13,48 +13,150 @@
|
||||
#include "network/Node.h"
|
||||
#include <atomic>
|
||||
|
||||
#include "Register.h"
|
||||
#include "CommonParty.h"
|
||||
|
||||
class TrustedParty : public NodeUpdatable {
|
||||
class BaseTrustedParty : virtual public CommonParty {
|
||||
public:
|
||||
vector<ReceivedMsg> prf_outputs;
|
||||
|
||||
BaseTrustedParty();
|
||||
virtual ~BaseTrustedParty() {}
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
virtual void NodeReady();
|
||||
void NewMessage(int from, ReceivedMsg& msg);
|
||||
void NodeAborted(struct sockaddr_in* from) { (void)from; }
|
||||
|
||||
void Start();
|
||||
|
||||
protected:
|
||||
boost::mutex _print_mx;
|
||||
|
||||
std::atomic_uint _num_prf_received;
|
||||
std::atomic_uint _received_gc_received;
|
||||
std::atomic_uint n_received;
|
||||
|
||||
vector<SendBuffer> msg_keys, msg_input_masks;
|
||||
|
||||
int randomfd;
|
||||
|
||||
bool done_filling;
|
||||
|
||||
#ifdef __PURE_SHE__
|
||||
mpz_t _temp_mpz;
|
||||
void _fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid);
|
||||
#endif
|
||||
virtual bool _fill_keys() = 0;
|
||||
|
||||
void _compute_send_garbled_circuit();
|
||||
virtual void _launch_online() = 0;
|
||||
|
||||
void prepare_randomness();
|
||||
void send_randomness();
|
||||
virtual void send_input_masks(party_id_t pid) = 0;
|
||||
virtual void send_output_masks() = 0;
|
||||
virtual void garble() = 0;
|
||||
|
||||
void add_keys(const Register& reg);
|
||||
};
|
||||
|
||||
class TrustedParty : public BaseTrustedParty, public CommonCircuitParty {
|
||||
|
||||
public:
|
||||
TrustedParty(const char* netmap_file, const char* circuit_file );
|
||||
virtual ~TrustedParty();
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
void NodeReady();
|
||||
void NewMessage(int from, char* message, unsigned int len);
|
||||
void NodeAborted(struct sockaddr_in* from) {}
|
||||
|
||||
void Start();
|
||||
|
||||
|
||||
/* TEST methods */
|
||||
|
||||
private:
|
||||
Node* _node;
|
||||
BooleanCircuit* _circuit;
|
||||
gate_id_t _G;
|
||||
party_id_t _N;
|
||||
wire_id_t _W;
|
||||
wire_id_t _OW;
|
||||
|
||||
boost::mutex _print_mx;
|
||||
|
||||
std::atomic_int _num_prf_received;
|
||||
std::atomic_int _received_gc_received;
|
||||
#ifdef __PURE_SHE__
|
||||
mpz_t _temp_mpz;
|
||||
void _fill_keys_for_party(Key* sqr_keys, Key* keys, party_id_t pid);
|
||||
#endif
|
||||
void _fill_keys_for_party(Key* keys, party_id_t pid);
|
||||
void _merge_keys(Key* dest, Key* src);
|
||||
void _generate_masks();
|
||||
void _allocate_prf_outputs();
|
||||
void _allocate_garbled_table();
|
||||
void _compute_send_garbled_circuit();
|
||||
bool _fill_keys();
|
||||
|
||||
void _launch_online();
|
||||
void _print_keys();
|
||||
|
||||
void send_input_masks(party_id_t pid);
|
||||
void send_output_masks();
|
||||
void garble();
|
||||
};
|
||||
|
||||
class TrustedProgramParty : public BaseTrustedParty {
|
||||
public:
|
||||
SendBuffer msg_output_masks;
|
||||
|
||||
TrustedProgramParty(int argc, char** argv);
|
||||
~TrustedProgramParty();
|
||||
|
||||
void NodeReady();
|
||||
|
||||
void store_spdz_wire(SpdzOp op, const Register& reg);
|
||||
|
||||
void store_wire(const Register& reg);
|
||||
void load_wire(Register& reg);
|
||||
|
||||
#ifdef FREE_XOR
|
||||
const Key& delta(int i) { return deltas[i]; }
|
||||
const KeyVector& get_deltas() { return deltas; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
friend class GarbleRegister;
|
||||
friend class RandomRegister;
|
||||
|
||||
static TrustedProgramParty* singleton;
|
||||
static TrustedProgramParty& s();
|
||||
|
||||
GC::Memory< GC::Secret<GarbleRegister>::DynamicType > dynamic_memory;
|
||||
GC::Machine< GC::Secret<GarbleRegister> > machine;
|
||||
GC::Processor< GC::Secret<GarbleRegister> > processor;
|
||||
GC::Program< GC::Secret<GarbleRegister> > program;
|
||||
|
||||
GC::Machine< GC::Secret<RandomRegister> > random_machine;
|
||||
GC::Processor< GC::Secret<RandomRegister> > random_processor;
|
||||
|
||||
#ifdef FREE_XOR
|
||||
KeyVector deltas;
|
||||
#endif
|
||||
|
||||
vector<octetStream> spdz_wires[SPDZ_OP_N];
|
||||
vector< Share<gf2n> > mask_shares;
|
||||
|
||||
Timer random_timer;
|
||||
|
||||
bool _fill_keys();
|
||||
void _launch_online();
|
||||
|
||||
void send_input_masks(party_id_t pid) { (void)pid; }
|
||||
void send_output_masks();
|
||||
void garble();
|
||||
|
||||
void add_all_keys(const Register& reg, bool external);
|
||||
};
|
||||
|
||||
|
||||
inline void BaseTrustedParty::add_keys(const Register& reg)
|
||||
{
|
||||
for(int p = 0; p < get_n_parties(); p++)
|
||||
reg.keys.serialize(msg_keys[p], p + 1);
|
||||
}
|
||||
|
||||
inline void TrustedProgramParty::add_all_keys(const Register& reg, bool external)
|
||||
{
|
||||
for(int p = 0; p < get_n_parties(); p++)
|
||||
for (int i = 0; i < get_n_parties(); i++)
|
||||
reg.keys[external][i].serialize(msg_keys[p]);
|
||||
}
|
||||
|
||||
inline TrustedProgramParty& TrustedProgramParty::s()
|
||||
{
|
||||
if (singleton)
|
||||
return *singleton;
|
||||
else
|
||||
throw runtime_error("no singleton");
|
||||
}
|
||||
|
||||
#endif /* PROTOCOL_TRUSTEDPARTY_H_ */
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#ifndef __WIRE_H__
|
||||
#define __WIRE_H__
|
||||
|
||||
#include <vector>
|
||||
#include <boost/thread/mutex.hpp>
|
||||
#include <boost/thread.hpp>
|
||||
|
||||
#include "Key.h"
|
||||
#include "common.h"
|
||||
|
||||
class Gate;
|
||||
|
||||
@@ -16,10 +17,10 @@ typedef char signal_t;
|
||||
typedef struct Wire {
|
||||
// signal_t _sig; // the actual value that is passing through the wire (after inputs have been set)
|
||||
std::vector<gate_id_t> _enters_to; //TODO make it a regular c array
|
||||
gate_id_t _out_from;
|
||||
bool _is_output;
|
||||
gate_id_t _out_from;
|
||||
|
||||
Wire(bool out):/*_sig(NO_SIGNAL),*/ _is_output(out),_out_from(NULL){}
|
||||
Wire(bool out):/*_sig(NO_SIGNAL),*/ _is_output(out),_out_from(0){}
|
||||
// inline signal_t Sig() { return _sig; }
|
||||
// inline void Sig(signal_t s) { _sig = s; }
|
||||
// void print(int w_id) {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#include "aes.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
42
BMR/common.h
42
BMR/common.h
@@ -1,15 +1,53 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* common.h
|
||||
*
|
||||
* Created on: Jan 21, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef CIRCUIT_INC_COMMON_H_
|
||||
#define CIRCUIT_INC_COMMON_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
typedef unsigned long wire_id_t;
|
||||
typedef unsigned long gate_id_t;
|
||||
|
||||
class Function {
|
||||
bool rep[4];
|
||||
int shift(int i) { return 4 * (3 - i); }
|
||||
public:
|
||||
Function() { memset(rep, 0, sizeof(rep)); }
|
||||
Function(std::string& func)
|
||||
{
|
||||
for (int i = 0; i < 4; i++)
|
||||
if (func[i] != '0')
|
||||
rep[i] = 1;
|
||||
else
|
||||
rep[i] = 0;
|
||||
}
|
||||
Function(int int_rep)
|
||||
{
|
||||
for (int i = 0; i < 4; i++)
|
||||
rep[i] = (int_rep << shift(i)) & 1;
|
||||
}
|
||||
uint8_t operator[](int i) { return rep[i]; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class CheckVector : public vector<T>
|
||||
{
|
||||
public:
|
||||
CheckVector() : vector<T>() {}
|
||||
CheckVector(size_t size) : vector<T>(size) {}
|
||||
CheckVector(size_t size, const T& def) : vector<T>(size, def) {}
|
||||
#ifdef CHECK_SIZE
|
||||
T& operator[](size_t i) { return this->at(i); }
|
||||
const T& operator[](size_t i) const { return this->at(i); }
|
||||
#else
|
||||
T& at(size_t i) { return (*this)[i]; }
|
||||
const T& at(size_t i) const { return (*this)[i]; }
|
||||
#endif
|
||||
};
|
||||
|
||||
#endif /* CIRCUIT_INC_COMMON_H_ */
|
||||
|
||||
16
BMR/msg_types.cpp
Normal file
16
BMR/msg_types.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* msg_types.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "msg_types.h"
|
||||
|
||||
#define X(NAME) #NAME,
|
||||
|
||||
const char* message_type_names[] = {
|
||||
MESSAGE_TYPES
|
||||
};
|
||||
|
||||
#undef X
|
||||
@@ -1,30 +1,42 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* msg_types.h
|
||||
*
|
||||
* Created on: Feb 16, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOL_INC_MSG_TYPES_H_
|
||||
#define PROTOCOL_INC_MSG_TYPES_H_
|
||||
|
||||
|
||||
typedef enum {
|
||||
TYPE_KEYS = 0,
|
||||
TYPE_MASK_INPUTS,
|
||||
TYPE_MASK_OUTPUT,
|
||||
TYPE_PRF_OUTPUTS,
|
||||
TYPE_GARBLED_CIRCUIT,
|
||||
TYPE_EXTERNAL_VALUES,
|
||||
TYPE_ALL_EXTERNAL_VALUES,
|
||||
TYPE_KEY_PER_IN_WIRE,
|
||||
TYPE_ALL_KEYS_PER_IN_WIRE,
|
||||
TYPE_LAUNCH_ONLINE,
|
||||
TYPE_RECEIVED_GC,
|
||||
#define MESSAGE_TYPES \
|
||||
X(TYPE_KEYS) \
|
||||
X(TYPE_MASK_INPUTS) \
|
||||
X(TYPE_MASK_OUTPUT) \
|
||||
X(TYPE_PRF_OUTPUTS) \
|
||||
X(TYPE_GARBLED_CIRCUIT) \
|
||||
X(TYPE_EXTERNAL_VALUES) \
|
||||
X(TYPE_ALL_EXTERNAL_VALUES) \
|
||||
X(TYPE_KEY_PER_IN_WIRE) \
|
||||
X(TYPE_ALL_KEYS_PER_IN_WIRE) \
|
||||
X(TYPE_LAUNCH_ONLINE) \
|
||||
X(TYPE_RECEIVED_GC) \
|
||||
X(TYPE_SPDZ_WIRES) \
|
||||
X(TYPE_DELTA) \
|
||||
X(TYPE_CHECKSUM) \
|
||||
X(TYPE_NEXT) \
|
||||
X(TYPE_DONE) \
|
||||
X(TYPE_MAX) \
|
||||
|
||||
TYPE_CHECKSUM,
|
||||
TYPE_MAX
|
||||
|
||||
#define X(NAME) NAME,
|
||||
|
||||
typedef enum {
|
||||
MESSAGE_TYPES
|
||||
} MSG_TYPE;
|
||||
|
||||
#undef X
|
||||
|
||||
extern const char* message_type_names[];
|
||||
|
||||
#endif /* PROTOCOL_INC_MSG_TYPES_H_ */
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Client.cpp
|
||||
*
|
||||
* Created on: Jan 27, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#include "Client.h"
|
||||
@@ -28,20 +28,15 @@ static void throw_bad_ip(const char* ip) {
|
||||
}
|
||||
|
||||
Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable, unsigned int max_message_size)
|
||||
:_numservers(numservers),
|
||||
_max_msg_sz(max_message_size),
|
||||
_updatable(updatable),
|
||||
_new_message(false)
|
||||
:_max_msg_sz(max_message_size),
|
||||
_numservers(numservers),
|
||||
_updatable(updatable)
|
||||
{
|
||||
_sockets = new int[_numservers](); // 0 initialized
|
||||
_servers = new struct sockaddr_in[_numservers];
|
||||
_msg_queues = new Queue<Msg>[_numservers]();
|
||||
_msg_queues = new WaitQueue< shared_ptr<SendBuffer> >[_numservers];
|
||||
|
||||
_lockqueue = new std::mutex[_numservers];
|
||||
_queuecheck = new std::condition_variable[_numservers];
|
||||
_new_message = new bool[_numservers]();
|
||||
|
||||
memset(_servers, 0, sizeof(_servers));
|
||||
memset(_servers, 0, sizeof(*_servers));
|
||||
|
||||
for (int i=0; i<_numservers; i++) {
|
||||
_sockets[i] = socket(AF_INET, SOCK_STREAM, 0);
|
||||
@@ -56,20 +51,32 @@ Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable
|
||||
}
|
||||
|
||||
Client::~Client() {
|
||||
Stop();
|
||||
for (int i=0; i<_numservers; i++)
|
||||
close(_sockets[i]);
|
||||
delete[] _sockets;
|
||||
delete[] _servers;
|
||||
delete[] _msg_queues;
|
||||
delete[] _lockqueue;
|
||||
delete[] _queuecheck;
|
||||
delete[] _new_message;
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Client:: Client deleted\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void Client::Connect() {
|
||||
for (int i=0; i<_numservers; i++)
|
||||
new boost::thread(&Client::_send_thread, this, i);
|
||||
new boost::thread(&Client::_connect, this);
|
||||
threads.add_thread(new boost::thread(&Client::_send_thread, this, i));
|
||||
threads.add_thread(new boost::thread(&Client::_connect, this));
|
||||
}
|
||||
|
||||
void Client::Stop() {
|
||||
for (int i=0; i<_numservers; i++)
|
||||
_msg_queues[i].stop();
|
||||
threads.join_all();
|
||||
for (int i=0; i<_numservers; i++)
|
||||
shutdown(_sockets[i], SHUT_RDWR);
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Stopped client\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
void Client::_connect() {
|
||||
@@ -89,95 +96,100 @@ void Client::_connect_to_server(int i) {
|
||||
int port = ntohs(_servers[i].sin_port);
|
||||
ip = inet_ntoa(_servers[i].sin_addr);
|
||||
int error = 0;
|
||||
int interval = 10000;
|
||||
int total_wait = 0;
|
||||
while (true ) {
|
||||
error = connect(_sockets[i], (struct sockaddr *)&_servers[i], sizeof(struct sockaddr));
|
||||
if (interval < CONNECT_INTERVAL)
|
||||
interval *= 2;
|
||||
if(!error)
|
||||
break;
|
||||
if (errno == 111) {
|
||||
fprintf(stderr,".");
|
||||
} else {
|
||||
fprintf(stderr,"Client:: Error (%d): connect to %s:%d: \"%s\"\n",errno, ip,port,strerror(errno));
|
||||
fprintf(stderr,"Client:: socket %d sleeping for %u usecs\n",i, CONNECT_INTERVAL);
|
||||
fprintf(stderr,"Client:: socket %d sleeping for %u usecs\n",i, interval);
|
||||
}
|
||||
usleep(CONNECT_INTERVAL);
|
||||
usleep(interval);
|
||||
total_wait += interval;
|
||||
if (total_wait > 60e6)
|
||||
throw runtime_error("waiting for too long");
|
||||
}
|
||||
|
||||
printf("\nClient:: connected to %s:%d\n", ip,port);
|
||||
setsockopt(_sockets[i], SOL_SOCKET, SO_SNDBUF, &BUFFER_SIZE, sizeof(BUFFER_SIZE));
|
||||
// Using the following disables the automatic buffer size (ipv4.tcp_wmem)
|
||||
// in favour of the core.wmem_max, which is worse.
|
||||
//setsockopt(_sockets[i], SOL_SOCKET, SO_SNDBUF, &NETWORK_BUFFER_SIZE, sizeof(NETWORK_BUFFER_SIZE));
|
||||
}
|
||||
|
||||
void Client::Send(int id, const char* message, unsigned int len) {
|
||||
Msg new_msg = {message, len};
|
||||
void Client::Send(int id, SendBuffer& buffer) {
|
||||
{
|
||||
std::unique_lock<std::mutex> locker(_lockqueue[id]);
|
||||
// printf ("Client:: queued %u bytes to %d\n", len, id);
|
||||
_msg_queues[id].Enqueue(new_msg);
|
||||
_new_message[id] = true;
|
||||
_queuecheck[id].notify_one();
|
||||
#ifdef DEBUG_COMM
|
||||
printf ("Client:: queued %u bytes to %d\n", buffer.size(), id);
|
||||
phex(buffer.data(), 4);
|
||||
#endif
|
||||
SendBuffer* tmp = new SendBuffer;
|
||||
*tmp = buffer;
|
||||
shared_ptr<SendBuffer> new_msg(tmp);
|
||||
_msg_queues[id].push(new_msg);
|
||||
}
|
||||
}
|
||||
|
||||
void Client::Broadcast(const char* message, unsigned int len) {
|
||||
void Client::Broadcast(SendBuffer& buffer) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf ("Client:: queued %u bytes to broadcast\n", buffer.size());
|
||||
phex(buffer.data(), 4);
|
||||
#endif
|
||||
SendBuffer* tmp = new SendBuffer;
|
||||
*tmp = buffer;
|
||||
shared_ptr<SendBuffer> new_msg(tmp);
|
||||
for(int i=0;i<_numservers; i++) {
|
||||
std::unique_lock<std::mutex> locker(_lockqueue[i]);
|
||||
Msg new_msg = {message, len};
|
||||
_msg_queues[i].Enqueue(new_msg);
|
||||
_new_message[i] = true;
|
||||
_queuecheck[i].notify_one();
|
||||
_msg_queues[i].push(new_msg);
|
||||
}
|
||||
}
|
||||
|
||||
void Client::Broadcast2(const char* message, unsigned int len) {
|
||||
void Client::Broadcast2(SendBuffer& buffer) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf ("Client:: queued %u bytes to broadcast to all but the server\n", buffer.size());
|
||||
phex(buffer.data(), 4);
|
||||
#endif
|
||||
SendBuffer* tmp = new SendBuffer;
|
||||
*tmp = buffer;
|
||||
shared_ptr<SendBuffer> new_msg(tmp);
|
||||
// first server is always the trusted party so we start with i=1
|
||||
for(int i=1;i<_numservers; i++) {
|
||||
std::unique_lock<std::mutex> locker(_lockqueue[i]);
|
||||
Msg new_msg = {message, len};
|
||||
_msg_queues[i].Enqueue(new_msg);
|
||||
_new_message[i] = true;
|
||||
_queuecheck[i].notify_one();
|
||||
_msg_queues[i].push(new_msg);
|
||||
}
|
||||
}
|
||||
|
||||
void Client::_send_thread(int i) {
|
||||
while(true)
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> locker(_lockqueue[i]);
|
||||
//printf("Client:: waiting for a notification to send to %d\n", i);
|
||||
_queuecheck[i].wait(locker);
|
||||
if (!_new_message[i]) {
|
||||
// printf("Client:: Spurious notification!\n");
|
||||
continue;
|
||||
}
|
||||
//printf("Client:: notified!!\n");
|
||||
}
|
||||
while (true)
|
||||
{
|
||||
Msg msg = {0};
|
||||
{
|
||||
std::unique_lock<std::mutex> locker(_lockqueue[i]);
|
||||
if(_msg_queues[i].Empty()) {
|
||||
//printf("Client:: no more messages in queue\n");
|
||||
break; // out of the inner while
|
||||
}
|
||||
msg = _msg_queues[i].Dequeue();
|
||||
}
|
||||
_send_blocking(msg, i);
|
||||
}
|
||||
_new_message[i] = false;
|
||||
}
|
||||
shared_ptr<SendBuffer> msg;
|
||||
while(_msg_queues[i].pop_dont_stop(msg))
|
||||
_send_blocking(*msg, i);
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Shutting down sender thread %d\n", i);
|
||||
#endif
|
||||
}
|
||||
|
||||
void Client::_send_blocking(Msg msg, int id) {
|
||||
// printf ("Client:: sending %u bytes to %d\n", msg.len, id);
|
||||
void Client::_send_blocking(SendBuffer& msg, int id) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf ("Client:: sending %llu bytes at 0x%x to %d\n", msg.size(), msg.data(), id);
|
||||
fflush(0);
|
||||
#ifdef DEBUG2
|
||||
phex(msg.data(), msg.size());
|
||||
#else
|
||||
phex(msg.data(), 4);
|
||||
#endif
|
||||
#endif
|
||||
int cur_sent = 0;
|
||||
cur_sent = send(_sockets[id], &msg.len, LENGTH_FIELD, 0);
|
||||
if(LENGTH_FIELD == cur_sent) {
|
||||
size_t len = msg.size();
|
||||
cur_sent = send(_sockets[id], &len, sizeof(len), 0);
|
||||
if(sizeof(len) == cur_sent) {
|
||||
unsigned int total_sent = 0;
|
||||
unsigned int remaining = 0;
|
||||
while(total_sent != msg.len) {
|
||||
remaining = (msg.len-total_sent)>_max_msg_sz ? _max_msg_sz : (msg.len-total_sent);
|
||||
cur_sent = send(_sockets[id], msg.msg+total_sent, remaining, 0);
|
||||
while(total_sent != msg.size()) {
|
||||
remaining = (msg.size()-total_sent)>_max_msg_sz ? _max_msg_sz : (msg.size()-total_sent);
|
||||
cur_sent = send(_sockets[id], msg.data()+total_sent, remaining, 0);
|
||||
//printf("Client:: msg.len=%u, remaining=%u, total_sent=%u, cur_sent = %d\n",msg.len, remaining, total_sent,cur_sent);
|
||||
if(cur_sent == -1) {
|
||||
fprintf(stderr,"Client:: Error: send msg failed: %s\n",strerror(errno));
|
||||
@@ -188,4 +200,10 @@ void Client::_send_blocking(Msg msg, int id) {
|
||||
} else if (-1 == cur_sent){
|
||||
fprintf(stderr,"Client:: Error: send header failed: %s\n",strerror(errno));
|
||||
}
|
||||
#ifdef DEBUG_COMM
|
||||
printf ("Client:: sent %u bytes at 0x%x to %d\n", msg.size(), msg.data(), id);
|
||||
fflush(0);
|
||||
phex(msg.data(), 4);
|
||||
fflush(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Client.h
|
||||
*
|
||||
* Created on: Jan 27, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_INC_CLIENT_H_
|
||||
@@ -17,6 +17,10 @@
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <boost/thread.hpp>
|
||||
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
#define CONNECT_INTERVAL (1000000)
|
||||
|
||||
@@ -34,22 +38,20 @@ public:
|
||||
virtual ~Client();
|
||||
|
||||
void Connect();
|
||||
void Send(int id, const char* message, unsigned int len);
|
||||
void Broadcast(const char* message, unsigned int len);
|
||||
void Broadcast2(const char* message, unsigned int len);
|
||||
void Send(int id, SendBuffer& new_msg);
|
||||
void Broadcast(SendBuffer& new_msg);
|
||||
void Broadcast2(SendBuffer& new_msg);
|
||||
void Stop();
|
||||
|
||||
private:
|
||||
|
||||
Queue<Msg>* _msg_queues;
|
||||
WaitQueue< shared_ptr<SendBuffer> >* _msg_queues;
|
||||
// std::queue<Msg>* _msg_queues;
|
||||
// boost::mutex* _msg_mux;
|
||||
// boost::mutex* _thd_mux;
|
||||
unsigned int _max_msg_sz;
|
||||
std::mutex* _lockqueue;
|
||||
std::condition_variable* _queuecheck;
|
||||
bool* _new_message;
|
||||
void _send_thread(int i);
|
||||
void _send_blocking(Msg msg, int id);
|
||||
void _send_blocking(SendBuffer& msg, int id);
|
||||
|
||||
void _connect();
|
||||
void _connect_to_server(int i);
|
||||
@@ -58,6 +60,8 @@ private:
|
||||
struct sockaddr_in* _servers;
|
||||
int* _sockets;
|
||||
ClientUpdatable* _updatable;
|
||||
|
||||
boost::thread_group threads;
|
||||
};
|
||||
|
||||
#endif /* NETWORK_INC_CLIENT_H_ */
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Node.cpp
|
||||
*
|
||||
* Created on: Jan 27, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
@@ -26,12 +26,12 @@ static void throw_bad_id(int id) {
|
||||
|
||||
Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num_parties)
|
||||
:_id(my_id),
|
||||
_updatable(updatable),
|
||||
_connected_to_servers(false),
|
||||
_num_parties_identified(0)
|
||||
_num_parties_identified(0),
|
||||
_updatable(updatable)
|
||||
{
|
||||
_parse_map(netmap_file, num_parties);
|
||||
unsigned int max_message_size = BUFFER_SIZE/2;
|
||||
unsigned int max_message_size = NETWORK_BUFFER_SIZE/2;
|
||||
if(_id < 0 || _id > _numparties)
|
||||
throw_bad_id(_id);
|
||||
_ready_nodes = new bool[_numparties](); //initialized to false
|
||||
@@ -41,21 +41,30 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num
|
||||
}
|
||||
|
||||
Node::~Node() {
|
||||
print_waiting();
|
||||
delete(_client);
|
||||
delete(_server);
|
||||
delete (_endpoints);
|
||||
delete (_ready_nodes);
|
||||
delete (_clients_connected);
|
||||
delete[] (_endpoints);
|
||||
delete[] (_ready_nodes);
|
||||
delete[] (_clients_connected);
|
||||
}
|
||||
|
||||
|
||||
void Node::Start() {
|
||||
_client->Connect();
|
||||
new boost::thread(&Node::_start, this);
|
||||
boost::thread(&Node::_start, this).join();
|
||||
_server->starter->join();
|
||||
for (unsigned int i = 0; i < _server->listeners.size(); i++)
|
||||
_server->listeners[i]->join();
|
||||
}
|
||||
|
||||
void Node::Stop() {
|
||||
_client->Stop();
|
||||
}
|
||||
|
||||
void Node::_start() {
|
||||
usleep(START_INTERVAL);
|
||||
int interval = 10000;
|
||||
int total_wait = 0;
|
||||
while(true) {
|
||||
bool all_ready = true;
|
||||
if (_connected_to_servers) {
|
||||
@@ -72,15 +81,27 @@ void Node::_start() {
|
||||
break;
|
||||
fprintf(stderr,"+");
|
||||
// fprintf(stderr,"Node:: waiting for all nodes to get ready ; sleeping for %u usecs\n", START_INTERVAL);
|
||||
usleep(START_INTERVAL);
|
||||
if (interval < CONNECT_INTERVAL)
|
||||
interval *= 2;
|
||||
usleep(interval);
|
||||
total_wait += interval;
|
||||
if (total_wait > 60e6)
|
||||
throw runtime_error("waiting for too long");
|
||||
}
|
||||
printf("All identified\n");
|
||||
_client->Broadcast(ALL_IDENTIFIED,strlen(ALL_IDENTIFIED));
|
||||
SendBuffer buffer;
|
||||
buffer.serialize(ALL_IDENTIFIED, strlen(ALL_IDENTIFIED));
|
||||
_client->Broadcast(buffer);
|
||||
}
|
||||
|
||||
void Node::NewMsg(char* msg, unsigned int len, struct sockaddr_in* from) {
|
||||
//printf("Node:: got message of length %d ",len);
|
||||
// printf("from %s:%d\n", inet_ntoa(from->sin_addr), ntohs(from->sin_port));
|
||||
void Node::NewMsg(ReceivedMsg& message, struct sockaddr_in* from) {
|
||||
char* msg = message.data();
|
||||
size_t len = message.size();
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Node:: got message of length %u at 0x%x\n",len,msg);
|
||||
printf("from %s:%d\n", inet_ntoa(from->sin_addr), ntohs(from->sin_port));
|
||||
phex(msg, 4);
|
||||
#endif
|
||||
if(len == strlen(ID_HDR)+sizeof(_id) && 0==strncmp(msg, ID_HDR, strlen(ID_HDR))) {
|
||||
int *id = (int*)(msg+strlen(ID_HDR));
|
||||
printf("Node:: identified as party: %d\n", *id);
|
||||
@@ -94,12 +115,16 @@ void Node::NewMsg(char* msg, unsigned int len, struct sockaddr_in* from) {
|
||||
printf("Node:: received ALL_IDENTIFIED from %d\n",_clientsmap[from]);
|
||||
_num_parties_identified++;
|
||||
if(_num_parties_identified == _numparties-1) {
|
||||
printf("Node:: received ALL_IDENTIFIED from ALL\n",_clientsmap[from]);
|
||||
printf("Node:: received ALL_IDENTIFIED from ALL\n");
|
||||
_updatable->NodeReady();
|
||||
}
|
||||
return;
|
||||
}
|
||||
_updatable->NewMessage(_clientsmap[from], msg, len );
|
||||
_updatable->NewMessage(_clientsmap[from], message);
|
||||
#ifdef DEBUG_COMM
|
||||
printf("finished with %d bytes at 0x%x\n", len, msg);
|
||||
phex(msg, 4);
|
||||
#endif
|
||||
}
|
||||
|
||||
void Node::ClientsConnected() {
|
||||
@@ -109,6 +134,7 @@ void Node::ClientsConnected() {
|
||||
void Node::NodeAborted(struct sockaddr_in* from)
|
||||
{
|
||||
printf("Node:: party %d has aborted\n",_clientsmap[from]);
|
||||
Stop();
|
||||
}
|
||||
|
||||
void Node::ConnectedToServers() {
|
||||
@@ -117,25 +143,43 @@ void Node::ConnectedToServers() {
|
||||
_identify();
|
||||
}
|
||||
|
||||
void Node::Send(int to, const char* msg, unsigned int len) {
|
||||
void Node::Send(int to, SendBuffer& msg) {
|
||||
int new_recipient = to>_id?to-1:to;
|
||||
//printf("Node:: new_recipient=%d\n",new_recipient);
|
||||
_client->Send(new_recipient, msg, len);
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Node:: new_recipient=%d\n",new_recipient);
|
||||
printf("Send %d bytes at 0x%x\n", msg.size(), msg.data());
|
||||
phex(msg.data(), 4);
|
||||
#endif
|
||||
_client->Send(new_recipient, msg);
|
||||
}
|
||||
|
||||
void Node::Broadcast(const char* msg, unsigned int len) {
|
||||
_client->Broadcast(msg, len);
|
||||
void Node::Broadcast(SendBuffer& msg) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Broadcast %d bytes at 0x%x\n", msg.size(), msg.data());
|
||||
phex(msg.data(), 4);
|
||||
#endif
|
||||
_client->Broadcast(msg);
|
||||
}
|
||||
void Node::Broadcast2(const char* msg, unsigned int len) {
|
||||
_client->Broadcast2(msg, len);
|
||||
void Node::Broadcast2(SendBuffer& msg) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Broadcast2 %d bytes at 0x%x\n", msg.size(), msg.data());
|
||||
phex(msg.data(), 4);
|
||||
#endif
|
||||
_client->Broadcast2(msg);
|
||||
}
|
||||
|
||||
void Node::_identify() {
|
||||
char* msg = new char[strlen(ID_HDR)+sizeof(_id)];
|
||||
char* msg = id_msg;
|
||||
strncpy(msg, ID_HDR, strlen(ID_HDR));
|
||||
strncpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
|
||||
//printf("Node:: identifying myself:\n");
|
||||
_client->Broadcast(msg,strlen(ID_HDR)+4);
|
||||
SendBuffer buffer;
|
||||
buffer.serialize(msg, strlen(ID_HDR)+4);
|
||||
#ifdef DEBUG_COMM
|
||||
cout << "message for identification:";
|
||||
phex(buffer.data(), 4);
|
||||
#endif
|
||||
_client->Broadcast(buffer);
|
||||
}
|
||||
|
||||
void Node::_parse_map(const char* netmap_file, int num_parties) {
|
||||
@@ -165,11 +209,27 @@ void Node::_parse_map(const char* netmap_file, int num_parties) {
|
||||
for(int i=0; i<_numparties; i++) {
|
||||
if(_id == i) {
|
||||
netmap >> _ip >> _port;
|
||||
//printf("Node:: my address: %s:%d\n", _ip.c_str(),_port);
|
||||
#ifdef DEBUG_NETMAP
|
||||
printf("Node:: my address: %s:%d\n", _ip.c_str(),_port);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
netmap >> _endpoints[j].ip >> _endpoints[j].port;
|
||||
#ifdef DEBUG_NETMAP
|
||||
printf("Node:: other address (%d): %s:%d\n", j,
|
||||
_endpoints[j].ip.c_str(), _endpoints[j].port);
|
||||
#endif
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Node::print_waiting()
|
||||
{
|
||||
for (unsigned i = 0; i < _server->timers.size(); i++)
|
||||
{
|
||||
cout << "Waited " << _server->timers[i].elapsed()
|
||||
<< " seconds for client "
|
||||
<< _clientsmap[_server->get_client_addr(i)] << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Node.h
|
||||
*
|
||||
* Created on: Jan 27, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_NODE_H_
|
||||
@@ -11,11 +11,14 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "Client.h"
|
||||
#include "Server.h"
|
||||
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
class Server;
|
||||
class Client;
|
||||
class ServerUpdatable;
|
||||
@@ -24,7 +27,7 @@ class ClientUpdatable;
|
||||
class NodeUpdatable {
|
||||
public:
|
||||
virtual void NodeReady()=0;
|
||||
virtual void NewMessage(int from, char* message, unsigned int len) =0;
|
||||
virtual void NewMessage(int from, ReceivedMsg& msg) =0;
|
||||
virtual void NodeAborted(struct sockaddr_in* from) =0;
|
||||
};
|
||||
|
||||
@@ -34,26 +37,29 @@ typedef void (*msg_id_cb_t)(int from, char* msg, unsigned int len);
|
||||
const char ALL_IDENTIFIED[] = "ALID";
|
||||
#define LOOPBACK NULL
|
||||
#define LOCALHOST_IP "127.0.0.1"
|
||||
#define PORT_BASE (4000)
|
||||
#define PORT_BASE (14000)
|
||||
|
||||
class Node : public ServerUpdatable, public ClientUpdatable {
|
||||
public:
|
||||
Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num_parties=0);
|
||||
virtual ~Node();
|
||||
void Send(int to, const char* msg, unsigned int len);
|
||||
void Broadcast(const char* msg, unsigned int len);
|
||||
void Broadcast2(const char* msg, unsigned int len);
|
||||
void Send(int to, SendBuffer& msg);
|
||||
void Broadcast(SendBuffer& msg);
|
||||
void Broadcast2(SendBuffer& msg);
|
||||
inline int NumParties(){return _numparties;}
|
||||
void Start();
|
||||
void Stop();
|
||||
// void Close();
|
||||
|
||||
//derived from ServerUpdateable
|
||||
void NewMsg(char* msg, unsigned int len, struct sockaddr_in* from);
|
||||
void NewMsg(ReceivedMsg& msg, struct sockaddr_in* from);
|
||||
void ClientsConnected();
|
||||
void NodeAborted(struct sockaddr_in* from);
|
||||
//derived from ClientUpdatable
|
||||
void ConnectedToServers();
|
||||
|
||||
void print_waiting();
|
||||
|
||||
private:
|
||||
void _parse_map(const char* netmap_file, int num_parties);
|
||||
void _identify();
|
||||
@@ -74,6 +80,8 @@ private:
|
||||
std::map<struct sockaddr_in*,int> _clientsmap;
|
||||
bool* _clients_connected;
|
||||
NodeUpdatable* _updatable;
|
||||
|
||||
char id_msg[strlen(ID_HDR)+sizeof(_id)];
|
||||
};
|
||||
|
||||
#endif /* NETWORK_NODE_H_ */
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include <stdio.h>
|
||||
#include <errno.h>
|
||||
@@ -9,19 +11,21 @@
|
||||
#include <boost/thread.hpp>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "Server.h"
|
||||
|
||||
|
||||
/* Opens server socket for listening - not yet accepting */
|
||||
Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size)
|
||||
:_port(port),
|
||||
:starter(0),
|
||||
_expected_clients(expected_clients),
|
||||
_port(port),
|
||||
_updatable(updatable),
|
||||
_max_msg_sz(max_message_size)
|
||||
{
|
||||
_clients = new int[expected_clients]();
|
||||
_clients_addr = new struct sockaddr_in[expected_clients]();
|
||||
timers.resize(expected_clients);
|
||||
|
||||
_servfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (-1 == _servfd)
|
||||
@@ -41,16 +45,21 @@ Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsig
|
||||
printf("Server:: Error listen: \n%s\n",strerror(errno));
|
||||
|
||||
|
||||
new boost::thread(&Server::_start_server, this);
|
||||
starter = new boost::thread(&Server::_start_server, this);
|
||||
}
|
||||
|
||||
Server::~Server() {
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Server:: Server being deleted\n");
|
||||
#endif
|
||||
close(_servfd);
|
||||
for (int i=0; i<_expected_clients; i++)
|
||||
close(_clients[i]);
|
||||
delete (_clients);
|
||||
delete (_clients_addr);
|
||||
delete[] (_clients);
|
||||
delete[] (_clients_addr);
|
||||
delete starter;
|
||||
for (unsigned i = 0; i < listeners.size(); i++)
|
||||
delete listeners[i];
|
||||
}
|
||||
|
||||
void Server::_start_server() {
|
||||
@@ -64,8 +73,10 @@ void Server::_start_server() {
|
||||
printf("Server:: accept: error in connecting socket\n%s\n",strerror(errno));
|
||||
} else {
|
||||
printf("Server:: Incoming connection from %s:%d\n",inet_ntoa(_clients_addr[i].sin_addr), ntohs(_clients_addr[i].sin_port));
|
||||
setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &BUFFER_SIZE, sizeof(BUFFER_SIZE));
|
||||
boost::thread* listener = new boost::thread(&Server::_listen_to_client, this, i);
|
||||
// Using the following disables the automatic buffer size (ipv4.tcp_rmem)
|
||||
// in favour of the core.rmem_max, which is worse.
|
||||
//setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &NETWORK_BUFFER_SIZE, sizeof(NETWORK_BUFFER_SIZE));
|
||||
listeners.push_back(new boost::thread(&Server::_listen_to_client, this, i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,23 +84,32 @@ void Server::_start_server() {
|
||||
}
|
||||
|
||||
void Server::_listen_to_client(int id){
|
||||
int msg_len = 0;
|
||||
size_t msg_len = 0;
|
||||
int n_recv = 0;
|
||||
unsigned int total_received;
|
||||
unsigned int remaining;
|
||||
char *msg;
|
||||
size_t total_received;
|
||||
size_t remaining;
|
||||
ReceivedMsg msg;
|
||||
#ifdef DEBUG_FLEXBUF
|
||||
cout << "msg from " << id << " stored at " << &msg << endl;
|
||||
#endif
|
||||
while (true) {
|
||||
n_recv = recv(_clients[id], &msg_len, LENGTH_FIELD, MSG_WAITALL);
|
||||
if (!_handle_recv_len(id, n_recv,LENGTH_FIELD))
|
||||
timers[id].start();
|
||||
n_recv = recv(_clients[id], &msg_len, sizeof(msg_len), MSG_WAITALL);
|
||||
timers[id].stop();
|
||||
#ifdef DEBUG_COMM
|
||||
cout << "message of size " << msg_len << " coming in from " << id << endl;
|
||||
#endif
|
||||
msg.resize(msg_len);
|
||||
if (!_handle_recv_len(id, n_recv, sizeof(msg_len)))
|
||||
return;
|
||||
// printf("Server:: waiting for a message of len = %d\n", msg_len);
|
||||
msg = new char[msg_len];
|
||||
assert(msg != NULL);
|
||||
total_received = 0;
|
||||
remaining = 0;
|
||||
while (total_received != msg_len) {
|
||||
remaining = (msg_len-total_received)>_max_msg_sz ? _max_msg_sz : (msg_len-total_received);
|
||||
n_recv = recv(_clients[id], msg+total_received, remaining, NULL /* MSG_WAITALL*/);
|
||||
timers[id].start();
|
||||
n_recv = recv(_clients[id], msg.data()+total_received, remaining, 0 /* MSG_WAITALL*/);
|
||||
timers[id].stop();
|
||||
// printf("n_recv = %d\n", n_recv);
|
||||
if (!_handle_recv_len(id, n_recv,remaining)) {
|
||||
printf("returning\n");
|
||||
@@ -99,19 +119,21 @@ void Server::_listen_to_client(int id){
|
||||
// printf("total_received = %d\n", total_received);
|
||||
}
|
||||
// printf("Server:: received %d: \n", msg_len);
|
||||
_updatable->NewMsg(msg, msg_len, &_clients_addr[id]);
|
||||
_updatable->NewMsg(msg, &_clients_addr[id]);
|
||||
}
|
||||
printf("stop listenning to %d\n", id);
|
||||
}
|
||||
|
||||
bool Server::_handle_recv_len(int id, unsigned int actual_len, unsigned int expected_len) {
|
||||
bool Server::_handle_recv_len(int id, size_t actual_len, size_t expected_len) {
|
||||
// printf("Server:: received msg from %d len = %u\n",id, actual_len);
|
||||
if (actual_len == 0) {
|
||||
#ifdef DEBUG_COMM
|
||||
printf("Server:: [%d]: Error: n_recv==0 Connection closed\n", id);
|
||||
#endif
|
||||
_updatable->NodeAborted(&_clients_addr[id]);
|
||||
return false;
|
||||
// exit(1);
|
||||
} else if (actual_len == -1) {
|
||||
} else if (actual_len == -1U) {
|
||||
printf("Server:: [%d]: Error: n_recv==-1. \"%s\"\n",id, strerror(errno));
|
||||
_updatable->NodeAborted(&_clients_addr[id]);
|
||||
return false;
|
||||
@@ -120,4 +142,5 @@ bool Server::_handle_recv_len(int id, unsigned int actual_len, unsigned int expe
|
||||
// printf("Server:: [%d]: Error: n_recv < %d; n_recv=%d; \"%d\"\n",id, expected_len, actual_len,strerror(errno));
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Server.h
|
||||
*
|
||||
* Created on: Jan 24, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_INC_SERVER_H_
|
||||
@@ -10,13 +10,15 @@
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
class ServerUpdatable {
|
||||
public:
|
||||
virtual void ClientsConnected()=0;
|
||||
virtual void NewMsg(char* msg, unsigned int len, struct sockaddr_in* from)=0;
|
||||
virtual void NewMsg(ReceivedMsg& msg, struct sockaddr_in* from)=0;
|
||||
virtual void NodeAborted(struct sockaddr_in* from) =0;
|
||||
};
|
||||
|
||||
@@ -25,6 +27,13 @@ public:
|
||||
Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size);
|
||||
~Server();
|
||||
|
||||
sockaddr_in* get_client_addr(int id) { return &_clients_addr[id]; }
|
||||
|
||||
boost::thread* starter;
|
||||
std::vector<boost::thread*> listeners;
|
||||
|
||||
vector<Timer> timers;
|
||||
|
||||
private:
|
||||
int _expected_clients;
|
||||
int *_clients;
|
||||
@@ -38,7 +47,7 @@ private:
|
||||
|
||||
void _start_server();
|
||||
void _listen_to_client(int id);
|
||||
bool _handle_recv_len(int id, unsigned int actual_len,unsigned int expected_len);
|
||||
bool _handle_recv_len(int id, size_t actual_len, size_t expected_len);
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* common.h
|
||||
*
|
||||
* Created on: Jan 27, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_INC_COMMON_H_
|
||||
@@ -13,23 +13,16 @@
|
||||
|
||||
//#include "utils.h"
|
||||
|
||||
#define LENGTH_FIELD (4)
|
||||
|
||||
/*
|
||||
* To change the buffer sizes in the kernel
|
||||
# echo 'net.core.wmem_max=12582912' >> /etc/sysctl.conf
|
||||
# echo 'net.core.rmem_max=12582912' >> /etc/sysctl.conf
|
||||
*/
|
||||
const int BUFFER_SIZE = 20000000;
|
||||
const int NETWORK_BUFFER_SIZE = 20000000;
|
||||
|
||||
typedef struct {
|
||||
std::string ip;
|
||||
int port;
|
||||
} endpoint_t;
|
||||
|
||||
typedef struct {
|
||||
const char* msg;
|
||||
unsigned int len;
|
||||
} Msg;
|
||||
|
||||
#endif /* NETWORK_INC_COMMON_H_ */
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
|
||||
|
||||
|
||||
|
||||
CPPFLAGS=-g -O3 -c -w --std=c++11 \
|
||||
-I/usr/include \
|
||||
-Iinc \
|
||||
$(LOCAL_CPPFLAGS)
|
||||
|
||||
LDFLAGS=-g -L/usr/lib/x86_64-linux-gnu/ \
|
||||
-lboost_system -lpthread -lboost_thread $(LOCAL_LDFLAGS)
|
||||
|
||||
OBJECTS=build/test.o \
|
||||
build/utils.o \
|
||||
build/Server.o \
|
||||
build/Client.o \
|
||||
build/Node.o
|
||||
|
||||
|
||||
all:bin/net
|
||||
|
||||
bin/net:$(OBJECTS)
|
||||
g++ -Wall $^ -o bin/net $(LDFLAGS)
|
||||
|
||||
|
||||
build/Server.o:src/Server.cpp
|
||||
$(CC) $(CPPFLAGS) $^ -o $@
|
||||
|
||||
build/Client.o:src/Client.cpp
|
||||
$(CC) $(CPPFLAGS) $^ -o $@
|
||||
|
||||
build/Node.o:src/Node.cpp
|
||||
$(CC) $(CPPFLAGS) $^ -o $@
|
||||
|
||||
build/test.o:test/test.cpp
|
||||
$(CC) $(CPPFLAGS) $^ -o $@
|
||||
|
||||
build/utils.o:src/utils.cpp
|
||||
$(CC) $(CPPFLAGS) $^ -o $@
|
||||
|
||||
|
||||
clean:
|
||||
rm build/* bin/net
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* mq.h
|
||||
*
|
||||
* Created on: Feb 14, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_INC_MQ_H_
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* utils.cpp
|
||||
*
|
||||
* Created on: Jan 31, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
@@ -22,7 +23,7 @@ void fill_random(void* buffer, unsigned int length)
|
||||
}
|
||||
|
||||
char cs(char* msg, unsigned int len, char result) {
|
||||
for(int i = 0; i < len; i++)
|
||||
for(size_t i = 0; i < len; i++)
|
||||
result += msg[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* utils.h
|
||||
*
|
||||
* Created on: Jan 31, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef NETWORK_TEST_UTILS_H_
|
||||
|
||||
16
BMR/prf.cpp
16
BMR/prf.cpp
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* prf.cpp
|
||||
*
|
||||
* Created on: Feb 18, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
|
||||
@@ -10,22 +10,14 @@
|
||||
#include "aes.h"
|
||||
#include "proto_utils.h"
|
||||
|
||||
void PRF_single(const Key* key, char* input, char* output)
|
||||
void PRF_single(const Key& key, char* input, char* output)
|
||||
{
|
||||
// printf("prf_single\n");
|
||||
// std::cout << *key;
|
||||
// phex(input, 16);
|
||||
AES_KEY aes_key;
|
||||
AES_128_Key_Expansion((const unsigned char*)(&(key->r)), &aes_key);
|
||||
AES_128_Key_Expansion((const unsigned char*)(&(key.r)), &aes_key);
|
||||
aes_key.rounds=10;
|
||||
AES_encryptC((block*)input, (block*)output, &aes_key);
|
||||
// phex(output, 16);
|
||||
}
|
||||
|
||||
void PRF_chunk(const Key* key, char* input, char* output, int number)
|
||||
{
|
||||
AES_KEY aes_key;
|
||||
AES_128_Key_Expansion((const unsigned char*)(&(key->r)), &aes_key);
|
||||
aes_key.rounds=10;
|
||||
AES_ecb_encrypt_chunk_in_out((block*)input, (block*)output, number, &aes_key);
|
||||
}
|
||||
|
||||
29
BMR/prf.h
29
BMR/prf.h
@@ -1,16 +1,37 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* prf.h
|
||||
*
|
||||
* Created on: Feb 18, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOL_INC_PRF_H_
|
||||
#define PROTOCOL_INC_PRF_H_
|
||||
|
||||
#include "Key.h"
|
||||
#include "aes.h"
|
||||
|
||||
void PRF_single(const Key* key, char* input, char* output);
|
||||
void PRF_chunk(const Key* key, char* input, char* output, int number);
|
||||
#include "Tools/aes.h"
|
||||
|
||||
void PRF_single(const Key& key, char* input, char* output);
|
||||
|
||||
inline void PRF_chunk(const Key& key, char* input, char* output, int number)
|
||||
{
|
||||
__m128i* in = (__m128i*)input;
|
||||
__m128i* out = (__m128i*)output;
|
||||
AES_KEY aes_key;
|
||||
AES_128_Key_Expansion((unsigned char*)&key.r, &aes_key);
|
||||
switch (number)
|
||||
{
|
||||
case 2:
|
||||
ecb_aes_128_encrypt<2>(out, in, (octet*)aes_key.rd_key);
|
||||
break;
|
||||
case 3:
|
||||
ecb_aes_128_encrypt<3>(out, in, (octet*)aes_key.rd_key);
|
||||
break;
|
||||
default:
|
||||
throw not_implemented();
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* PROTOCOL_INC_PRF_H_ */
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* utils.cpp
|
||||
*
|
||||
* Created on: Jan 31, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
|
||||
@@ -21,12 +21,9 @@ void fill_message_type(void* buffer, MSG_TYPE type)
|
||||
memcpy(buffer, &type, sizeof(MSG_TYPE));
|
||||
}
|
||||
|
||||
void fill_message_type(vector<char>& buffer, MSG_TYPE type)
|
||||
void fill_message_type(SendBuffer& buffer, MSG_TYPE type)
|
||||
{
|
||||
// compatibility
|
||||
buffer.resize(buffer.size() + 4);
|
||||
char* start = (char*)&type;
|
||||
copy(start, start + sizeof(MSG_TYPE), buffer.end() - 4);
|
||||
buffer.serialize(type);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* utils.h
|
||||
*
|
||||
* Created on: Jan 31, 2016
|
||||
* Author: bush
|
||||
*/
|
||||
|
||||
#ifndef PROTO_UTILS_H_
|
||||
@@ -11,23 +11,30 @@
|
||||
#include "msg_types.h"
|
||||
#include <time.h>
|
||||
#include <sys/time.h>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/avx_memcpy.h"
|
||||
#include "Tools/FlexBuffer.h"
|
||||
|
||||
#define LOOPBACK_STR "LOOPBACK"
|
||||
|
||||
void fill_random(void* buffer, unsigned int length);
|
||||
|
||||
void fill_message_type(void* buffer, MSG_TYPE type);
|
||||
class SendBuffer;
|
||||
|
||||
char cs(char* msg, unsigned int len, char result=0);
|
||||
void fill_message_type(void* buffer, MSG_TYPE type);
|
||||
void fill_message_type(SendBuffer& buffer, MSG_TYPE type);
|
||||
|
||||
void phex (const void *addr, int len);
|
||||
|
||||
//inline void xor_big(const char* input1, const char* input2, char* output);
|
||||
|
||||
|
||||
inline timeval* GET_TIME() {
|
||||
struct timeval* now = new struct timeval();
|
||||
int rc = gettimeofday(now, 0);
|
||||
inline timeval GET_TIME() {
|
||||
struct timeval now;
|
||||
int rc = gettimeofday(&now, 0);
|
||||
if (rc != 0) {
|
||||
perror("gettimeofday");
|
||||
}
|
||||
@@ -46,6 +53,29 @@ inline unsigned long PRINT_DIFF(struct timeval* before, struct timeval* after) {
|
||||
return diff;
|
||||
}
|
||||
|
||||
inline void phex(const FlexBuffer& buffer) { phex(buffer.data(), buffer.size()); }
|
||||
|
||||
inline void print_bit_array(const char* bits, int len)
|
||||
{
|
||||
for (int i = 0; i < len; i++)
|
||||
{
|
||||
if (i % 8 == 0)
|
||||
cout << " ";
|
||||
cout << (int)bits[i];
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
inline void print_bit_array(const vector<char>& bits)
|
||||
{
|
||||
print_bit_array(bits.data(), bits.size());
|
||||
}
|
||||
|
||||
inline void print_indices(const vector<int>& indices)
|
||||
{
|
||||
for (unsigned i = 0; i < indices.size(); i++)
|
||||
cout << indices[i] << " ";
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
#endif /* NETWORK_TEST_UTILS_H_ */
|
||||
|
||||
62
CHANGELOG.md
62
CHANGELOG.md
@@ -1,62 +0,0 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here.
|
||||
|
||||
## 0.0.3 (Mar 2, 2018)
|
||||
|
||||
- Added offline phases based on homomorphic encryption, used in the [SPDZ-2 paper](https://eprint.iacr.org/2012/642) and the [Overdrive paper](https://eprint.iacr.org/2017/1230).
|
||||
- On macOS, the minimum requirement is now Sierra.
|
||||
- Compilation with LLVM/clang is now possible (tested with 3.8).
|
||||
|
||||
## 0.0.2 (Sep 13, 2017)
|
||||
|
||||
### Support sockets based external client input and output to a SPDZ MPC program.
|
||||
|
||||
See the [ExternalIO directory](./ExternalIO/README.md) for more details and examples.
|
||||
|
||||
Note that [libsodium](https://download.libsodium.org/doc/) is now a dependency on the SPDZ build.
|
||||
|
||||
Added compiler instructions:
|
||||
|
||||
* LISTEN
|
||||
* ACCEPTCLIENTCONNECTION
|
||||
* CONNECTIPV4
|
||||
* WRITESOCKETSHARE
|
||||
* WRITESOCKETINT
|
||||
|
||||
Removed instructions:
|
||||
|
||||
* OPENSOCKET
|
||||
* CLOSESOCKET
|
||||
|
||||
Modified instructions:
|
||||
|
||||
* READSOCKETC
|
||||
* READSOCKETS
|
||||
* READSOCKETINT
|
||||
* WRITESOCKETC
|
||||
* WRITESOCKETS
|
||||
|
||||
Support secure external client input and output with new instructions:
|
||||
|
||||
* READCLIENTPUBLICKEY
|
||||
* INITSECURESOCKET
|
||||
* RESPSECURESOCKET
|
||||
|
||||
### Read/Write secret shares to disk to support persistence in a SPDZ MPC program.
|
||||
|
||||
Added compiler instructions:
|
||||
|
||||
* READFILESHARE
|
||||
* WRITEFILESHARE
|
||||
|
||||
### Other instructions
|
||||
|
||||
Added compiler instructions:
|
||||
|
||||
* DIGESTC - Clear truncated hash computation
|
||||
* PRINTINT - Print register value
|
||||
|
||||
## 0.0.1 (Sep 2, 2016)
|
||||
|
||||
### Initial Release
|
||||
|
||||
* See `README.md` and `tutorial.md`.
|
||||
10
CONFIG
10
CONFIG
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
ROOT = .
|
||||
|
||||
@@ -17,6 +17,10 @@ USE_GF2N_LONG = 0
|
||||
# AVX2 support (Haswell or later) changes the bit matrix transpose
|
||||
ARCH = -mtune=native -mavx
|
||||
|
||||
# defaults for BMR, change number of parties here
|
||||
CFLAGS = -DN_PARTIES=2 -DFREE_XOR -DKEY_SIGNAL -DSPDZ_AUTH -DNO_INPUT -DMAX_INLINE
|
||||
USE_GF2N_LONG = 1
|
||||
|
||||
#use CONFIG.mine to overwrite DIR settings
|
||||
-include CONFIG.mine
|
||||
|
||||
@@ -39,8 +43,10 @@ ifeq ($(OS), Linux)
|
||||
LDLIBS += -lrt
|
||||
endif
|
||||
|
||||
BOOST = -lboost_system -lboost_thread $(MY_BOOST)
|
||||
|
||||
CXX = g++
|
||||
CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = g++
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Check-Offline.cpp
|
||||
|
||||
2
Compiler/GC/__init__.py
Normal file
2
Compiler/GC/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
166
Compiler/GC/instructions.py
Normal file
166
Compiler/GC/instructions.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import Compiler.instructions_base as base
|
||||
import Compiler.instructions as spdz
|
||||
import Compiler.tools as tools
|
||||
import collections
|
||||
import itertools
|
||||
|
||||
class SecretBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'sb'
|
||||
class ClearBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'cb'
|
||||
|
||||
base.ArgFormats['sb'] = SecretBitsAF
|
||||
base.ArgFormats['sbw'] = SecretBitsAF
|
||||
base.ArgFormats['cb'] = ClearBitsAF
|
||||
base.ArgFormats['cbw'] = ClearBitsAF
|
||||
|
||||
opcodes = dict(
|
||||
XORS = 0x200,
|
||||
XORM = 0x201,
|
||||
ANDRS = 0x202,
|
||||
BITDECS = 0x203,
|
||||
BITCOMS = 0x204,
|
||||
CONVSINT = 0x205,
|
||||
LDMSDI = 0x206,
|
||||
STMSDI = 0x207,
|
||||
LDMSD = 0x208,
|
||||
STMSD = 0x209,
|
||||
LDBITS = 0x20a,
|
||||
XORCI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
CONVCINT = 0x213,
|
||||
REVEAL = 0x214,
|
||||
STMSDCI = 0x215,
|
||||
)
|
||||
|
||||
class xors(base.Instruction):
|
||||
code = opcodes['XORS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
class xorm(base.Instruction):
|
||||
code = opcodes['XORM']
|
||||
arg_format = ['int','sbw','sb','cb']
|
||||
|
||||
class xorc(base.Instruction):
|
||||
code = base.opcodes['XORC']
|
||||
arg_format = ['cbw','cb','cb']
|
||||
|
||||
class xorci(base.Instruction):
|
||||
code = opcodes['XORCI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class andrs(base.Instruction):
|
||||
code = opcodes['ANDRS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
class addc(base.Instruction):
|
||||
code = base.opcodes['ADDC']
|
||||
arg_format = ['cbw','cb','cb']
|
||||
|
||||
class addci(base.Instruction):
|
||||
code = base.opcodes['ADDCI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class mulci(base.Instruction):
|
||||
code = base.opcodes['MULCI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class bitdecs(base.VarArgsInstruction):
|
||||
code = opcodes['BITDECS']
|
||||
arg_format = tools.chain(['sb'], itertools.repeat('sbw'))
|
||||
|
||||
class bitcoms(base.VarArgsInstruction):
|
||||
code = opcodes['BITCOMS']
|
||||
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
|
||||
|
||||
class bitdecc(base.VarArgsInstruction):
|
||||
code = opcodes['BITDECC']
|
||||
arg_format = tools.chain(['cb'], itertools.repeat('cbw'))
|
||||
|
||||
class shrci(base.Instruction):
|
||||
code = base.opcodes['SHRCI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class ldbits(base.Instruction):
|
||||
code = opcodes['LDBITS']
|
||||
arg_format = ['sbw','i','i']
|
||||
|
||||
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMS']
|
||||
arg_format = ['sbw','int']
|
||||
|
||||
class stms(base.DirectMemoryWriteInstruction):
|
||||
code = base.opcodes['STMS']
|
||||
arg_format = ['sb','int']
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super(type(self), self).__init__(*args, **kwargs)
|
||||
# import inspect
|
||||
# self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
|
||||
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMC']
|
||||
arg_format = ['cbw','int']
|
||||
|
||||
class stmc(base.DirectMemoryWriteInstruction):
|
||||
code = base.opcodes['STMC']
|
||||
arg_format = ['cb','int']
|
||||
|
||||
class ldmsi(base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMSI']
|
||||
arg_format = ['sbw','ci']
|
||||
|
||||
class stmsi(base.WriteMemoryInstruction):
|
||||
code = base.opcodes['STMSI']
|
||||
arg_format = ['sb','ci']
|
||||
|
||||
class ldmsdi(base.ReadMemoryInstruction):
|
||||
code = opcodes['LDMSDI']
|
||||
arg_format = tools.cycle(['sbw','cb','int'])
|
||||
|
||||
class stmsdi(base.WriteMemoryInstruction):
|
||||
code = opcodes['STMSDI']
|
||||
arg_format = tools.cycle(['sb','cb'])
|
||||
|
||||
class ldmsd(base.ReadMemoryInstruction):
|
||||
code = opcodes['LDMSD']
|
||||
arg_format = tools.cycle(['sbw','int','int'])
|
||||
|
||||
class stmsd(base.WriteMemoryInstruction):
|
||||
code = opcodes['STMSD']
|
||||
arg_format = tools.cycle(['sb','int'])
|
||||
|
||||
class stmsdci(base.WriteMemoryInstruction):
|
||||
code = opcodes['STMSDCI']
|
||||
arg_format = tools.cycle(['cb','cb'])
|
||||
|
||||
class convsint(base.Instruction):
|
||||
code = opcodes['CONVSINT']
|
||||
arg_format = ['int','sbw','ci']
|
||||
|
||||
class convcint(base.Instruction):
|
||||
code = opcodes['CONVCINT']
|
||||
arg_format = ['cbw','ci']
|
||||
|
||||
class movs(base.Instruction):
|
||||
code = base.opcodes['MOVS']
|
||||
arg_format = ['sbw','sb']
|
||||
|
||||
class bit(base.Instruction):
|
||||
code = base.opcodes['BIT']
|
||||
arg_format = ['sbw']
|
||||
|
||||
class reveal(base.Instruction):
|
||||
code = opcodes['REVEAL']
|
||||
arg_format = ['int','cbw','sb']
|
||||
|
||||
class print_reg(base.IOInstruction):
|
||||
code = base.opcodes['PRINTREG']
|
||||
arg_format = ['cb','i']
|
||||
def __init__(self, reg, comment=''):
|
||||
super(print_reg, self).__init__(reg, self.str_to_int(comment))
|
||||
|
||||
class print_reg_plain(base.IOInstruction):
|
||||
code = base.opcodes['PRINTREGPLAIN']
|
||||
arg_format = ['cb']
|
||||
12
Compiler/GC/program.py
Normal file
12
Compiler/GC/program.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler import types, instructions
|
||||
|
||||
class Program(object):
|
||||
def __init__(self, progname):
|
||||
types.program = self
|
||||
instructions.program = self
|
||||
self.curr_tape = None
|
||||
execfile(progname)
|
||||
def malloc(self, *args):
|
||||
pass
|
||||
352
Compiler/GC/types.py
Normal file
352
Compiler/GC/types.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.types import MemValue, read_mem_value, regint, Array
|
||||
from Compiler.program import Tape, Program
|
||||
from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
|
||||
class bits(Tape.Register):
|
||||
n = 40
|
||||
size = 1
|
||||
PreOp = staticmethod(floatingpoint.PreOpN)
|
||||
MemValue = staticmethod(lambda value: MemValue(value))
|
||||
@staticmethod
|
||||
def PreOR(l):
|
||||
return [1 - x for x in \
|
||||
floatingpoint.PreOpN(operator.mul, \
|
||||
[1 - x for x in l])]
|
||||
@classmethod
|
||||
def get_type(cls, length):
|
||||
if length is None:
|
||||
return cls
|
||||
elif length == 1:
|
||||
return cls.bit_type
|
||||
if length not in cls.types:
|
||||
class bitsn(cls):
|
||||
n = length
|
||||
cls.types[length] = bitsn
|
||||
bitsn.__name__ = cls.__name__ + str(length)
|
||||
return cls.types[length]
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cls):
|
||||
return other
|
||||
elif isinstance(other, MemValue):
|
||||
return cls.conv(other.read())
|
||||
else:
|
||||
res = cls()
|
||||
res.load_other(other)
|
||||
return res
|
||||
hard_conv = conv
|
||||
@classmethod
|
||||
def compose(cls, items, bit_length):
|
||||
return cls.bit_compose(sum([item.bit_decompose(bit_length) for item in items], []))
|
||||
@classmethod
|
||||
def bit_compose(cls, bits):
|
||||
if len(bits) == 1:
|
||||
return bits[0]
|
||||
bits = list(bits)
|
||||
res = cls.new(n=len(bits))
|
||||
cls.bitcom(res, *bits)
|
||||
res.decomposed = bits
|
||||
return res
|
||||
def bit_decompose(self, bit_length=None):
|
||||
n = bit_length or self.n
|
||||
if n > self.n:
|
||||
raise Exception('wanted %d bits, only got %d' % (n, self.n))
|
||||
if n == 1:
|
||||
return [self]
|
||||
if self.decomposed is None or len(self.decomposed) < n:
|
||||
res = [self.bit_type() for i in range(n)]
|
||||
self.bitdec(self, *res)
|
||||
self.decomposed = res
|
||||
return res
|
||||
else:
|
||||
return self.decomposed[:n]
|
||||
@classmethod
|
||||
def load_mem(cls, address, mem_type=None):
|
||||
res = cls()
|
||||
if mem_type == 'sd':
|
||||
return cls.load_dynamic_mem(address)
|
||||
else:
|
||||
cls.load_inst[util.is_constant(address)](res, address)
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, (int, long))](self, address)
|
||||
def __init__(self, value=None, n=None):
|
||||
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape)
|
||||
self.set_length(n or self.n)
|
||||
if value is not None:
|
||||
self.load_other(value)
|
||||
self.decomposed = None
|
||||
def set_length(self, n):
|
||||
if n > self.max_length:
|
||||
print self.max_length
|
||||
raise Exception('too long: %d' % n)
|
||||
self.n = n
|
||||
def load_other(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
self.set_length(self.n or util.int_len(other))
|
||||
self.load_int(other)
|
||||
elif isinstance(other, regint):
|
||||
self.conv_regint(self.n, self, other)
|
||||
elif isinstance(self, type(other)) or isinstance(other, type(self)):
|
||||
self.mov(self, other)
|
||||
else:
|
||||
raise CompilerError('cannot convert from %s to %s' % \
|
||||
(type(other), type(self)))
|
||||
def __repr__(self):
|
||||
return '%s(%d/%d)' % \
|
||||
(super(bits, self).__repr__(), self.n, type(self).n)
|
||||
|
||||
class cbits(bits):
|
||||
max_length = 64
|
||||
reg_type = 'cb'
|
||||
is_clear = True
|
||||
load_inst = (None, inst.ldmc)
|
||||
store_inst = (None, inst.stmc)
|
||||
bitdec = inst.bitdecc
|
||||
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
self.load_other(regint(value))
|
||||
def store_in_dynamic_mem(self, address):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
def clear_op(self, other, c_inst, ci_inst, op):
|
||||
if isinstance(other, cbits):
|
||||
res = cbits(n=max(self.n, other.n))
|
||||
c_inst(res, self, other)
|
||||
return res
|
||||
else:
|
||||
if util.is_constant(other):
|
||||
if other >= 2**31 or other < -2**31:
|
||||
return op(self, cbits(other))
|
||||
res = cbits(n=max(self.n, len(bin(other)) - 2))
|
||||
ci_inst(res, self, other)
|
||||
return res
|
||||
else:
|
||||
return op(self, cbits(other))
|
||||
__add__ = lambda self, other: \
|
||||
self.clear_op(other, inst.addc, inst.addci, operator.add)
|
||||
__xor__ = lambda self, other: \
|
||||
self.clear_op(other, inst.xorc, inst.xorci, operator.xor)
|
||||
__radd__ = __add__
|
||||
__rxor__ = __xor__
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, cbits):
|
||||
return NotImplemented
|
||||
else:
|
||||
try:
|
||||
res = cbits(n=min(self.max_length, self.n+util.int_len(other)))
|
||||
inst.mulci(res, self, other)
|
||||
return res
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
def __rshift__(self, other):
|
||||
res = cbits(n=self.n-other)
|
||||
inst.shrci(res, self, other)
|
||||
return res
|
||||
def print_reg(self, desc=''):
|
||||
inst.print_reg(self, desc)
|
||||
def print_reg_plain(self):
|
||||
inst.print_reg_plain(self)
|
||||
output = print_reg_plain
|
||||
def reveal(self):
|
||||
return self
|
||||
|
||||
class sbits(bits):
|
||||
max_length = 128
|
||||
reg_type = 'sb'
|
||||
is_clear = False
|
||||
clear_type = cbits
|
||||
default_type = cbits
|
||||
load_inst = (inst.ldmsi, inst.ldms)
|
||||
store_inst = (inst.stmsi, inst.stms)
|
||||
bitdec = inst.bitdecs
|
||||
bitcom = inst.bitcoms
|
||||
conv_regint = inst.convsint
|
||||
mov = inst.movs
|
||||
types = {}
|
||||
def __init__(self, *args, **kwargs):
|
||||
bits.__init__(self, *args, **kwargs)
|
||||
@staticmethod
|
||||
def new(value=None, n=None):
|
||||
if n == 1:
|
||||
return sbit(value)
|
||||
else:
|
||||
return sbits.get_type(n)(value)
|
||||
@staticmethod
|
||||
def get_random_bit():
|
||||
res = sbit()
|
||||
inst.bit(res)
|
||||
return res
|
||||
@classmethod
|
||||
def load_dynamic_mem(cls, address):
|
||||
res = cls()
|
||||
if isinstance(address, (long, int)):
|
||||
inst.ldmsd(res, address, cls.n)
|
||||
else:
|
||||
inst.ldmsdi(res, address, cls.n)
|
||||
return res
|
||||
def store_in_dynamic_mem(self, address):
|
||||
if isinstance(address, (long, int)):
|
||||
inst.stmsd(self, address)
|
||||
else:
|
||||
inst.stmsdi(self, cbits.conv(address))
|
||||
def load_int(self, value):
|
||||
if abs(value) < 2**31:
|
||||
if (abs(value) > (1 << self.n)):
|
||||
raise Exception('public value %d longer than %d bits' \
|
||||
% (value, self.n))
|
||||
inst.ldbits(self, self.n, value)
|
||||
else:
|
||||
value %= 2**self.n
|
||||
if value >> 64 != 0:
|
||||
raise NotImplementedError('public value too large')
|
||||
self.load_other(regint(value))
|
||||
@read_mem_value
|
||||
def __add__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self.xor_int(other)
|
||||
else:
|
||||
if not isinstance(other, sbits):
|
||||
other = sbits(other)
|
||||
n = self.n
|
||||
else:
|
||||
n = max(self.n, other.n)
|
||||
res = self.new(n=n)
|
||||
inst.xors(n, res, self, other)
|
||||
return res
|
||||
__radd__ = __add__
|
||||
__sub__ = __add__
|
||||
__xor__ = __add__
|
||||
__rxor__ = __add__
|
||||
@read_mem_value
|
||||
def __rsub__(self, other):
|
||||
if isinstance(other, cbits):
|
||||
return other + self
|
||||
else:
|
||||
return self.xor_int(other)
|
||||
@read_mem_value
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self.mul_int(other)
|
||||
try:
|
||||
if min(self.n, other.n) != 1:
|
||||
raise NotImplementedError('high order multiplication')
|
||||
n = max(self.n, other.n)
|
||||
res = self.new(n=max(self.n, other.n))
|
||||
order = (self, other) if self.n != 1 else (other, self)
|
||||
inst.andrs(n, res, *order)
|
||||
return res
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
@read_mem_value
|
||||
def __rmul__(self, other):
|
||||
if isinstance(other, cbits):
|
||||
return other * self
|
||||
else:
|
||||
return self.mul_int(other)
|
||||
def xor_int(self, other):
|
||||
if other == 0:
|
||||
return self
|
||||
self_bits = self.bit_decompose()
|
||||
other_bits = util.bit_decompose(other, max(self.n, util.int_len(other)))
|
||||
extra_bits = [self.new(b, n=1) for b in other_bits[self.n:]]
|
||||
return self.bit_compose([~x if y else x \
|
||||
for x,y in zip(self_bits, other_bits)] \
|
||||
+ extra_bits)
|
||||
def mul_int(self, other):
|
||||
if other == 0:
|
||||
return 0
|
||||
elif other == 1:
|
||||
return self
|
||||
elif self.n == 1:
|
||||
bits = util.bit_decompose(other, util.int_len(other))
|
||||
zero = sbit(0)
|
||||
mul_bits = [self if b else zero for b in bits]
|
||||
return self.bit_compose(mul_bits)
|
||||
else:
|
||||
print self.n, other
|
||||
return NotImplemented
|
||||
def __lshift__(self, i):
|
||||
return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i])
|
||||
def __invert__(self):
|
||||
# res = type(self)(n=self.n)
|
||||
# inst.nots(res, self)
|
||||
# return res
|
||||
one = self.new(value=1, n=1)
|
||||
bits = [one + bit for bit in self.bit_decompose()]
|
||||
return self.bit_compose(bits)
|
||||
def __neg__(self):
|
||||
return self
|
||||
def reveal(self):
|
||||
if self.n > self.clear_type.max_length:
|
||||
raise Exception('too long to reveal')
|
||||
res = self.clear_type(n=self.n)
|
||||
inst.reveal(self.n, res, self)
|
||||
return res
|
||||
def equal(self, other, n=None):
|
||||
bits = (~(self + other)).bit_decompose()
|
||||
return reduce(operator.mul, bits)
|
||||
|
||||
class bit(object):
|
||||
n = 1
|
||||
|
||||
class sbit(bit, sbits):
|
||||
def if_else(self, x, y):
|
||||
return self * (x ^ y) ^ y
|
||||
|
||||
class cbit(bit, cbits):
|
||||
pass
|
||||
|
||||
sbits.bit_type = sbit
|
||||
cbits.bit_type = cbit
|
||||
|
||||
class bitsBlock(oram.Block):
|
||||
value_type = sbits
|
||||
def __init__(self, value, start, lengths, entries_per_block):
|
||||
oram.Block.__init__(self, value, lengths)
|
||||
length = sum(self.lengths)
|
||||
used_bits = entries_per_block * length
|
||||
self.value_bits = self.value.bit_decompose(used_bits)
|
||||
start_length = util.log2(entries_per_block)
|
||||
self.start_bits = util.bit_decompose(start, start_length)
|
||||
self.start_demux = oram.demux_list(self.start_bits)
|
||||
self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \
|
||||
for i in range(entries_per_block)]
|
||||
self.mul_entries = map(operator.mul, self.start_demux, self.entries)
|
||||
self.bits = sum(self.mul_entries).bit_decompose()
|
||||
self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths))
|
||||
self.anti_value = self.mul_value + self.value
|
||||
def set_slice(self, value):
|
||||
value = sbits.compose(util.tuplify(value), sum(self.lengths))
|
||||
for i,b in enumerate(self.start_bits):
|
||||
value = b.if_else(value << (2**i * sum(self.lengths)), value)
|
||||
self.value = value + self.anti_value
|
||||
return self
|
||||
|
||||
oram.block_types[sbits] = bitsBlock
|
||||
|
||||
class dyn_sbits(sbits):
|
||||
pass
|
||||
|
||||
class DynamicArray(Array):
|
||||
def __init__(self, *args):
|
||||
Array.__init__(self, *args)
|
||||
def _malloc(self):
|
||||
return Program.prog.malloc(self.length, 'sd', self.value_type)
|
||||
def _load(self, address):
|
||||
return self.value_type.load_dynamic_mem(cbits.conv(address))
|
||||
def _store(self, value, address):
|
||||
if isinstance(value, MemValue):
|
||||
value = value.read()
|
||||
if isinstance(value, sbits):
|
||||
self.value_type.conv(value).store_in_dynamic_mem(address)
|
||||
else:
|
||||
cbits.conv(value).store_in_dynamic_mem(address)
|
||||
|
||||
sbits.dynamic_array = DynamicArray
|
||||
cbits.dynamic_array = Array
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import compilerLib, program, instructions, types, library, floatingpoint
|
||||
import inspect
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import itertools, time
|
||||
from collections import defaultdict, deque
|
||||
@@ -151,7 +151,7 @@ def determine_scope(block, options):
|
||||
block.defined_registers = set(last_def.iterkeys())
|
||||
|
||||
class Merger:
|
||||
def __init__(self, block, options):
|
||||
def __init__(self, block, options, merge_classes, stop_class=stopopen_class):
|
||||
self.block = block
|
||||
self.instructions = block.instructions
|
||||
self.options = options
|
||||
@@ -159,7 +159,7 @@ class Merger:
|
||||
self.max_parallel_open = int(options.max_parallel_open)
|
||||
else:
|
||||
self.max_parallel_open = float('inf')
|
||||
self.dependency_graph()
|
||||
self.dependency_graph(merge_classes, stop_class)
|
||||
|
||||
def do_merge(self, merges_iter):
|
||||
""" Merge an iterable of nodes in G, returning the number of merged
|
||||
@@ -341,8 +341,7 @@ class Merger:
|
||||
preorder.extend(reversed(startinputs))
|
||||
return preorder
|
||||
|
||||
def longest_paths_merge(self, instruction_type=startopen_class,
|
||||
merge_stopopens=True):
|
||||
def longest_paths_merge(self, merge_stopopens=True):
|
||||
""" Attempt to merge instructions of type instruction_type (which are given in
|
||||
merge_nodes) using longest paths algorithm.
|
||||
|
||||
@@ -357,8 +356,6 @@ class Merger:
|
||||
instructions = self.instructions
|
||||
merge_nodes = self.open_nodes
|
||||
depths = self.depths
|
||||
if instruction_type is not startopen_class and merge_stopopens:
|
||||
raise CompilerError('Cannot merge stopopens whilst merging %s instructions' % instruction_type)
|
||||
if not merge_nodes and not self.input_nodes:
|
||||
return 0
|
||||
|
||||
@@ -379,8 +376,7 @@ class Merger:
|
||||
my_merge = (m for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b)
|
||||
|
||||
if merge_stopopens:
|
||||
my_stopopen = [G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b]
|
||||
|
||||
my_stopopen = filter(lambda x: x != -1, (G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b))
|
||||
mc, nodes[0,b] = self.do_merge(iter(my_merge))
|
||||
|
||||
if merge_stopopens:
|
||||
@@ -420,8 +416,16 @@ class Merger:
|
||||
|
||||
return len(merges)
|
||||
|
||||
def dependency_graph(self, merge_class=startopen_class):
|
||||
def dependency_graph(self, merge_classes, stop_class):
|
||||
""" Create the program dependency graph. """
|
||||
if len(merge_classes) != 1:
|
||||
if stop_class is not type(None):
|
||||
raise NotImplementedError('stop merging only implemented ' \
|
||||
'for single instruction')
|
||||
if int(self.options.max_parallel_open):
|
||||
raise NotImplementedError('parallel limit only implemented ' \
|
||||
'for single instruction')
|
||||
|
||||
block = self.block
|
||||
options = self.options
|
||||
open_nodes = set()
|
||||
@@ -453,13 +457,10 @@ class Merger:
|
||||
next_available_depth = {}
|
||||
self.sources = []
|
||||
self.real_depths = [0] * len(block.instructions)
|
||||
round_type = {}
|
||||
|
||||
def add_edge(i, j):
|
||||
from_merge = isinstance(block.instructions[i], merge_class)
|
||||
to_merge = isinstance(block.instructions[j], merge_class)
|
||||
G.add_edge(i, j)
|
||||
is_source = G.get_attr(i, 'is_source') and G.get_attr(j, 'is_source') and not from_merge
|
||||
G.set_attr(j, 'is_source', is_source)
|
||||
for d in (self.depths, self.real_depths):
|
||||
if d[j] < d[i]:
|
||||
d[j] = d[i]
|
||||
@@ -515,7 +516,7 @@ class Merger:
|
||||
for n,instr in enumerate(block.instructions):
|
||||
outputs,inputs = instr.get_def(), instr.get_used()
|
||||
|
||||
G.add_node(n, is_source=True)
|
||||
G.add_node(n)
|
||||
|
||||
# if options.debug:
|
||||
# col = colordict[instr.__class__.__name__]
|
||||
@@ -534,13 +535,19 @@ class Merger:
|
||||
else:
|
||||
write(reg, n)
|
||||
|
||||
if isinstance(instr, merge_class):
|
||||
if isinstance(instr, merge_classes):
|
||||
open_nodes.add(n)
|
||||
last_open.append(n)
|
||||
G.add_node(n, merges=[])
|
||||
if stop_class != type(None):
|
||||
last_open.append(n)
|
||||
# the following must happen after adding the edge
|
||||
self.real_depths[n] += 1
|
||||
depth = depths[n] + 1
|
||||
while depth in round_type:
|
||||
if round_type[depth] == type(instr):
|
||||
break
|
||||
depth += 1
|
||||
round_type[depth] = type(instr)
|
||||
if int(options.max_parallel_open):
|
||||
skipped_depths = set()
|
||||
while parallel_open[depth] >= int(options.max_parallel_open):
|
||||
@@ -548,10 +555,12 @@ class Merger:
|
||||
depth = next_available_depth.get(depth, depth + 1)
|
||||
for d in skipped_depths:
|
||||
next_available_depth[d] = depth
|
||||
else:
|
||||
self.real_depths[n] = depth
|
||||
parallel_open[depth] += len(instr.args) * instr.get_size()
|
||||
depths[n] = depth
|
||||
|
||||
if isinstance(instr, stopopen_class):
|
||||
if isinstance(instr, stop_class):
|
||||
startopen = last_open.popleft()
|
||||
add_edge(startopen, n)
|
||||
G.set_attr(startopen, 'stop', n)
|
||||
@@ -609,7 +618,7 @@ class Merger:
|
||||
(n, len(block.instructions)), time.asctime()
|
||||
|
||||
if len(open_nodes) > 1000:
|
||||
print "Program has %d %s instructions" % (len(open_nodes), merge_class)
|
||||
print "Program has %d %s instructions" % (len(open_nodes), merge_classes)
|
||||
|
||||
def merge_nodes(self, i, j):
|
||||
""" Merge node j into i, removing node j """
|
||||
|
||||
295
Compiler/circuit_oram.py
Normal file
295
Compiler/circuit_oram.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
from Compiler.path_oram import *
|
||||
from Compiler.util import bit_compose
|
||||
|
||||
def first_diff(a_bits, b_bits):
|
||||
length = len(a_bits)
|
||||
level_bits = [None] * length
|
||||
not_found = 1
|
||||
for i in range(length):
|
||||
# find first position where bits differ (i.e. first 0 in 1 - a XOR b)
|
||||
t = 1 - XOR(a_bits[i], b_bits[i])
|
||||
prev_nf = not_found
|
||||
not_found *= t
|
||||
level_bits[i] = prev_nf - not_found
|
||||
return level_bits, not_found
|
||||
|
||||
def find_deeper(a, b, path, start, length, compute_level=True):
|
||||
a_bits = a.value.bit_decompose(length)
|
||||
b_bits = b.value.bit_decompose(length)
|
||||
path_bits = [type(a_bits[0])(x) for x in path.bit_decompose(length)]
|
||||
a_bits.reverse()
|
||||
b_bits.reverse()
|
||||
path_bits.reverse()
|
||||
level_bits = [0] * length
|
||||
# make sure that winner is set at start if one input is empty
|
||||
any_empty = OR(a.empty, b.empty)
|
||||
a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)]
|
||||
b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)]
|
||||
diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]]
|
||||
diff_preor = type(a.value).PreOR([any_empty] + diff)
|
||||
diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])]
|
||||
winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty)
|
||||
winner_bits = [if_else(winner, bd, ad) for ad,bd in zip(a_diff, b_diff)]
|
||||
winner_preor = type(a.value).PreOR(winner_bits)
|
||||
level_bits = [x - y for x,y in zip(winner_preor, [0] + winner_preor)]
|
||||
return [0] * start + level_bits + [1 - sum(level_bits)], winner
|
||||
|
||||
def find_deepest(paths, search_path, start, length, compute_level=True):
|
||||
if len(paths) == 1:
|
||||
return None, paths[0], 1
|
||||
l = len(paths) / 2
|
||||
_, a, a_index = find_deepest(paths[:l], search_path, start, length, False)
|
||||
_, b, b_index = find_deepest(paths[l:], search_path, start, length, False)
|
||||
level, winner = find_deeper(a, b, search_path, start, length, compute_level)
|
||||
return level, if_else(winner, b, a), if_else(winner, b_index << l, a_index)
|
||||
|
||||
def ge_unary_public(a, b):
|
||||
return sum(a[b-1:])
|
||||
|
||||
def gu_step(high, low):
|
||||
greater = high[0] * (1 - high[1])
|
||||
not_greater = high[1]
|
||||
return if_else(not_greater, 0, high[0] + low[0]), \
|
||||
if_else(greater, 0, high[1] + low[1])
|
||||
|
||||
def greater_unary(a, b):
|
||||
if len(a) == 1:
|
||||
return a[0], b[0]
|
||||
else:
|
||||
l = len(a) / 2
|
||||
return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l]))
|
||||
|
||||
def comp_step(high, low):
|
||||
prod = high[0] * high[1]
|
||||
greater = high[0] - prod
|
||||
smaller = high[1] - prod
|
||||
deferred = 1 - greater - smaller
|
||||
indicator = greater, smaller, deferred
|
||||
return sum(map(operator.mul, indicator, (1, 0, low[0]))), \
|
||||
sum(map(operator.mul, indicator, (0, 1, low[1])))
|
||||
|
||||
def comp_binary(a, b):
|
||||
if len(a) != len(b):
|
||||
raise CompilerError('Arguments must have same length: %s %s' % (str(a), str(b)))
|
||||
if len(a) == 1:
|
||||
return a[0], b[0]
|
||||
else:
|
||||
l = len(a) / 2
|
||||
return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l]))
|
||||
|
||||
def unary_to_binary(l):
|
||||
return sum(x * (i + 1) for i,x in enumerate(l)).bit_decompose(log2(len(l) + 1))
|
||||
|
||||
class CircuitORAM(PathORAM):
|
||||
def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \
|
||||
stash_size=None, bucket_size=2, init_rounds=-1):
|
||||
self.bucket_oram = TrivialORAM
|
||||
self.bucket_size = bucket_size
|
||||
self.D = log2(size)
|
||||
self.logD = log2(self.D)
|
||||
self.L = self.D + 1
|
||||
print 'create oram of size %d with depth %d and %d buckets' \
|
||||
% (size, self.D, self.n_buckets())
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(self.D)
|
||||
if entry_size is not None:
|
||||
self.value_length = len(tuplify(entry_size))
|
||||
self.entry_size = tuplify(entry_size)
|
||||
else:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
self.index_size = log2(size)
|
||||
self.size = size
|
||||
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
|
||||
index_size=self.index_size)
|
||||
self.entry_type = empty_entry.types()
|
||||
self.ram = RAM(self.bucket_size * 2**(self.D+1), self.entry_type, \
|
||||
self.get_array)
|
||||
self.buckets = self.ram
|
||||
if init_rounds != -1:
|
||||
# put memory initialization in different timer
|
||||
stop_timer()
|
||||
start_timer(1)
|
||||
self.ram.init_mem(self.empty_entry(apply_type=False))
|
||||
if init_rounds != -1:
|
||||
stop_timer(1)
|
||||
start_timer()
|
||||
self.root = RefBucket(1, self)
|
||||
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
|
||||
stash_size = 20
|
||||
vt, es = self.internal_entry_size()
|
||||
self.stash = TrivialORAM(stash_size, vt, entry_size=es, \
|
||||
index_size=self.index_size)
|
||||
self.t = MemValue(regint(0))
|
||||
self.state = MemValue(self.value_type.get_type(self.D)(0))
|
||||
self.read_path = MemValue(value_type.clear_type(0))
|
||||
# these include a read value from the stash
|
||||
self.read_value = [Array(self.D + 2, self.value_type.get_type(l))
|
||||
for l in self.entry_size]
|
||||
self.read_empty = Array(self.D + 2, self.value_type.bit_type)
|
||||
def get_ram_index(self, path, level):
|
||||
clear_type = self.value_type.clear_type
|
||||
return ((2**(self.D) + clear_type.conv(path)) >> (self.D - (level - 1)))
|
||||
def get_bucket_ram(self, path, level):
|
||||
if level == 0:
|
||||
return self.stash.ram
|
||||
else:
|
||||
return RefRAM(self.get_ram_index(path, level), self)
|
||||
def get_bucket_oram(self, path, level):
|
||||
if level == 0:
|
||||
return self.stash
|
||||
else:
|
||||
return RefTrivialORAM(self.get_ram_index(path, level), self)
|
||||
def prepare_deepest(self, path):
|
||||
deepest = [None] * (self.D + 2)
|
||||
deepest_index = [None] * (self.D + 2)
|
||||
src = Value()
|
||||
stash_empty = self.stash.ram.is_empty()
|
||||
level, _, index = find_deepest(self.stash.ram.get_value_array(0), path, 0, self.D)
|
||||
goal = if_else(stash_empty, ValueTuple([0] * len(level)), unary_to_binary(level))
|
||||
src = if_else(stash_empty, src, Value(0))
|
||||
src_index = if_else(stash_empty, 0, index)
|
||||
buckets = [self.get_bucket_ram(path, i) for i in range(self.L + 1)]
|
||||
bucket_deepest = [(goal, src, src_index, None)]
|
||||
for i in range(1, self.L):
|
||||
l, _, index = find_deepest(buckets[i].get_value_array(0), path, i - 1, self.D)
|
||||
bucket_deepest.append((unary_to_binary(l), Value(i), index, i))
|
||||
def op(left, right, void=None):
|
||||
goal, src, src_index, _ = left
|
||||
l, secret_i, index, i = right
|
||||
high, low = comp_binary(l, goal)
|
||||
replace = high * (1 - low) * (1 - buckets[i].is_empty())
|
||||
goal = if_else(replace, bit_compose(l), \
|
||||
bit_compose(goal)).bit_decompose(len(goal))
|
||||
src = if_else(replace, secret_i, src)
|
||||
src_index = if_else(replace, index, src_index)
|
||||
return goal, src, src_index, i
|
||||
preop_bucket_deepest = self.value_type.PreOp(op, bucket_deepest)
|
||||
for i in range(1, self.L + 1):
|
||||
goal, src, src_index, _ = preop_bucket_deepest[i-1]
|
||||
high, low = comp_binary(goal, bit_decompose(i, len(goal)))
|
||||
cond = 1 - low * (1 - high)
|
||||
deepest[i] = if_else(cond, src, Value())
|
||||
deepest_index[i] = if_else(cond, src_index, 0)
|
||||
return deepest, deepest_index
|
||||
def prepare_target(self, path, deepest):
|
||||
deepest, deepest_index = deepest
|
||||
dest = Value()
|
||||
src = Value()
|
||||
src_index = 0
|
||||
target = [None] * (self.L + 1)
|
||||
target_index = [None] * (self.L + 1)
|
||||
for i in range(self.L, -1 , -1):
|
||||
i_eq_src = src.equal(i, self.logD + 1)
|
||||
target[i] = if_else(i_eq_src, dest, Value())
|
||||
target_index[i] = if_else(i_eq_src, src_index, 0)
|
||||
dest = if_else(i_eq_src, Value(), dest)
|
||||
src = if_else(i_eq_src, Value(), src)
|
||||
if i == 0:
|
||||
break
|
||||
cond = or_op(dest.empty * self.get_bucket_ram(path, i).has_empty_entry(), \
|
||||
(1 - target[i].empty)) * (1 - deepest[i].empty)
|
||||
src = if_else(cond, deepest[i], src)
|
||||
src_index = if_else(cond, deepest_index[i], src_index)
|
||||
dest = if_else(cond, Value(i), dest)
|
||||
return target, target_index
|
||||
def evict_once(self, path):
|
||||
deepest = self.prepare_deepest(path)
|
||||
target = self.prepare_target(path, deepest)
|
||||
evictor = self.evict_once_fast(path, target)
|
||||
next(evictor)
|
||||
towrite = next(evictor)
|
||||
yield
|
||||
self.add_evicted(path, towrite)
|
||||
yield
|
||||
def evict_once_fast(self, path, target):
|
||||
target, target_index = target
|
||||
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
|
||||
index_size=self.index_size)
|
||||
hold = empty_entry
|
||||
dest = Value()
|
||||
towrite = [None] * (self.L + 1)
|
||||
for i in range(self.L + 1):
|
||||
cond = (1 - hold.is_empty) * (dest.equal(i, self.logD + 1))
|
||||
towrite[i] = if_else(cond, hold, empty_entry)
|
||||
hold = if_else(cond, empty_entry, hold)
|
||||
dest = if_else(cond, Value(), dest)
|
||||
cond = 1 - target[i].empty
|
||||
bucket = self.get_bucket_oram(path, i)
|
||||
if i != self.L:
|
||||
index = target_index[i].bit_decompose(bucket.size)
|
||||
hold = if_else(cond, bucket.read_and_remove_by_public(index), hold)
|
||||
dest = if_else(cond, target[i], dest)
|
||||
if i == 1:
|
||||
yield
|
||||
yield towrite
|
||||
def add_evicted(self, path, towrite):
|
||||
# make sure to add after removing
|
||||
for i in range(1, self.L + 1):
|
||||
self.get_bucket_oram(path, i).add(towrite[i])
|
||||
def evict_rounds(self):
|
||||
get_path = lambda x: bit_compose(reversed(x.bit_decompose(self.D)))
|
||||
paths = [get_path(2 * self.t + i) for i in range(2)]
|
||||
for path in paths:
|
||||
for _ in self.evict_once(path):
|
||||
yield
|
||||
self.t.iadd(1)
|
||||
def evict(self):
|
||||
raise CompilerException('Using this function is likely an error. Use recursive_evict() instead.')
|
||||
Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-%d' % self.size)
|
||||
for i,_ in enumerate(self.evict_rounds()):
|
||||
Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-round-%d-%d' % (i, self.size))
|
||||
def recursive_evict(self):
|
||||
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-%d' % self.size)
|
||||
for i,_ in enumerate(self.recursive_evict_rounds()):
|
||||
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size))
|
||||
def recursive_evict_rounds(self):
|
||||
for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
|
||||
yield
|
||||
def bucket_indices_on_path_to(self, leaf):
|
||||
# root is at 1, different to PathORAM
|
||||
for level in range(self.D + 1):
|
||||
base = self.get_ram_index(leaf, level + 1) * self.bucket_size
|
||||
yield [base + i for i in range(self.bucket_size)]
|
||||
def output(self):
|
||||
print_ln('stash')
|
||||
self.stash.output()
|
||||
@for_range(1, 2**(self.D+1))
|
||||
def f(i):
|
||||
print_ln('node %s', self.value_type.clear_type(i))
|
||||
RefRAM(i, self).output()
|
||||
self.index.output()
|
||||
def __repr__(self):
|
||||
return repr(self.stash) + '\n' + repr(RefBucket(1, self))
|
||||
|
||||
class DebugCircuitORAM(CircuitORAM):
|
||||
""" Debugging only. Tree ORAM using index revealing the access
|
||||
pattern. """
|
||||
index_structure = LocalIndexStructure
|
||||
|
||||
threshold = 2**10
|
||||
|
||||
def OptimalCircuitORAM(size, value_type, *args, **kwargs):
|
||||
if size <= threshold:
|
||||
print size, 'below threshold', threshold
|
||||
return LinearORAM(size, value_type, *args, **kwargs)
|
||||
else:
|
||||
print size, 'above threshold', threshold
|
||||
return RecursiveCircuitORAM(size, value_type, *args, **kwargs)
|
||||
|
||||
class RecursiveCircuitIndexStructure(PackedIndexStructure):
|
||||
""" Secure index using secure tree ORAM. """
|
||||
storage = staticmethod(OptimalCircuitORAM)
|
||||
|
||||
class RecursiveCircuitORAM(CircuitORAM):
|
||||
""" Secure tree ORAM using secure index. """
|
||||
index_structure = RecursiveCircuitIndexStructure
|
||||
|
||||
class AtLeastOneRecursionPackedCircuitORAM(PackedIndexStructure):
|
||||
storage = RecursiveCircuitORAM
|
||||
|
||||
class AtLeastOneRecursionPackedCircuitORAMWithEmpty(PackedORAMWithEmpty):
|
||||
storage = RecursiveCircuitORAM
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
"""
|
||||
Functions for secure comparison of GF(p) types.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.program import Program
|
||||
from Compiler.config import *
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.oram import *
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
class CompilerError(Exception):
|
||||
"""Base class for compiler exceptions."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from math import log, floor, ceil
|
||||
from Compiler.instructions import *
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import heapq
|
||||
from Compiler.exceptions import *
|
||||
@@ -19,7 +19,7 @@ class SparseDiGraph(object):
|
||||
""" max_nodes: maximum no of nodes
|
||||
default_attributes: dict of node attributes and default values """
|
||||
if default_attributes is None:
|
||||
default_attributes = { 'merges': None, 'stop': -1, 'start': -1, 'is_source': True }
|
||||
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
|
||||
self.default_attributes = default_attributes
|
||||
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
|
||||
self.n = max_nodes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
""" This module is for classes of actual assembly instructions.
|
||||
|
||||
@@ -335,6 +335,14 @@ class crash(base.IOInstruction):
|
||||
code = base.opcodes['CRASH']
|
||||
arg_format = []
|
||||
|
||||
class start_grind(base.IOInstruction):
|
||||
code = base.opcodes['STARTGRIND']
|
||||
arg_format = []
|
||||
|
||||
class stop_grind(base.IOInstruction):
|
||||
code = base.opcodes['STOPGRIND']
|
||||
arg_format = []
|
||||
|
||||
@base.gf2n
|
||||
class use_prep(base.Instruction):
|
||||
r""" Input usage. """
|
||||
@@ -1175,6 +1183,12 @@ class divint(base.IntegerInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['DIVINT']
|
||||
|
||||
@base.vectorize
|
||||
class bitdecint(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['BITDECINT']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
|
||||
|
||||
###
|
||||
### Clear comparison instructions
|
||||
###
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import itertools
|
||||
from random import randint
|
||||
@@ -54,6 +54,8 @@ opcodes = dict(
|
||||
JOIN_TAPE = 0x1A,
|
||||
CRASH = 0x1B,
|
||||
USE_PREP = 0x1C,
|
||||
STARTGRIND = 0x1D,
|
||||
STOPGRIND = 0x1E,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -132,6 +134,7 @@ opcodes = dict(
|
||||
EQC = 0x97,
|
||||
JMPI = 0x98,
|
||||
# Integers
|
||||
BITDECINT = 0x99,
|
||||
LDINT = 0x9A,
|
||||
ADDINT = 0x9B,
|
||||
SUBINT = 0x9C,
|
||||
@@ -535,7 +538,11 @@ class Instruction(object):
|
||||
return ""
|
||||
|
||||
def has_var_args(self):
|
||||
return False
|
||||
try:
|
||||
len(self.arg_format)
|
||||
return False
|
||||
except:
|
||||
return True
|
||||
|
||||
def is_vec(self):
|
||||
return False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat
|
||||
from Compiler.instructions import *
|
||||
|
||||
214
Compiler/oram.py
214
Compiler/oram.py
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import random
|
||||
import math
|
||||
@@ -187,9 +187,10 @@ def demux_list(x):
|
||||
res = map(operator.mul, a, b)
|
||||
return res
|
||||
|
||||
def demux_array(x):
|
||||
def demux_array(x, res=None):
|
||||
n = len(x)
|
||||
res = Array(2**n, type(x[0]))
|
||||
if res is None:
|
||||
res = Array(2**n, type(x[0]))
|
||||
if n == 1:
|
||||
res[0] = 1 - x[0]
|
||||
res[1] = x[0]
|
||||
@@ -282,10 +283,18 @@ class ValueTuple(tuple):
|
||||
class Entry(object):
|
||||
""" An (O)RAM entry with empty bit, index, and value. """
|
||||
@staticmethod
|
||||
def get_empty(value_type, value_length, apply_type=True):
|
||||
t = value_type if apply_type else lambda x: x
|
||||
bt = value_type if apply_type else lambda x: x
|
||||
return Entry(t(0), tuple(t(0) for i in range(value_length)), bt(True), t)
|
||||
def get_empty(value_type, entry_size, apply_type=True, index_size=None):
|
||||
res = {}
|
||||
for i,tt in enumerate((value_type, value_type.default_type)):
|
||||
if apply_type:
|
||||
apply = lambda length, x: value_type.get_type(length)(x)
|
||||
else:
|
||||
apply = lambda length, x: x
|
||||
res[i] = Entry(apply(index_size, 0), \
|
||||
tuple(apply(l, 0) for l in entry_size), \
|
||||
apply(1, True), value_type)
|
||||
res[0].defaults = res[1]
|
||||
return res[0]
|
||||
def __init__(self, v, x=None, empty=None, value_type=None):
|
||||
self.created_non_empty = False
|
||||
if x is None:
|
||||
@@ -296,7 +305,7 @@ class Entry(object):
|
||||
else:
|
||||
if empty is None:
|
||||
self.created_non_empty = True
|
||||
empty = value_type(False)
|
||||
empty = value_type.bit_type(False)
|
||||
self.is_empty = empty
|
||||
self.v = v
|
||||
if not isinstance(x, (tuple, list)):
|
||||
@@ -360,14 +369,15 @@ class RefRAM(object):
|
||||
crash()
|
||||
self.size = oram.bucket_size
|
||||
self.entry_type = oram.entry_type
|
||||
self.l = [t.dynamic_array(self.size, t, array.address + \
|
||||
index * oram.bucket_size) \
|
||||
for t,array in zip(self.entry_type,oram.ram.l)]
|
||||
self.l = [oram.get_array(self.size, t, array.address + \
|
||||
index * oram.bucket_size) \
|
||||
for t,array in zip(self.entry_type,oram.ram.l)]
|
||||
self.index = index
|
||||
def init_mem(self, empty_entry):
|
||||
print 'init ram'
|
||||
for a,value in zip(self.l, empty_entry.values()):
|
||||
a.assign_all(value)
|
||||
for a,value in zip(self.l, empty_entry.defaults.values()):
|
||||
# don't use threads if n_threads explicitly set to 1
|
||||
a.assign_all(value, n_threads != 1)
|
||||
def get_empty_bits(self):
|
||||
return self.l[0]
|
||||
def get_indices(self):
|
||||
@@ -401,7 +411,8 @@ class RefRAM(object):
|
||||
return tree_reduce(operator.mul, list(self.get_empty_bits()))
|
||||
def reveal(self):
|
||||
Program.prog.curr_tape.start_new_basicblock()
|
||||
res = RAM(self.size, [t.clear_type for t in self.entry_type], self.index)
|
||||
res = RAM(self.size, [t.clear_type for t in self.entry_type], \
|
||||
lambda *args: Array(*args), self.index)
|
||||
for i,a in enumerate(self.l):
|
||||
for j,x in enumerate(a):
|
||||
res.l[i][j] = x.reveal()
|
||||
@@ -422,33 +433,40 @@ class RefRAM(object):
|
||||
|
||||
class RAM(RefRAM):
|
||||
""" List of entries in memory. """
|
||||
def __init__(self, size, entry_type, index=0):
|
||||
def __init__(self, size, entry_type, get_array, index=0):
|
||||
#print_reg(cint(0), 'r in')
|
||||
self.size = size
|
||||
self.entry_type = entry_type
|
||||
self.l = [t.dynamic_array(self.size, t) for t in entry_type]
|
||||
self.l = [get_array(self.size, t) for t in entry_type]
|
||||
self.index = index
|
||||
|
||||
class AbstractORAM(object):
|
||||
""" Implements reading and writing using read_and_remove and add. """
|
||||
@staticmethod
|
||||
def get_array(size, t, *args, **kwargs):
|
||||
return t.dynamic_array(size, t, *args, **kwargs)
|
||||
def read(self, index):
|
||||
return self._read(self.value_type.hard_conv(index))
|
||||
def write(self, index, value):
|
||||
new_value = [self.value_type.hard_conv(v) \
|
||||
for v in (value if isinstance(value, (tuple, list)) \
|
||||
new_value = [self.value_type.get_type(length).hard_conv(v) \
|
||||
for length,v in zip(self.entry_size, value \
|
||||
if isinstance(value, (tuple, list)) \
|
||||
else (value,))]
|
||||
return self._write(self.value_type.hard_conv(index), *new_value)
|
||||
return self._write(self.index_type.hard_conv(index), *new_value)
|
||||
def access(self, index, new_value, write, new_empty=False):
|
||||
return self._access(self.value_type.hard_conv(index),
|
||||
self.value_type.hard_conv(write),
|
||||
self.value_type.hard_conv(new_empty),
|
||||
*[self.value_type.hard_conv(v) for v in tuplify(new_value)])
|
||||
return self._access(self.index_type.hard_conv(index),
|
||||
self.value_type.bit_type.hard_conv(write),
|
||||
self.value_type.bit_type.hard_conv(new_empty),
|
||||
*[self.value_type.get_type(length).hard_conv(v) \
|
||||
for length,v in zip(self.entry_size, \
|
||||
tuplify(new_value))])
|
||||
def read_and_maybe_remove(self, index):
|
||||
return self.read_and_remove(self.value_type.hard_conv(index)), \
|
||||
return self.read_and_remove(self.index_type.hard_conv(index)), \
|
||||
self.state.read()
|
||||
@method_block
|
||||
def _read(self, index):
|
||||
return self.access(index, (self.value_type(0),) * self.value_length, \
|
||||
return self.access(index, tuple(self.value_type.get_type(l)(0) \
|
||||
for l in self.entry_size), \
|
||||
False)
|
||||
@method_block
|
||||
def _write(self, index, *value):
|
||||
@@ -457,7 +475,7 @@ class AbstractORAM(object):
|
||||
def _access(self, index, write, new_empty, *new_value):
|
||||
Program.prog.curr_tape.\
|
||||
start_new_basicblock(name='abstract-access-remove-%d' % self.size)
|
||||
index = MemValue(self.value_type.hard_conv(index))
|
||||
index = MemValue(self.index_type.hard_conv(index))
|
||||
read_value, read_empty = self.read_and_remove(index)
|
||||
if len(read_value) != self.value_length:
|
||||
raise Exception('read_and_remove() of %s returns wrong length of ' \
|
||||
@@ -472,9 +490,10 @@ class AbstractORAM(object):
|
||||
if len(new_value) != self.value_length:
|
||||
raise Exception('wrong length of new value')
|
||||
value = tuple(MemValue(i) for i in if_else(write, new_value, read_value))
|
||||
empty = self.value_type.hard_conv(new_empty)
|
||||
empty = self.value_type.bit_type.hard_conv(new_empty)
|
||||
self.add(Entry(index, value, if_else(write, empty, read_empty), \
|
||||
value_type=self.value_type))
|
||||
value_type=self.value_type), evict=False)
|
||||
self.recursive_evict()
|
||||
return read_value, read_empty
|
||||
@method_block
|
||||
def delete(self, index, for_real=True):
|
||||
@@ -490,19 +509,25 @@ class AbstractORAM(object):
|
||||
class EmptyException(Exception):
|
||||
pass
|
||||
|
||||
class RefTrivialORAM(object):
|
||||
class EndRecursiveEviction(object):
|
||||
recursive_evict = lambda self: None
|
||||
recursive_evict_rounds = lambda self: itertools.repeat([None])
|
||||
|
||||
class RefTrivialORAM(EndRecursiveEviction):
|
||||
""" Trivial ORAM reference. """
|
||||
contiguous = False
|
||||
def empty_entry(self, apply_type=True):
|
||||
return Entry.get_empty(self.value_type, self.value_length, apply_type)
|
||||
return Entry.get_empty(self.value_type, self.entry_size, \
|
||||
apply_type, self.index_size)
|
||||
def __init__(self, index, oram):
|
||||
self.ram = RefRAM(index, oram)
|
||||
self.index_size = oram.index_size
|
||||
self.value_type, self.value_length = oram.internal_value_type()
|
||||
self.value_type, self.entry_size = oram.internal_entry_size()
|
||||
self.size = oram.bucket_size
|
||||
def init_mem(self):
|
||||
print 'init trivial oram'
|
||||
self.ram.init_mem(self.empty_entry())
|
||||
self.ram.init_mem(self.empty_entry(apply_type=False))
|
||||
def search(self, read_index):
|
||||
if use_binary_search and self.value_type == sgf2n:
|
||||
return self.binary_search(read_index)
|
||||
@@ -695,7 +720,7 @@ class RefTrivialORAM(object):
|
||||
Program.prog.security)
|
||||
return [prefix_empty[i+1] - prefix_empty[i] \
|
||||
for i in range(len(self.ram))]
|
||||
def add(self, new_entry, state=None):
|
||||
def add(self, new_entry, state=None, evict=None):
|
||||
# if self.last_index != new_entry.v:
|
||||
# raise Exception('index mismatch: %s / %s' %
|
||||
# (str(self.last_index), str(new_entry.v)))
|
||||
@@ -761,14 +786,17 @@ class TrivialORAM(RefTrivialORAM, AbstractORAM):
|
||||
entry_size=None, contiguous=True, init_rounds=-1):
|
||||
self.index_size = index_size or log2(size)
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
if entry_size is None:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
else:
|
||||
self.value_length = len(tuplify(entry_size))
|
||||
self.entry_size = tuplify(entry_size)
|
||||
self.contiguous = contiguous
|
||||
entry_type = self.empty_entry().types()
|
||||
self.size = size
|
||||
self.ram = RAM(size, entry_type)
|
||||
self.ram = RAM(size, entry_type, self.get_array)
|
||||
if init_rounds != -1:
|
||||
# put memory initialization in different timer
|
||||
stop_timer()
|
||||
@@ -789,9 +817,16 @@ def get_n_threads(n_loops):
|
||||
|
||||
class LinearORAM(TrivialORAM):
|
||||
""" Contiguous ORAM that stores entries in order. """
|
||||
@staticmethod
|
||||
def get_array(size, t, *args, **kwargs):
|
||||
return Array(size, t, *args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
TrivialORAM.__init__(self, *args, **kwargs)
|
||||
self.index_vector = self.get_array(2 ** self.index_size, \
|
||||
self.index_type.bit_type)
|
||||
def read_and_maybe_remove(self, index):
|
||||
return self.read(index), 0
|
||||
def add(self, entry, state=None):
|
||||
def add(self, entry, state=None, evict=None):
|
||||
if entry.created_non_empty is True:
|
||||
self.write(entry.v, entry.x)
|
||||
else:
|
||||
@@ -802,14 +837,14 @@ class LinearORAM(TrivialORAM):
|
||||
def _read(self, index):
|
||||
maybe_start_timer(6)
|
||||
empty_entry = self.empty_entry(False)
|
||||
index_vector = \
|
||||
demux_array(bit_decompose(index, self.index_size))
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
self.value_length + 1, [self.value_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type.get_type(l) for l in self.entry_size])
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
access_here = index_vector[i]
|
||||
access_here = self.index_vector[i]
|
||||
return access_here * ValueTuple((entry.empty(),) + entry.x)
|
||||
not_found = f()[0]
|
||||
read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x
|
||||
@@ -819,13 +854,13 @@ class LinearORAM(TrivialORAM):
|
||||
def _write(self, index, *new_value):
|
||||
maybe_start_timer(7)
|
||||
empty_entry = self.empty_entry(False)
|
||||
index_vector = \
|
||||
demux_array(bit_decompose(index, self.index_size))
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
new_value = make_array(new_value)
|
||||
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
access_here = index_vector[i]
|
||||
access_here = self.index_vector[i]
|
||||
nv = ValueTuple(new_value)
|
||||
delta_entry = \
|
||||
Entry(0, access_here * (nv - entry.x), \
|
||||
@@ -841,7 +876,7 @@ class LinearORAM(TrivialORAM):
|
||||
new_empty = MemValue(new_empty)
|
||||
write = MemValue(write)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
self.value_length + 1, [self.value_type] + \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
@@ -892,16 +927,24 @@ class RefBucket(object):
|
||||
child.output()
|
||||
|
||||
def random_block(length, value_type):
|
||||
return sum(value_type.get_random_bit() << i for i in range(length))
|
||||
return sum(value_type.bit_type.get_random_bit() << i for i in range(length))
|
||||
|
||||
class List(object):
|
||||
class List(EndRecursiveEviction):
|
||||
""" Debugging only. List which accepts secret values as indices
|
||||
and *reveals* them. """
|
||||
def __init__(self, size, value_type, value_length=1, init_rounds=None):
|
||||
def __init__(self, size, value_type, value_length=1, \
|
||||
init_rounds=None, entry_size=None):
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(log2(size))
|
||||
self.value_length = value_length
|
||||
self.l = [value_type.dynamic_array(size, value_type) \
|
||||
if entry_size is None:
|
||||
self.l = [value_type.dynamic_array(size, value_type) \
|
||||
for i in range(value_length)]
|
||||
else:
|
||||
self.l = [value_type.dynamic_array(size, \
|
||||
value_type.get_type(length)) \
|
||||
for length in entry_size]
|
||||
self.value_length = len(entry_size)
|
||||
for l in self.l:
|
||||
l.assign_all(0)
|
||||
__getitem__ = lambda self,index: [self.l[i][regint(reveal(index))] \
|
||||
@@ -916,8 +959,9 @@ class List(object):
|
||||
read_and_remove = lambda self,i: (self[i], None)
|
||||
def read_and_maybe_remove(self, *args, **kwargs):
|
||||
return self.read_and_remove(*args, **kwargs), 0
|
||||
add = lambda self,entry: self.__setitem__(entry.v.read(), \
|
||||
[v.read() for v in entry.x])
|
||||
add = lambda self,entry,**kwargs: self.__setitem__(entry.v.read(), \
|
||||
[v.read() for v in entry.x])
|
||||
recursive_evict = lambda *args,**kwargs: None
|
||||
def batch_init(self, values):
|
||||
for i,value in enumerate(values):
|
||||
index = self.value_type.hard_conv(i)
|
||||
@@ -939,7 +983,7 @@ class LocalIndexStructure(List):
|
||||
def f(i):
|
||||
self.l[0][i] = random_block(entry_size, value_type)
|
||||
print 'index size:', size
|
||||
def update(self, index, value):
|
||||
def update(self, index, value, evict=None):
|
||||
read_value = self[index]
|
||||
#print 'read', index, read_value
|
||||
#print self.l
|
||||
@@ -977,13 +1021,18 @@ class TreeORAM(AbstractORAM):
|
||||
self.value_type = value_type
|
||||
if entry_size is not None:
|
||||
self.value_length = len(tuplify(entry_size))
|
||||
self.entry_size = tuplify(entry_size)
|
||||
else:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
self.index_size = log2(size)
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.size = size
|
||||
empty_entry = Entry.get_empty(*self.internal_value_type())
|
||||
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
|
||||
index_size=self.D)
|
||||
self.entry_type = empty_entry.types()
|
||||
self.ram = RAM(self.n_buckets() * self.bucket_size, self.entry_type)
|
||||
self.ram = RAM(self.n_buckets() * self.bucket_size, self.entry_type, \
|
||||
self.get_array)
|
||||
if init_rounds != -1:
|
||||
# put memory initialization in different timer
|
||||
stop_timer()
|
||||
@@ -996,7 +1045,7 @@ class TreeORAM(AbstractORAM):
|
||||
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
|
||||
|
||||
self.read_value = Array(self.value_length, value_type)
|
||||
self.read_non_empty = MemValue(self.value_type(0))
|
||||
self.read_non_empty = MemValue(self.value_type.bit_type(0))
|
||||
self.state = MemValue(self.value_type(0))
|
||||
@method_block
|
||||
def add_to_root(self, state, is_empty, v, *x):
|
||||
@@ -1011,7 +1060,7 @@ class TreeORAM(AbstractORAM):
|
||||
entry = bucket.bucket.pop()
|
||||
#print 'evict', entry
|
||||
#print 'from', bucket
|
||||
b = if_else(entry.empty(), self.value_type.get_random_bit(), \
|
||||
b = if_else(entry.empty(), self.value_type.bit_type.get_random_bit(), \
|
||||
get_bit(entry.x[0], self.D - 1 - d, self.D))
|
||||
block = cond_swap(b, entry, self.root.bucket.empty_entry())
|
||||
#print 'empty', entry.empty()
|
||||
@@ -1042,7 +1091,7 @@ class TreeORAM(AbstractORAM):
|
||||
new_path = regint.get_random(self.D)
|
||||
l_star = self.value_type(new_path)
|
||||
self.state.write(l_star)
|
||||
return self.index.update(u, l_star).reveal()
|
||||
return self.index.update(u, l_star, evict=False).reveal()
|
||||
@method_block
|
||||
def read_and_remove_levels(self, u, read_path):
|
||||
u = MemValue(u)
|
||||
@@ -1050,7 +1099,7 @@ class TreeORAM(AbstractORAM):
|
||||
levels = self.D + 1
|
||||
parallel = get_parallel(self.index_size, *self.internal_value_type())
|
||||
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
|
||||
self.value_length + 1, [self.value_type] + \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
def process(level):
|
||||
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
|
||||
@@ -1074,6 +1123,8 @@ class TreeORAM(AbstractORAM):
|
||||
crash()
|
||||
def internal_value_type(self):
|
||||
return self.value_type, self.value_length + 1
|
||||
def internal_entry_size(self):
|
||||
return self.value_type, [self.D] + list(self.entry_size)
|
||||
def n_buckets(self):
|
||||
return 2**(self.D+1)
|
||||
@method_block
|
||||
@@ -1097,7 +1148,7 @@ class TreeORAM(AbstractORAM):
|
||||
Program.prog.curr_tape.\
|
||||
start_new_basicblock(name='read_and_remove-%d-end' % self.size)
|
||||
return [MemValue(v) for v in read_value], MemValue(read_empty)
|
||||
def add(self, entry, state=None):
|
||||
def add(self, entry, state=None, evict=None):
|
||||
if state is None:
|
||||
state = self.state.read()
|
||||
#print_reg(cint(0), 'add')
|
||||
@@ -1141,6 +1192,9 @@ class TreeORAM(AbstractORAM):
|
||||
#print 'd, 2^d', d, 1 << d
|
||||
self.evict2(s1 + (1 << d), s2 + (1 << d), d)
|
||||
self.check()
|
||||
def recursive_evict(self):
|
||||
self.evict()
|
||||
self.index.recursive_evict()
|
||||
|
||||
def batch_init(self, values):
|
||||
""" Batch initalization. Obliviously shuffles and adds N entries to
|
||||
@@ -1349,11 +1403,11 @@ class PackedIndexStructure(object):
|
||||
self.entry_size = tuplify(entry_size)
|
||||
self.value_type = value_type
|
||||
for demux_bits in range(max_demux_bits + 1):
|
||||
self.log_entries_per_element = min(size, \
|
||||
self.log_entries_per_element = min(log2(size), \
|
||||
int(math.floor(math.log(float(get_value_size(value_type)) / \
|
||||
sum(self.entry_size), 2))))
|
||||
self.log_elements_per_block = \
|
||||
max(0, min(demux_bits, log2(size * sum(self.entry_size)) - \
|
||||
max(0, min(demux_bits, log2(size) - \
|
||||
self.log_entries_per_element))
|
||||
if self.log_entries_per_element < 0:
|
||||
self.entries_per_block = 1
|
||||
@@ -1388,14 +1442,17 @@ class PackedIndexStructure(object):
|
||||
print 'log(elements per block):', self.log_elements_per_block
|
||||
print 'elements per block:', self.elements_per_block
|
||||
print 'used bits:', self.used_bits
|
||||
entry_size = [self.used_bits] * self.elements_per_block
|
||||
if real_size > 1:
|
||||
# no need to init underlying ORAM, will be initialized implicitely
|
||||
self.l = self.storage(real_size, value_type, self.elements_per_block, \
|
||||
init_rounds=0)
|
||||
self.l = self.storage(real_size, value_type, \
|
||||
entry_size=entry_size, init_rounds=0)
|
||||
self.small = False
|
||||
else:
|
||||
self.l = List(1, value_type, self.elements_per_block)
|
||||
self.l = List(1, value_type, self.elements_per_block, \
|
||||
entry_size=entry_size)
|
||||
self.small = True
|
||||
self.index_type = self.l.index_type
|
||||
if init_rounds:
|
||||
if init_rounds > 0:
|
||||
real_init_rounds = init_rounds * real_size / size
|
||||
@@ -1484,20 +1541,20 @@ class PackedIndexStructure(object):
|
||||
return self.MultiSlicer(self, index)
|
||||
else:
|
||||
return self.Slicer(self, index)
|
||||
def update(self, index, value):
|
||||
def update(self, index, value, evict=True):
|
||||
""" Updating index return current value. Has to be done in one
|
||||
step to avoid exponential blow-up in ORAM recursion. """
|
||||
return self.access(index, value, True)
|
||||
def access(self, index, value, write):
|
||||
return self.access(index, value, True, evict=evict)
|
||||
def access(self, index, value, write, evict=True):
|
||||
slicer = self.get_slicer(index)
|
||||
block = self.l.read_and_maybe_remove(slicer.a)[0][0]
|
||||
read_value = slicer.read(block)
|
||||
value = if_else(write, ValueTuple(tuplify(value)), \
|
||||
ValueTuple(read_value))
|
||||
self.l.add(Entry(MemValue(self.value_type(slicer.a)), \
|
||||
self.l.add(Entry(MemValue(self.l.index_type(slicer.a)), \
|
||||
ValueTuple(MemValue(v) \
|
||||
for v in slicer.write(value)), \
|
||||
value_type=self.value_type))
|
||||
value_type=self.value_type), evict=evict)
|
||||
return untuplify(read_value)
|
||||
def __getitem__(self, index):
|
||||
slicer = self.get_slicer(index)
|
||||
@@ -1507,7 +1564,9 @@ class PackedIndexStructure(object):
|
||||
# no need for reading first
|
||||
self.l[index] = self.get_slicer(index).write(value)
|
||||
else:
|
||||
self.access(index, value, True)
|
||||
self.access(index, value, True, False)
|
||||
self.l.recursive_evict()
|
||||
recursive_evict = lambda self: self.l.recursive_evict()
|
||||
|
||||
def batch_init(self, values):
|
||||
""" Initialize m values with indices 0, ..., m-1 """
|
||||
@@ -1551,8 +1610,8 @@ class PackedORAMWithEmpty(AbstractORAM, PackedIndexStructure):
|
||||
return res[1:], 1 - res[0]
|
||||
def read_and_maybe_remove(self, index):
|
||||
return self.read(index), 0
|
||||
def add(self, entry, state=None):
|
||||
self.access(entry.v, entry.x, True, entry.empty())
|
||||
def add(self, entry, state=None, evict=True):
|
||||
self.access(entry.v, entry.x, True, entry.empty(), evict=evict)
|
||||
|
||||
class LocalPackedIndexStructure(PackedIndexStructure):
|
||||
""" Debugging only. Packed tree ORAM index revealing the access
|
||||
@@ -1621,30 +1680,39 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
|
||||
storage = OptimalORAM
|
||||
|
||||
def test_oram(oram_type, N, value_type=sint, iterations=100):
|
||||
stop_grind()
|
||||
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
|
||||
value_type = value_type.get_type(32)
|
||||
index_type = value_type.get_type(log2(N))
|
||||
start_grind()
|
||||
print 'initialized'
|
||||
print_ln('initialized')
|
||||
stop_timer()
|
||||
# synchronize
|
||||
start_timer(2)
|
||||
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
||||
value_type(0).reveal()
|
||||
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
||||
stop_timer(2)
|
||||
start_timer()
|
||||
#oram[value_type(0)] = -1
|
||||
#iterations = N
|
||||
@for_range(iterations)
|
||||
def f(i):
|
||||
oram[value_type(i % N)] = value_type(i % N)
|
||||
time()
|
||||
oram[index_type(i % N)] = value_type(i % N)
|
||||
#value, empty = oram.read_and_remove(value_type(i))
|
||||
#print 'first write'
|
||||
time()
|
||||
oram[value_type(i % N)].reveal().print_reg('writ')
|
||||
oram[index_type(i % N)].reveal().print_reg('writ')
|
||||
#print 'first read'
|
||||
@for_range(iterations)
|
||||
def f(i):
|
||||
x = oram[value_type(i % N)]
|
||||
time()
|
||||
x = oram[index_type(i % N)]
|
||||
x.reveal().print_reg('read')
|
||||
# print 'second read'
|
||||
print_ln('%s accesses', 3 * iterations)
|
||||
return oram
|
||||
|
||||
def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
if '_Array' not in dir():
|
||||
from oram import *
|
||||
@@ -147,13 +147,17 @@ class PathORAM(TreeORAM):
|
||||
self.value_type = value_type
|
||||
if entry_size is not None:
|
||||
self.value_length = len(tuplify(entry_size))
|
||||
self.entry_size = tuplify(entry_size)
|
||||
else:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
self.index_size = log2(size)
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.size = size
|
||||
self.entry_type = Entry.get_empty(*self.internal_value_type()).types()
|
||||
self.entry_type = Entry.get_empty(*self.internal_entry_size()).types()
|
||||
|
||||
self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type)
|
||||
self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type,
|
||||
self.get_array)
|
||||
if init_rounds != -1:
|
||||
# put memory initialization in different timer
|
||||
stop_timer()
|
||||
@@ -197,15 +201,15 @@ class PathORAM(TreeORAM):
|
||||
|
||||
# temp storage for the path + stash in eviction
|
||||
self.temp_size = stash_size + self.bucket_size*(self.D+1)
|
||||
self.temp_storage = RAM(self.temp_size, self.entry_type)
|
||||
self.temp_storage = RAM(self.temp_size, self.entry_type, self.get_array)
|
||||
self.temp_levels = [0] * self.temp_size # Array(self.temp_size, 'c')
|
||||
for i in range(self.temp_size):
|
||||
self.temp_levels[i] = 0
|
||||
|
||||
# these include a read value from the stash
|
||||
self.read_value = [Array(self.D + 2, self.value_type)
|
||||
for i in range(self.value_length)]
|
||||
self.read_empty = Array(self.D + 2, self.value_type)
|
||||
self.read_value = [Array(self.D + 2, self.value_type.get_type(l))
|
||||
for l in self.entry_size]
|
||||
self.read_empty = Array(self.D + 2, self.value_type.bit_type)
|
||||
|
||||
self.state = MemValue(self.value_type(0))
|
||||
self.eviction_count = MemValue(cint(0))
|
||||
@@ -242,7 +246,8 @@ class PathORAM(TreeORAM):
|
||||
for j, ram_index in enumerate(ram_indices):
|
||||
self.temp_storage[i*self.bucket_size + j] = self.buckets[ram_index]
|
||||
self.temp_levels[i*self.bucket_size + j] = i
|
||||
self.buckets[ram_index] = Entry.get_empty(self.value_type, 1)
|
||||
ies = self.internal_entry_size()
|
||||
self.buckets[ram_index] = Entry.get_empty(*ies)
|
||||
|
||||
# load the stash
|
||||
for i in range(len(self.stash.ram)):
|
||||
@@ -253,7 +258,7 @@ class PathORAM(TreeORAM):
|
||||
entry = self.stash.ram[i]
|
||||
self.temp_storage[i + self.bucket_size*(self.D+1)] = entry
|
||||
|
||||
te = Entry.get_empty(self.value_type, 1)
|
||||
te = Entry.get_empty(*self.internal_entry_size())
|
||||
self.stash.ram[i] = te
|
||||
|
||||
self.path_regs = [None] * self.bucket_size*(self.D+1)
|
||||
@@ -268,11 +273,11 @@ class PathORAM(TreeORAM):
|
||||
#self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)]
|
||||
if self.use_shuffle_evict:
|
||||
if self.bucket_size == 4:
|
||||
self.size_bits = [[self.value_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)]
|
||||
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)]
|
||||
elif self.bucket_size == 2 or self.bucket_size == 3:
|
||||
self.size_bits = [[self.value_type(i) for i in (0, 0)] for j in range(self.D+1)]
|
||||
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0)] for j in range(self.D+1)]
|
||||
else:
|
||||
self.size_bits = [[self.value_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)]
|
||||
self.size_bits = [[self.value_type.bit_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)]
|
||||
self.stash_size = Counter(0, max_val=len(self.stash.ram))
|
||||
|
||||
leaf = self.state.read().reveal()
|
||||
@@ -308,7 +313,7 @@ class PathORAM(TreeORAM):
|
||||
|
||||
empty_entry = self.empty_entry(False)
|
||||
skip = 1
|
||||
found = Array(self.bucket_size, self.value_type)
|
||||
found = Array(self.bucket_size, self.value_type.bit_type)
|
||||
entries = [self.buckets[j] for j in ram_indices]
|
||||
indices = [e.v for e in entries]
|
||||
empty_bits = [e.empty() for e in entries]
|
||||
@@ -341,8 +346,8 @@ class PathORAM(TreeORAM):
|
||||
self.read_empty[self.D+1] = empty
|
||||
|
||||
def empty_entry(self, apply_type=True):
|
||||
vtype, vlength = self.internal_value_type()
|
||||
return Entry.get_empty(vtype, vlength, apply_type)
|
||||
vtype, entry_size = self.internal_entry_size()
|
||||
return Entry.get_empty(vtype, entry_size, apply_type, self.index_size)
|
||||
|
||||
def shuffle_evict(self, leaf):
|
||||
""" Evict using oblivious shuffling etc """
|
||||
@@ -409,25 +414,25 @@ class PathORAM(TreeORAM):
|
||||
if self.bucket_size == 4:
|
||||
c = s[0]*s[1]
|
||||
if self.value_type == sgf2n:
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type(c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][3] = [1 - self.value_type(s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
|
||||
else:
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type(c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][3] = [1 - self.value_type(s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
|
||||
elif self.bucket_size == 2:
|
||||
if evict_debug:
|
||||
print_str('%s,%s,', s[0].reveal(), s[1].reveal())
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
|
||||
elif self.bucket_size == 3:
|
||||
c = s[0]*s[1]
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type(s[0] + s[1] - c), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type(s[1]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type(c), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
|
||||
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c), self.value_type.clear_type(j)]
|
||||
|
||||
if evict_debug:
|
||||
print_ln()
|
||||
@@ -471,7 +476,7 @@ class PathORAM(TreeORAM):
|
||||
merged_entries = [e for e in merged_entries if e is not None]
|
||||
|
||||
# need to copy entries/levels to memory for re-positioning
|
||||
entries_ram = RAM(self.temp_size, self.entry_type)
|
||||
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)
|
||||
levels_array = Array(self.temp_size, cint)
|
||||
|
||||
for i,entrylev in enumerate(merged_entries):
|
||||
@@ -570,10 +575,10 @@ class PathORAM(TreeORAM):
|
||||
|
||||
def adjust_lca(self, lca_bits, lev, not_empty, prnt=False):
|
||||
""" Adjust LCA based on bucket capacities (and original clear level, lev) """
|
||||
found = self.value_type(0)
|
||||
assigned = self.value_type(0)
|
||||
try_add_here = self.value_type(0)
|
||||
new_lca = [self.value_type(0)] * (self.D + 1)
|
||||
found = self.value_type.bit_type(0)
|
||||
assigned = self.value_type.bit_type(0)
|
||||
try_add_here = self.value_type.bit_type(0)
|
||||
new_lca = [self.value_type.bit_type(0)] * (self.D + 1)
|
||||
|
||||
upper = min(lev + self.sigma, self.D)
|
||||
lower = max(lev - self.tau, 0)
|
||||
@@ -639,7 +644,7 @@ class PathORAM(TreeORAM):
|
||||
a_bits = bit_decompose(a, self.D)
|
||||
b_bits = bit_decompose(b, self.D)
|
||||
found = [None] * self.D
|
||||
not_found = self.value_type(not_empty) #1
|
||||
not_found = self.value_type.bit_type(not_empty) #1
|
||||
if limit is None:
|
||||
limit = self.D
|
||||
|
||||
@@ -722,13 +727,13 @@ class PathORAM(TreeORAM):
|
||||
|
||||
return levstar, a
|
||||
|
||||
def add(self, entry, state=None):
|
||||
def add(self, entry, state=None, evict=True):
|
||||
if state is None:
|
||||
state = self.state.read()
|
||||
l = state
|
||||
x = tuple((self.value_type(i.read())) for i in entry.x)
|
||||
x = tuple(i.read() for i in entry.x)
|
||||
|
||||
e = Entry(self.value_type(entry.v.read()), (l,) + x, entry.empty())
|
||||
e = Entry(entry.v.read(), (l,) + x, entry.empty())
|
||||
|
||||
#self.temp_storage[self.temp_size-1] = e * 1
|
||||
#self.temp_levels[self.temp_size-1] = 0
|
||||
@@ -738,7 +743,8 @@ class PathORAM(TreeORAM):
|
||||
except Exception:
|
||||
print self
|
||||
raise
|
||||
self.evict()
|
||||
if evict:
|
||||
self.evict()
|
||||
|
||||
class LocalPathORAM(PathORAM):
|
||||
""" Debugging only. Path ORAM using index revealing the access
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from random import randint
|
||||
import math
|
||||
@@ -361,7 +361,7 @@ def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
||||
if config is None:
|
||||
config = configure_waksman(random_perm(n))
|
||||
for i,c in enumerate(config):
|
||||
config[i] = [value_type(b) for b in c]
|
||||
config[i] = [value_type.bit_type(b) for b in c]
|
||||
waksman(x, config, reverse=reverse)
|
||||
waksman(x, config, reverse=reverse)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
@@ -66,6 +66,8 @@ class Program(object):
|
||||
self.free_threads = set()
|
||||
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
|
||||
self.types = {}
|
||||
self.to_merge = [Compiler.instructions.startopen_class]
|
||||
self.stop_class = Compiler.instructions.stopopen_class
|
||||
Program.prog = self
|
||||
|
||||
self.reset_values()
|
||||
@@ -125,6 +127,7 @@ class Program(object):
|
||||
self.name = progname
|
||||
if len(args) > 1:
|
||||
self.name += '-' + '-'.join(args[1:])
|
||||
self.progname = progname
|
||||
|
||||
def new_tape(self, function, args=[], name=None):
|
||||
if name is None:
|
||||
@@ -534,7 +537,9 @@ class Tape:
|
||||
(block.name, i, len(self.basicblocks), \
|
||||
len(block.instructions))
|
||||
# the next call is necessary for allocation later even without merging
|
||||
merger = al.Merger(block, options)
|
||||
merger = al.Merger(block, options, \
|
||||
tuple(self.program.to_merge), \
|
||||
self.program.stop_class)
|
||||
if options.dead_code_elimination:
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Eliminate dead code...'
|
||||
@@ -545,8 +550,8 @@ class Tape:
|
||||
block.defined_registers = set()
|
||||
continue
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Merging open instructions...'
|
||||
numrounds = merger.longest_paths_merge()
|
||||
print 'Merging instructions...'
|
||||
numrounds = merger.longest_paths_merge(self.program.stop_class != type(None))
|
||||
if numrounds > 0:
|
||||
print 'Program requires %d rounds of communication' % numrounds
|
||||
numinv = sum(len(i.args) for i in block.instructions if isinstance(i, Compiler.instructions.startopen_class))
|
||||
@@ -555,7 +560,9 @@ class Tape:
|
||||
if options.dead_code_elimination:
|
||||
block.instructions = filter(lambda x: x is not None, block.instructions)
|
||||
if not (options.merge_opens and self.merge_opens):
|
||||
print 'Not merging open instructions in tape %s' % self.name
|
||||
print 'Not merging instructions in tape %s' % self.name
|
||||
else:
|
||||
print 'Rounds determined by', self.program.to_merge
|
||||
|
||||
# add jumps
|
||||
offset = 0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import itertools
|
||||
|
||||
@@ -7,3 +7,9 @@ class chain(object):
|
||||
self.args = args
|
||||
def __iter__(self):
|
||||
return itertools.chain(*self.args)
|
||||
|
||||
class cycle(object):
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
def __iter__(self):
|
||||
return itertools.cycle(*self.args)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
from Compiler.program import Tape
|
||||
from Compiler.exceptions import *
|
||||
@@ -791,13 +791,8 @@ class regint(_register, _int):
|
||||
return cint(self).mod2m(*args, **kwargs)
|
||||
|
||||
def bit_decompose(self, bit_length=None):
|
||||
res = []
|
||||
x = self
|
||||
two = regint(2)
|
||||
for i in range(bit_length or program.bit_length):
|
||||
y = x / two
|
||||
res.append(x - two * y)
|
||||
x = y
|
||||
res = [regint() for i in range(bit_length or program.bit_length)]
|
||||
bitdecint(self, *res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
@@ -819,6 +814,9 @@ class regint(_register, _int):
|
||||
class _secret(_register):
|
||||
__slots__ = []
|
||||
|
||||
PreOR = staticmethod(lambda l: floatingpoint.PreORC(l))
|
||||
PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l))
|
||||
|
||||
@vectorized_classmethod
|
||||
@set_instruction_type
|
||||
def protect_memory(cls, start, end):
|
||||
@@ -971,6 +969,9 @@ class sint(_secret, _int):
|
||||
clear_type = cint
|
||||
reg_type = 's'
|
||||
|
||||
PreOp = staticmethod(floatingpoint.PreOpL)
|
||||
PreOR = staticmethod(floatingpoint.PreOR)
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_random_int(cls, bits):
|
||||
res = sint()
|
||||
@@ -1164,6 +1165,10 @@ class sgf2n(_secret, _gf2n):
|
||||
clear_type = cgf2n
|
||||
reg_type = 'sg'
|
||||
|
||||
@classmethod
|
||||
def get_type(cls, length):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
@@ -1265,8 +1270,10 @@ class sgf2n(_secret, _gf2n):
|
||||
masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal()
|
||||
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
|
||||
|
||||
sint.basic_type = sint
|
||||
sgf2n.basic_type = sgf2n
|
||||
for t in (sint, sgf2n):
|
||||
t.bit_type = t
|
||||
t.basic_type = t
|
||||
t.default_type = t
|
||||
|
||||
|
||||
class sgf2nint(sgf2n):
|
||||
@@ -2305,13 +2312,13 @@ class Array(object):
|
||||
self[i] = value[source_index]
|
||||
source_index.iadd(1)
|
||||
return
|
||||
self._store(self.value_type.conv(value), self.get_address(index))
|
||||
self._store(value, self.get_address(index))
|
||||
|
||||
def _load(self, address):
|
||||
return self.value_type.load_mem(address)
|
||||
|
||||
def _store(self, value, address):
|
||||
value.store_in_mem(address)
|
||||
self.value_type.conv(value).store_in_mem(address)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
@@ -2335,10 +2342,10 @@ class Array(object):
|
||||
self[i] = j
|
||||
return self
|
||||
|
||||
def assign_all(self, value):
|
||||
mem_value = self.value_type.MemValue(value)
|
||||
n_loops = 8 if len(self) > 2**20 else 1
|
||||
@library.for_range_multithread(n_loops, 1024, len(self))
|
||||
def assign_all(self, value, use_threads=True):
|
||||
mem_value = MemValue(value)
|
||||
n_threads = 8 if use_threads and len(self) > 2**20 else 1
|
||||
@library.for_range_multithread(n_threads, 1024, len(self))
|
||||
def f(i):
|
||||
self[i] = mem_value
|
||||
return self
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# (C) 2018 University of Bristol. See License.txt
|
||||
# (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _Exceptions
|
||||
#define _Exceptions
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
(C) 2018 University of Bristol. See License.txt.
|
||||
(C) 2018 University of Bristol, Bar-Ilan University. See License.txt.
|
||||
|
||||
The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* (C) 2018 University of Bristol. See License.txt
|
||||
* (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
*
|
||||
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
|
||||
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* (C) 2018 University of Bristol. See License.txt
|
||||
* (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
*
|
||||
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
|
||||
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* AddableVector.h
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#include "Ciphertext.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _Ciphertext
|
||||
#define _Ciphertext
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "DiscreteGauss.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _DiscreteGauss
|
||||
#define _DiscreteGauss
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "FHE/FFT.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _FFT
|
||||
#define _FFT
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/FFT.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _FFT_Data
|
||||
#define _FFT_Data
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "FHE_Keys.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _FHE_Keys
|
||||
#define _FHE_Keys
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "FHE_Params.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _FHE_Params
|
||||
#define _FHE_Params
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* Generator.h
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "FHE/Matrix.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#ifndef _myHNF
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
|
||||
#include "FHE/NTL-Subs.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
#ifndef _NTL_Subs
|
||||
#define _NTL_Subs
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// (C) 2018 University of Bristol. See License.txt
|
||||
// (C) 2018 University of Bristol, Bar-Ilan University. See License.txt
|
||||
|
||||
/*
|
||||
* NoiseBound.cpp
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user