Initial release.

This commit is contained in:
Marcel Keller
2016-09-02 19:16:51 +01:00
commit 81e35b3549
149 changed files with 30008 additions and 0 deletions

72
.gitignore vendored Normal file
View File

@@ -0,0 +1,72 @@
# Offline data, runtime logs #
##############################
Player-Data/*
Prep-Data/*
logs/*
Language-Definition/main.pdf
# Personal CONFIG file #
##############################
CONFIG.mine
# Compiled source #
###################
Programs/Bytecode/*
Programs/Schedules/*
Programs/Public-Input/*
*.com
*.class
*.dll
*.exe
*.x
*.o
*.so
*.pyc
*.bc
*.sch
*.a
# Packages #
############
# it's better to unpack these files and commit the raw source
# git has its own built in compression methods
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip
# Latex #
#########
*.aux
*.lof
*.log
*.lot
*.fls
*.out
*.toc
*.fmt
*.bbl
*.bcf
*.blg
# Logs and databases #
######################
*.log
*.sql
*.sqlite
# OS generated files #
######################
*~
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "SimpleOT"]
path = SimpleOT
url = git@github.com:pascholl/SimpleOT.git

509
Auth/MAC_Check.cpp Normal file
View File

@@ -0,0 +1,509 @@
// (C) 2016 University of Bristol. See License.txt
#include "Auth/MAC_Check.h"
#include "Auth/Subroutines.h"
#include "Exceptions/Exceptions.h"
#include "Tools/random.h"
#include "Tools/time-func.h"
#include "Tools/int.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include <algorithm>
enum mc_timer { SEND, RECV_ADD, BCAST, RECV_SUM, SEED, COMMIT, WAIT_SUMMER, RECV, SUM, SELECT, MAX_TIMER };
const char* mc_timer_names[] = {
"sending",
"receiving and adding",
"broadcasting",
"receiving summed values",
"random seed",
"commit and open",
"wait for summer thread",
"receiving",
"summing",
"waiting for select()"
};
template<class T>
MAC_Check<T>::MAC_Check(const T& ai, int opening_sum, int max_broadcast, int send_player) :
base_player(send_player), opening_sum(opening_sum), max_broadcast(max_broadcast)
{
popen_cnt=0;
alphai=ai;
values_opened=0;
timers.resize(MAX_TIMER);
}
template<class T>
MAC_Check<T>::~MAC_Check()
{
for (unsigned int i = 0; i < timers.size(); i++)
if (timers[i].elapsed() > 0)
cerr << T::type_string() << " " << mc_timer_names[i] << ": "
<< timers[i].elapsed() << endl;
for (unsigned int i = 0; i < player_timers.size(); i++)
if (player_timers[i].elapsed() > 0)
cerr << T::type_string() << " waiting for " << i << ": "
<< player_timers[i].elapsed() << endl;
}
template<class T, int t>
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check<T>& MC)
{
MC.player_timers.resize(P.num_players());
vector<octetStream>& oss = MC.oss;
oss.resize(P.num_players());
vector<int> senders;
senders.reserve(P.num_players());
for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players;
relative_sender < last_sum_players; relative_sender += sum_players)
{
int sender = positive_modulo(send_player + relative_sender, P.num_players());
senders.push_back(sender);
}
for (int j = 0; j < (int)senders.size(); j++)
P.request_receive(senders[j], oss[j]);
for (int j = 0; j < (int)senders.size(); j++)
{
int sender = senders[j];
MC.player_timers[sender].start();
P.wait_receive(sender, oss[j], true);
MC.player_timers[sender].stop();
if ((unsigned)oss[j].get_length() < values.size() * T::size())
{
stringstream ss;
ss << "Not enough information received, expected "
<< values.size() * T::size() << " bytes, got "
<< oss[j].get_length();
throw Processor_Error(ss.str());
}
MC.timers[SUM].start();
for (unsigned int i=0; i<values.size(); i++)
{
values[i].template add<t>(oss[j].consume(T::size()));
}
MC.timers[SUM].stop();
}
}
template<class T>
void MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
AddToMacs(S);
for (unsigned int i=0; i<S.size(); i++)
{ values[i]=S[i].get_share(); }
os.reset_write_head();
int sum_players = P.num_players();
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
while (true)
{
int last_sum_players = sum_players;
sum_players = (sum_players - 2 + opening_sum) / opening_sum;
if (sum_players == 0)
break;
if (my_relative_num >= sum_players && my_relative_num < last_sum_players)
{
for (unsigned int i=0; i<S.size(); i++)
{ values[i].pack(os); }
int receiver = positive_modulo(base_player + my_relative_num % sum_players, P.num_players());
timers[SEND].start();
P.send_to(receiver,os,true);
timers[SEND].stop();
}
if (my_relative_num < sum_players)
{
timers[RECV_ADD].start();
if (T::t() == 2)
add_openings<T,2>(values, P, sum_players, last_sum_players, base_player, *this);
else
add_openings<T,0>(values, P, sum_players, last_sum_players, base_player, *this);
timers[RECV_ADD].stop();
}
}
if (P.my_num() == base_player)
{
os.reset_write_head();
for (unsigned int i=0; i<S.size(); i++)
{ values[i].pack(os); }
timers[BCAST].start();
for (int i = 1; i < max_broadcast && i < P.num_players(); i++)
{
P.send_to((base_player + i) % P.num_players(), os, true);
}
timers[BCAST].stop();
AddToValues(values);
}
else if (my_relative_num * max_broadcast < P.num_players())
{
int sender = (base_player + my_relative_num / max_broadcast) % P.num_players();
ReceiveValues(values, P, sender);
timers[BCAST].start();
for (int i = 0; i < max_broadcast; i++)
{
int relative_receiver = (my_relative_num * max_broadcast + i);
if (relative_receiver < P.num_players())
{
int receiver = (base_player + relative_receiver) % P.num_players();
P.send_to(receiver, os, true);
}
}
timers[BCAST].stop();
}
values_opened += S.size();
}
template<class T>
void MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
S.size();
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
if (my_relative_num * max_broadcast >= P.num_players())
{
int sender = (base_player + my_relative_num / max_broadcast) % P.num_players();
ReceiveValues(values, P, sender);
}
else
GetValues(values);
popen_cnt += values.size();
CheckIfNeeded(P);
/* not compatible with continuous communication
send_player++;
if (send_player==P.num_players())
{ send_player=0; }
*/
}
template<class T>
void MAC_Check<T>::AddToMacs(const vector<Share<T> >& shares)
{
for (unsigned int i = 0; i < shares.size(); i++)
macs.push_back(shares[i].get_mac());
}
template<class T>
void MAC_Check<T>::AddToValues(vector<T>& values)
{
vals.insert(vals.end(), values.begin(), values.end());
}
template<class T>
void MAC_Check<T>::ReceiveValues(vector<T>& values, const Player& P, int sender)
{
timers[RECV_SUM].start();
P.receive_player(sender, os, true);
timers[RECV_SUM].stop();
for (unsigned int i = 0; i < values.size(); i++)
values[i].unpack(os);
AddToValues(values);
}
template<class T>
void MAC_Check<T>::GetValues(vector<T>& values)
{
int size = values.size();
if (popen_cnt + size > int(vals.size()))
{
stringstream ss;
ss << "wanted " << values.size() << " values from " << popen_cnt << ", only " << vals.size() << " in store";
throw out_of_range(ss.str());
}
values.clear();
typename vector<T>::iterator first = vals.begin() + popen_cnt;
values.insert(values.end(), first, first + size);
}
template<class T>
void MAC_Check<T>::CheckIfNeeded(const Player& P)
{
if (WaitingForCheck() >= POPEN_MAX)
Check(P);
}
template <class T>
void MAC_Check<T>::AddToCheck(const T& mac, const T& value, const Player& P)
{
CheckIfNeeded(P);
macs.push_back(mac);
vals.push_back(value);
}
template<class T>
void MAC_Check<T>::Check(const Player& P)
{
if (WaitingForCheck() == 0)
return;
//cerr << "In MAC Check : " << popen_cnt << endl;
octet seed[SEED_SIZE];
timers[SEED].start();
Create_Random_Seed(seed,P,SEED_SIZE);
timers[SEED].stop();
PRNG G;
G.SetSeed(seed);
Share<T> sj;
T a,gami,h,temp;
a.assign_zero();
gami.assign_zero();
vector<T> tau(P.num_players());
for (int i=0; i<popen_cnt; i++)
{ h.almost_randomize(G);
temp.mul(h,vals[i]);
a.add(a,temp);
temp.mul(h,macs[i]);
gami.add(gami,temp);
}
vals.erase(vals.begin(), vals.begin() + popen_cnt);
macs.erase(macs.begin(), macs.begin() + popen_cnt);
temp.mul(alphai,a);
tau[P.my_num()].sub(gami,temp);
//cerr << "\tCommit and Open" << endl;
timers[COMMIT].start();
Commit_And_Open(tau,P);
timers[COMMIT].stop();
//cerr << "\tFinal Check" << endl;
T t;
t.assign_zero();
for (int i=0; i<P.num_players(); i++)
{ t.add(t,tau[i]); }
if (!t.is_zero()) { throw mac_fail(); }
popen_cnt=0;
}
template<class T>
int mc_base_id(int function_id, int thread_num)
{
return (function_id << 28) + ((T::field_type() + 1) << 24) + (thread_num << 16);
}
template<class T>
Separate_MAC_Check<T>::Separate_MAC_Check(const T& ai, Names& Nms,
int thread_num, int opening_sum, int max_broadcast, int send_player) :
MAC_Check<T>(ai, opening_sum, max_broadcast, send_player),
check_player(Nms, mc_base_id<T>(1, thread_num))
{
}
template<class T>
void Separate_MAC_Check<T>::Check(const Player& P)
{
P.my_num();
MAC_Check<T>::Check(check_player);
}
template<class T>
void* run_summer_thread(void* summer)
{
((Summer<T>*) summer)->run();
return 0;
}
template <class T>
Parallel_MAC_Check<T>::Parallel_MAC_Check(const T& ai, Names& Nms,
int thread_num, int opening_sum, int max_broadcast, int base_player) :
Separate_MAC_Check<T>(ai, Nms, thread_num, opening_sum, max_broadcast, base_player),
send_player(Nms, mc_base_id<T>(2, thread_num)),
send_base_player(base_player)
{
int sum_players = Nms.num_players();
Player* summer_send_player = &send_player;
for (int i = 0; ; i++)
{
int last_sum_players = sum_players;
sum_players = (sum_players - 2 + opening_sum) / opening_sum;
int next_sum_players = (sum_players - 2 + opening_sum) / opening_sum;
if (sum_players == 0)
break;
Player* summer_receive_player = summer_send_player;
summer_send_player = new Player(Nms, mc_base_id<T>(3, thread_num));
summers.push_back(new Summer<T>(sum_players, last_sum_players, next_sum_players,
summer_send_player, summer_receive_player, *this));
pthread_create(&(summers[i]->thread), 0, run_summer_thread<T>, summers[i]);
}
receive_player = summer_send_player;
}
template<class T>
Parallel_MAC_Check<T>::~Parallel_MAC_Check()
{
for (unsigned int i = 0; i < summers.size(); i++)
{
summers[i]->input_queue.stop();
pthread_join(summers[i]->thread, 0);
delete summers[i];
}
}
template<class T>
void Parallel_MAC_Check<T>::POpen_Begin(vector<T>& values,
const vector<Share<T> >& S, const Player& P)
{
values.size();
this->AddToMacs(S);
int my_relative_num = positive_modulo(P.my_num() - send_base_player, P.num_players());
int sum_players = (P.num_players() - 2 + this->opening_sum) / this->opening_sum;
int receiver = positive_modulo(send_base_player + my_relative_num % sum_players, P.num_players());
// use queue rather sending to myself
if (receiver == P.my_num())
{
for (unsigned int i = 0; i < S.size(); i++)
values[i] = S[i].get_share();
summers.front()->share_queue.push(values);
}
else
{
this->os.reset_write_head();
for (unsigned int i=0; i<S.size(); i++)
S[i].get_share().pack(this->os);
this->timers[SEND].start();
send_player.send_to(receiver,this->os,true);
this->timers[SEND].stop();
}
for (unsigned int i = 0; i < summers.size(); i++)
summers[i]->input_queue.push(S.size());
this->values_opened += S.size();
send_base_player = (send_base_player + 1) % send_player.num_players();
}
template<class T>
void Parallel_MAC_Check<T>::POpen_End(vector<T>& values,
const vector<Share<T> >& S, const Player& P)
{
int last_size = 0;
this->timers[WAIT_SUMMER].start();
summers.back()->output_queue.pop(last_size);
this->timers[WAIT_SUMMER].stop();
if (int(values.size()) != last_size)
{
stringstream ss;
ss << "stopopen wants " << values.size() << " values, but I have " << last_size << endl;
throw Processor_Error(ss.str().c_str());
}
if (this->base_player == P.my_num())
{
value_queue.pop(values);
if (int(values.size()) != last_size)
throw Processor_Error("wrong number of local values");
else
this->AddToValues(values);
}
this->MAC_Check<T>::POpen_End(values, S, *receive_player);
this->base_player = (this->base_player + 1) % send_player.num_players();
}
template<class T>
Direct_MAC_Check<T>::Direct_MAC_Check(const T& ai, Names& Nms, int num) : Separate_MAC_Check<T>(ai, Nms, num) {
open_counter = 0;
}
template<class T>
Direct_MAC_Check<T>::~Direct_MAC_Check() {
cerr << T::type_string() << " open counter: " << open_counter << endl;
}
template<class T>
void Direct_MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
values.size();
this->os.reset_write_head();
for (unsigned int i=0; i<S.size(); i++)
S[i].get_share().pack(this->os);
this->timers[SEND].start();
P.send_all(this->os,true);
this->timers[SEND].stop();
this->AddToMacs(S);
for (unsigned int i=0; i<S.size(); i++)
this->vals.push_back(S[i].get_share());
}
template<class T, int t>
void direct_add_openings(vector<T>& values, const Player& P, vector<octetStream>& os)
{
for (unsigned int i=0; i<values.size(); i++)
for (int j=0; j<P.num_players(); j++)
if (j!=P.my_num())
values[i].template add<t>(os[j].consume(T::size()));
}
template<class T>
void Direct_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
S.size();
oss.resize(P.num_players());
this->GetValues(values);
this->timers[RECV].start();
for (int j=0; j<P.num_players(); j++)
if (j!=P.my_num())
P.receive_player(j,oss[j],true);
this->timers[RECV].stop();
open_counter++;
if (T::t() == 2)
direct_add_openings<T,2>(values, P, oss);
else
direct_add_openings<T,0>(values, P, oss);
for (unsigned int i = 0; i < values.size(); i++)
this->vals[this->popen_cnt+i] = values[i];
this->popen_cnt += values.size();
this->CheckIfNeeded(P);
}
template class MAC_Check<gfp>;
template class Direct_MAC_Check<gfp>;
template class Parallel_MAC_Check<gfp>;
template class MAC_Check<gf2n>;
template class Direct_MAC_Check<gf2n>;
template class Parallel_MAC_Check<gf2n>;
#ifdef USE_GF2N_LONG
template class MAC_Check<gf2n_short>;
template class Direct_MAC_Check<gf2n_short>;
template class Parallel_MAC_Check<gf2n_short>;
#endif

141
Auth/MAC_Check.h Normal file
View File

@@ -0,0 +1,141 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _MAC_Check
#define _MAC_Check
/* Class for storing MAC Check data and doing the Check */
#include <vector>
#include <deque>
using namespace std;
#include "Math/Share.h"
#include "Networking/Player.h"
#include "Networking/ServerSocket.h"
#include "Auth/Summer.h"
#include "Tools/time-func.h"
/* The MAX number of things we will partially open before running
* a MAC Check
*
* Keep this at much less than 1MB of data to be able to cope with
* multi-threaded players
*
*/
#define POPEN_MAX 1000000
template<class T>
class MAC_Check
{
protected:
/* POpen Data */
int popen_cnt;
vector<T> macs;
vector<T> vals;
int base_player;
int opening_sum;
int max_broadcast;
octetStream os;
/* MAC Share */
T alphai;
void AddToMacs(const vector< Share<T> >& shares);
void AddToValues(vector<T>& values);
void ReceiveValues(vector<T>& values, const Player& P, int sender);
void GetValues(vector<T>& values);
void CheckIfNeeded(const Player& P);
int WaitingForCheck()
{ return max(macs.size(), vals.size()); }
public:
int values_opened;
vector<Timer> timers;
vector<Timer> player_timers;
vector<octetStream> oss;
MAC_Check(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0);
virtual ~MAC_Check();
/* Run protocols to partially open data and check the MACs are
* all OK.
* - Implicit assume that the amount of data being sent does
* not overload the OS
* Begin and End expect the same arrays values and S passed to them
* and they expect values to be of the same size as S.
*/
virtual void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
virtual void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
void AddToCheck(const T& mac, const T& value, const Player& P);
virtual void Check(const Player& P);
int number() const { return values_opened; }
const T& get_alphai() const { return alphai; }
};
template<class T, int t>
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check<T>& MC);
template<class T>
class Separate_MAC_Check: public MAC_Check<T>
{
// Different channel for checks
Player check_player;
protected:
// No sense to expose this
Separate_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0);
virtual ~Separate_MAC_Check() {};
public:
virtual void Check(const Player& P);
};
template<class T>
class Parallel_MAC_Check: public Separate_MAC_Check<T>
{
// Different channel for every round
Player send_player;
// Managed by Summer
Player* receive_player;
vector< Summer<T>* > summers;
int send_base_player;
WaitQueue< vector<T> > value_queue;
public:
Parallel_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0);
virtual ~Parallel_MAC_Check();
virtual void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
virtual void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
friend class Summer<T>;
};
template<class T>
class Direct_MAC_Check: public Separate_MAC_Check<T>
{
int open_counter;
vector<octetStream> oss;
public:
Direct_MAC_Check(const T& ai, Names& Nms, int thread_num);
~Direct_MAC_Check();
void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
};
#endif

235
Auth/Subroutines.cpp Normal file
View File

@@ -0,0 +1,235 @@
// (C) 2016 University of Bristol. See License.txt
#include "Auth/Subroutines.h"
#include "Tools/random.h"
#include "Exceptions/Exceptions.h"
#include "Tools/Commit.h"
/* To ease readability as I re-write this program the following conventions
* will be used.
* For a variable v index by a player i
* Comm_v[i] is the commitment string for player i
* Open_v[i] is the opening data for player i
*/
// Special version for octetStreams
void Commit(vector< vector<octetStream> >& Comm_data,
vector<octetStream>& Open_data,
const vector< vector<octetStream> >& data,const Player& P,int num_runs)
{
int my_number=P.my_num();
for (int i=0; i<num_runs; i++)
{ Comm_data[i].resize(P.num_players());
Commit(Comm_data[i][my_number],Open_data[i],data[i][my_number],my_number);
P.Broadcast_Receive(Comm_data[i]);
}
}
// Special version for octetStreams
void Open(vector< vector<octetStream> >& data,
const vector< vector<octetStream> >& Comm_data,
const vector<octetStream>& My_Open_data,
const Player& P,int num_runs,int dont)
{
int my_number=P.my_num();
int num_players=P.num_players();
vector<octetStream> Open_data(num_players);
for (int i=0; i<num_runs; i++)
{ if (i!=dont)
{ Open_data[my_number]=My_Open_data[i];
P.Broadcast_Receive(Open_data);
for (int j=0; j<num_players; j++)
{ if (j!=my_number)
{ if (!Open(data[i][j],Comm_data[i][j],Open_data[j],j))
{ throw invalid_commitment(); }
}
}
}
}
}
void Open(vector< vector<octetStream> >& data,
const vector< vector<octetStream> >& Comm_data,
const vector<octetStream>& My_Open_data,
const vector<int> open,
const Player& P,int num_runs)
{
int my_number=P.my_num();
int num_players=P.num_players();
vector<octetStream> Open_data(num_players);
for (int i=0; i<num_runs; i++)
{ if (open[i]==1)
{ Open_data[my_number]=My_Open_data[i];
P.Broadcast_Receive(Open_data);
for (int j=0; j<num_players; j++)
{ if (j!=my_number)
{ if (!Open(data[i][j],Comm_data[i][j],Open_data[j],j))
{ throw invalid_commitment(); }
}
}
}
}
}
void Commit_To_Challenge(vector<unsigned int>& e,
vector<octetStream>& Comm_e,vector<octetStream>& Open_e,
const Player& P,int num_runs)
{
PRNG G;
G.ReSeed();
e.resize(P.num_players());
Comm_e.resize(P.num_players());
Open_e.resize(P.num_players());
e[P.my_num()]=G.get_uint()%num_runs;
octetStream ee; ee.store(e[P.my_num()]);
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num());
P.Broadcast_Receive(Comm_e);
}
int Open_Challenge(vector<unsigned int>& e,vector<octetStream>& Open_e,
const vector<octetStream>& Comm_e,
const Player& P,int num_runs)
{
// Now open the challenge commitments and determine which run was for real
P.Broadcast_Receive(Open_e);
int challenge=0;
octetStream ee;
for (int i = 0; i < P.num_players(); i++)
{ if (i != P.my_num())
{ if (!Open(ee,Comm_e[i],Open_e[i],i))
{ throw invalid_commitment(); }
ee.get(e[i]);
}
challenge+=e[i];
}
challenge = challenge % num_runs;
return challenge;
}
template<class T>
void Create_Random(T& ans,const Player& P)
{
PRNG G;
G.ReSeed();
vector<T> e(P.num_players());
vector<octetStream> Comm_e(P.num_players());
vector<octetStream> Open_e(P.num_players());
e[P.my_num()].randomize(G);
octetStream ee;
e[P.my_num()].pack(ee);
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num());
P.Broadcast_Receive(Comm_e);
P.Broadcast_Receive(Open_e);
ans.assign_zero();
for (int i = 0; i < P.num_players(); i++)
{ if (i != P.my_num())
{ if (!Open(ee,Comm_e[i],Open_e[i],i))
{ throw invalid_commitment(); }
e[i].unpack(ee);
}
ans.add(ans,e[i]);
}
}
void Create_Random_Seed(octet* seed,const Player& P,int len)
{
PRNG G;
G.ReSeed();
vector<octetStream> e(P.num_players());
vector<octetStream> Comm_e(P.num_players());
vector<octetStream> Open_e(P.num_players());
G.get_octetStream(e[P.my_num()],len);
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],e[P.my_num()],P.my_num());
P.Broadcast_Receive(Comm_e);
P.Broadcast_Receive(Open_e);
memset(seed,0,len*sizeof(octet));
for (int i = 0; i < P.num_players(); i++)
{ if (i != P.my_num())
{ if (!Open(e[i],Comm_e[i],Open_e[i],i))
{ throw invalid_commitment(); }
}
for (int j=0; j<len; j++)
{ seed[j]=seed[j]^(e[i].get_data()[j]); }
}
}
template<class T>
void Commit_And_Open(vector<T>& data,const Player& P)
{
vector<octetStream> Comm_data(P.num_players());
vector<octetStream> Open_data(P.num_players());
octetStream ee;
data[P.my_num()].pack(ee);
Commit(Comm_data[P.my_num()],Open_data[P.my_num()],ee,P.my_num());
P.Broadcast_Receive(Comm_data);
P.Broadcast_Receive(Open_data);
for (int i = 0; i < P.num_players(); i++)
{ if (i != P.my_num())
{ if (!Open(ee,Comm_data[i],Open_data[i],i))
{ throw invalid_commitment(); }
data[i].unpack(ee);
}
}
}
void Commit_To_Seeds(vector<PRNG>& G,
vector< vector<octetStream> >& seeds,
vector< vector<octetStream> >& Comm_seeds,
vector<octetStream>& Open_seeds,
const Player& P,int num_runs)
{
seeds.resize(num_runs);
Comm_seeds.resize(num_runs);
Open_seeds.resize(num_runs);
for (int i=0; i<num_runs; i++)
{ G[i].ReSeed();
seeds[i].resize(P.num_players());
Comm_seeds[i].resize(P.num_players());
Open_seeds[i].resize(P.num_players());
seeds[i][P.my_num()].reset_write_head();
seeds[i][P.my_num()].append(G[i].get_seed(),SEED_SIZE);
}
Commit(Comm_seeds,Open_seeds,seeds,P,num_runs);
}
template void Commit_And_Open(vector<gf2n>& data,const Player& P);
template void Create_Random(gf2n& ans,const Player& P);
#ifdef USE_GF2N_LONG
template void Commit_And_Open(vector<gf2n_short>& data,const Player& P);
template void Create_Random(gf2n_short& ans,const Player& P);
#endif
template void Commit_And_Open(vector<gfp>& data,const Player& P);
template void Create_Random(gfp& ans,const Player& P);

140
Auth/Subroutines.h Normal file
View File

@@ -0,0 +1,140 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Subroutines
#define _Subroutines
/* Defines subroutines for use in both KeyGen and Offline phase
* Mainly focused around commiting and decommitting to various
* bits of data
*/
#include "Tools/random.h"
#include "Networking/Player.h"
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Tools/Commit.h"
/* Run just the Open Protocol for data[i][j] of type octetStream
* 0 <= i < num_runs
* 0 <= j < num_players
* On output data[i][j] contains all the data
* If dont!=-1 then dont open this run
*/
void Open(vector< vector<octetStream> >& data,
const vector< vector<octetStream> >& Comm_data,
const vector<octetStream>& My_Open_data,
const Player& P,int num_runs,int dont=-1);
/* This one takes a vector open which contains 0 and 1
* If 1 then we open this value, otherwise we do not
*/
void Open(vector< vector<octetStream> >& data,
const vector< vector<octetStream> >& Comm_data,
const vector<octetStream>& My_Open_data,
const vector<int> open,
const Player& P,int num_runs);
/* This runs the Commit and Open Protocol for data[i][j] of type T
* 0 <= i < num_runs
* 0 <= j < num_players
* On input data[i][j] is only defined for j=my_number
*/
template<class T>
void Commit_And_Open(vector< vector<T> >& data,const Player& P,int num_runs);
template<class T>
void Commit_And_Open(vector<T>& data,const Player& P);
template<class T>
void Transmit_Data(vector< vector<T> >& data,const Player& P,int num_runs);
/* Functions to Commit and Open a Challenge Value */
void Commit_To_Challenge(vector<unsigned int>& e,
vector<octetStream>& Comm_e,vector<octetStream>& Open_e,
const Player& P,int num_runs);
int Open_Challenge(vector<unsigned int>& e,vector<octetStream>& Open_e,
const vector<octetStream>& Comm_e,
const Player& P,int num_runs);
/* Function to create a shared random value for T=gfp/gf2n */
template<class T>
void Create_Random(T& ans,const Player& P);
/* Produce a random seed of length len */
void Create_Random_Seed(octet* seed,const Player& P,int len);
/* Functions to Commit to Seed Values
* This also initialises the PRNG's in G
*/
void Commit_To_Seeds(vector<PRNG>& G,
vector< vector<octetStream> >& seeds,
vector< vector<octetStream> >& Comm_seeds,
vector<octetStream>& Open_seeds,
const Player& P,int num_runs);
/* Run just the Commit Protocol for data[i][j] of type T
* 0 <= i < num_runs
* 0 <= j < num_players
* On input data[i][j] is only defined for j=my_number
*/
template<class T>
void Commit(vector< vector<octetStream> >& Comm_data,
vector<octetStream>& Open_data,
const vector< vector<T> >& data,const Player& P,int num_runs)
{
octetStream os;
int my_number=P.my_num();
for (int i=0; i<num_runs; i++)
{ os.reset_write_head();
data[i][my_number].pack(os);
Comm_data[i].resize(P.num_players());
Commit(Comm_data[i][my_number],Open_data[i],os,my_number);
P.Broadcast_Receive(Comm_data[i]);
}
}
/* Run just the Open Protocol for data[i][j] of type T
* 0 <= i < num_runs
* 0 <= j < num_players
* On output data[i][j] contains all the data
* If dont!=-1 then dont open this run
*/
template<class T>
void Open(vector< vector<T> >& data,
const vector< vector<octetStream> >& Comm_data,
const vector<octetStream>& My_Open_data,
const Player& P,int num_runs,int dont=-1)
{
octetStream os;
int my_number=P.my_num();
int num_players=P.num_players();
vector<octetStream> Open_data(num_players);
for (int i=0; i<num_runs; i++)
{ if (i!=dont)
{ Open_data[my_number]=My_Open_data[i];
P.Broadcast_Receive(Open_data);
for (int j=0; j<num_players; j++)
{ if (j!=my_number)
{ if (!Open(os,Comm_data[i][j],Open_data[j],j))
{ throw invalid_commitment(); }
os.reset_read_head();
data[i][j].unpack(os);
}
}
}
}
}
#endif

94
Auth/Summer.cpp Normal file
View File

@@ -0,0 +1,94 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Summer.cpp
*
*/
#include "Auth/Summer.h"
#include "Auth/MAC_Check.h"
#include "Tools/int.h"
template<class T>
Summer<T>::Summer(int sum_players, int last_sum_players, int next_sum_players,
Player* send_player, Player* receive_player, Parallel_MAC_Check<T>& MC) :
sum_players(sum_players), last_sum_players(last_sum_players), next_sum_players(next_sum_players),
base_player(0), MC(MC),send_player(send_player), receive_player(receive_player),
thread(0), stop(false), size(0)
{
cout << "Setting up summation by " << sum_players << " players" << endl;
}
template<class T>
Summer<T>::~Summer()
{
delete send_player;
if (timer.elapsed())
cout << T::type_string() << " summation by " << sum_players << " players: "
<< timer.elapsed() << endl;
}
template<class T>
void Summer<T>::run()
{
octetStream os;
while (true)
{
int size = 0;
if (!input_queue.pop(size))
break;
int my_relative_num = positive_modulo(send_player->my_num() - base_player, send_player->num_players());
if (my_relative_num < sum_players)
{
// first summer takes inputs from queue
if (last_sum_players == send_player->num_players())
share_queue.pop(values);
else
{
values.resize(size);
receive_player->receive_player(receive_player->my_num(),os,true);
for (int i = 0; i < size; i++)
values[i].unpack(os);
}
timer.start();
if (T::t() == 2)
add_openings<T,2>(values, *receive_player, sum_players,
last_sum_players, base_player, MC);
else
add_openings<T,0>(values, *receive_player, sum_players,
last_sum_players, base_player, MC);
timer.stop();
os.reset_write_head();
for (int i = 0; i < size; i++)
values[i].pack(os);
if (sum_players > 1)
{
int receiver = positive_modulo(base_player + my_relative_num % next_sum_players,
send_player->num_players());
send_player->send_to(receiver, os, true);
}
else
{
send_player->send_all(os);
MC.value_queue.push(values);
}
}
if (sum_players == 1)
output_queue.push(size);
base_player = (base_player + 1) % send_player->num_players();
}
}
template class Summer<gfp>;
template class Summer<gf2n>;
#ifdef USE_GF2N_LONG
template class Summer<gf2n_short>;
#endif

47
Auth/Summer.h Normal file
View File

@@ -0,0 +1,47 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Summer.h
*
*/
#ifndef OFFLINE_SUMMER_H_
#define OFFLINE_SUMMER_H_
#include "Networking/Player.h"
#include "Tools/WaitQueue.h"
#include "Tools/time-func.h"
#include <pthread.h>
#include <vector>
using namespace std;
template<class T>
class Parallel_MAC_Check;
template<class T>
class Summer
{
int sum_players, last_sum_players, next_sum_players;
int base_player;
Parallel_MAC_Check<T>& MC;
Player* send_player;
Player* receive_player;
Timer timer;
public:
vector<T> values;
pthread_t thread;
WaitQueue<int> input_queue, output_queue;
bool stop;
int size;
WaitQueue< vector<T> > share_queue;
Summer(int sum_players, int last_sum_players, int next_sum_players,
Player* send_player, Player* receive_player, Parallel_MAC_Check<T>& MC);
~Summer();
void run();
};
#endif /* OFFLINE_SUMMER_H_ */

171
Auth/fake-stuff.cpp Normal file
View File

@@ -0,0 +1,171 @@
// (C) 2016 University of Bristol. See License.txt
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Share.h"
#include <fstream>
template<class T>
void make_share(vector<Share<T> >& Sa,const T& a,int N,const T& key,PRNG& G)
{
T mac,x,y;
mac.mul(a,key);
Share<T> S;
S.set_share(a);
S.set_mac(mac);
for (int i=0; i<N-1; i++)
{ x.randomize(G);
y.randomize(G);
Sa[i].set_share(x);
Sa[i].set_mac(y);
S.sub(S,Sa[i]);
}
Sa[N-1]=S;
}
template<class T>
void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key)
{
value.assign(0);
mac.assign(0);
for (int i=0; i<N; i++)
{
value.add(Sa[i].get_share());
mac.add(Sa[i].get_mac());
}
T res;
res.mul(value, key);
if (!res.equal(mac))
{
cout << "Value: " << value << endl;
cout << "Input MAC: " << mac << endl;
cout << "Actual MAC: " << res << endl;
cout << "MAC key: " << key << endl;
throw mac_fail();
}
}
template void make_share(vector<Share<gf2n> >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G);
template void make_share(vector<Share<gfp> >& Sa,const gfp& a,int N,const gfp& key,PRNG& G);
template void check_share(vector<Share<gf2n> >& Sa,gf2n& value,gf2n& mac,int N,const gf2n& key);
template void check_share(vector<Share<gfp> >& Sa,gfp& value,gfp& mac,int N,const gfp& key);
#ifdef USE_GF2N_LONG
template void make_share(vector<Share<gf2n_short> >& Sa,const gf2n_short& a,int N,const gf2n_short& key,PRNG& G);
template void check_share(vector<Share<gf2n_short> >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key);
#endif
// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40)
void expand_byte(gf2n_short& a,int b)
{
gf2n_short x,xp;
x.assign(32+1);
xp.assign_one();
a.assign_zero();
while (b!=0)
{ if ((b&1)==1)
{ a.add(a,xp); }
xp.mul(x);
b>>=1;
}
}
// Have previously worked out the linear equations we need to solve
void collapse_byte(int& b,const gf2n_short& aa)
{
word w=aa.get();
int e35=(w>>35)&1;
int e30=(w>>30)&1;
int e25=(w>>25)&1;
int e20=(w>>20)&1;
int e15=(w>>15)&1;
int e10=(w>>10)&1;
int e5=(w>>5)&1;
int e0=w&1;
int a[8];
a[7]=e35;
a[6]=e30^a[7];
a[5]=e25^a[7];
a[4]=e20^a[5]^a[6]^a[7];
a[3]=e15^a[7];
a[2]=e10^a[3]^a[6]^a[7];
a[1]=e5^a[3]^a[5]^a[7];
a[0]=e0^a[1]^a[2]^a[3]^a[4]^a[5]^a[6]^a[7];
b=0;
for (int i=7; i>=0; i--)
{ b=b<<1;
b+=a[i];
}
}
void generate_keys(const string& directory, int nplayers)
{
PRNG G;
G.ReSeed();
gf2n mac2;
gfp macp;
mac2.assign_zero();
macp.assign_zero();
ofstream outf;
for (int i = 0; i < nplayers; i++)
{
stringstream filename;
filename << directory << "Player-MAC-Keys-P" << i;
mac2.randomize(G);
macp.randomize(G);
cout << "Writing to " << filename.str().c_str() << endl;
outf.open(filename.str().c_str());
outf << nplayers << endl;
macp.output(outf,true);
outf << " ";
mac2.output(outf,true);
outf << endl;
outf.close();
}
}
void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers)
{
gfp sharep;
gf2n share2;
keyp.assign_zero();
key2.assign_zero();
int i, tmpN = 0;
ifstream inpf;
for (i = 0; i < nplayers; i++)
{
stringstream filename;
filename << directory << "Player-MAC-Keys-P" << i;
inpf.open(filename.str().c_str());
if (inpf.fail())
{
inpf.close();
cout << "Error: No MAC key share found for player " << i << std::endl;
exit(1);
}
else
{
inpf >> tmpN; // not needed here
sharep.input(inpf,true);
share2.input(inpf,true);
inpf.close();
}
std::cout << " Key " << i << "\t p: " << sharep << "\n\t 2: " << share2 << std::endl;
keyp.add(sharep);
key2.add(share2);
}
std::cout << "Final MAC keys :\t p: " << keyp << "\n\t\t 2: " << key2 << std::endl;
}

64
Auth/fake-stuff.h Normal file
View File

@@ -0,0 +1,64 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _fake_stuff
#define _fake_stuff
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Share.h"
#include <fstream>
using namespace std;
template<class T>
void make_share(vector<Share<T> >& Sa,const T& a,int N,const T& key,PRNG& G);
template<class T>
void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key);
void expand_byte(gf2n_short& a,int b);
void collapse_byte(int& b,const gf2n_short& a);
// Generate MAC key shares
void generate_keys(const string& directory, int nplayers);
// Read MAC key shares and compute keys
void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers);
template <class T>
class Files
{
public:
ofstream* outf;
int N;
T key;
PRNG G;
Files(int N, const T& key, const string& prefix) : N(N), key(key)
{
outf = new ofstream[N];
for (int i=0; i<N; i++)
{
stringstream filename;
filename << prefix << "-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail())
throw file_error(filename.str().c_str());
}
G.ReSeed();
}
~Files()
{
delete[] outf;
}
void output_shares(const T& a)
{
vector<Share<T> > Sa(N);
make_share(Sa,a,N,key,G);
for (int j=0; j<N; j++)
Sa[j].output(outf[j],false);
}
};
#endif

46
CONFIG Normal file
View File

@@ -0,0 +1,46 @@
# (C) 2016 University of Bristol. See License.txt
ROOT = .
OPTIM= -O3
#PROF = -pg
#DEBUG = -DDEBUG
#MEMPROTECT = -DMEMPROTECT
# set this to your preferred local storage directory
PREP_DIR = '-DPREP_DIR="Player-Data/"'
# set for 128-bit GF(2^n) and/or OT preprocessing
USE_GF2N_LONG = 0
# set to -march=<architecture> for optimization
# AVX2 support (Haswell or later) changes the bit matrix transpose
ARCH = -mtune=native
#use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine
ifeq ($(USE_GF2N_LONG),1)
GF2N_LONG = -DUSE_GF2N_LONG
endif
# MAX_MOD_SZ must be at least ceil(len(p)/len(word))+1
# Default is 3, which suffices for 128-bit p
# MOD = -DMAX_MOD_SZ=3
LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread
ifeq ($(USE_NTL),1)
LDLIBS := -lntl $(LDLIBS)
endif
OS := $(shell uname -s)
ifeq ($(OS), Linux)
LDLIBS += -lrt
endif
CXX = g++
CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH)
CPPFLAGS = $(CFLAGS)
LD = g++

203
Check-Offline.cpp Normal file
View File

@@ -0,0 +1,203 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Check-Offline.cpp
*
*/
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Share.h"
#include "Auth/fake-stuff.h"
#include "Tools/ezOptionParser.h"
#include "Exceptions/Exceptions.h"
#include "Math/Setup.h"
#include "Processor/Data_Files.h"
#include <sstream>
#include <fstream>
#include <vector>
using namespace std;
string PREP_DATA_PREFIX;
template<class T>
void check_mult_triples(const T& key,int N,vector<Data_Files*>& dataF,DataFieldType field_type)
{
T a,b,c,mac,res;
vector<Share<T> > Sa(N),Sb(N),Sc(N);
int n = 0;
try {
while (!dataF[0]->eof<T>(DATA_TRIPLE))
{
for (int i = 0; i < N; i++)
dataF[i]->get_three(field_type, DATA_TRIPLE, Sa[i], Sb[i], Sc[i]);
check_share(Sa, a, mac, N, key);
check_share(Sb, b, mac, N, key);
check_share(Sc, c, mac, N, key);
res.mul(a, b);
if (!res.equal(c))
{
cout << n << ": " << c << " != " << a << " * " << b << endl;
throw bad_value();
}
n++;
}
cout << n << " triples of type " << T::type_string() << endl;
}
catch (exception& e)
{
cout << "Error with triples of type " << T::type_string() << endl;
}
}
template<class T>
void check_bits(const T& key,int N,vector<Data_Files*>& dataF,DataFieldType field_type)
{
T a,b,c,mac,res;
vector<Share<T> > Sa(N),Sb(N),Sc(N);
int n = 0;
while (!dataF[0]->eof<T>(DATA_BIT))
{
for (int i = 0; i < N; i++)
dataF[i]->get_one(field_type, DATA_BIT, Sa[i]);
check_share(Sa, a, mac, N, key);
if (!(a.is_zero() || a.is_one()))
{
cout << n << ": " << a << " neither 0 or 1" << endl;
throw bad_value();
}
n++;
}
cout << n << " bits of type " << T::type_string() << endl;
}
template<class T>
void check_inputs(const T& key,int N,vector<Data_Files*>& dataF)
{
T a, mac, x;
vector< Share<T> > Sa(N);
for (int player = 0; player < N; player++)
{
int n = 0;
while (!dataF[0]->input_eof<T>(player))
{
for (int i = 0; i < N; i++)
dataF[i]->get_input(Sa[i], x, player);
check_share(Sa, a, mac, N, key);
if (!a.equal(x))
throw bad_value();
n++;
}
cout << n << " input masks for player " << player << " of type " << T::type_string() << endl;
}
}
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
gfp::init_field(gfp::pr(), false);
opt.syntax = "./Check-Offline.x <nparties> [OPTIONS]\n";
opt.example = "./Check-Offline.x 3 -lgp 64 -lg2 128\n";
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(p) field (default: 128)", // Help description.
"-lgp", // Flag token.
"--lgp" // Flag token.
);
opt.add(
"40", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(2^n) field (default: 40)", // Help description.
"-lg2", // Flag token.
"--lg2" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Read GF(p) triples in Montgomery representation (default: not set)", // Help description.
"-m", // Flag token.
"--usemont" // Flag token.
);
opt.parse(argc, argv);
string usage;
int lgp, lg2, nparties;
bool use_montgomery = false;
opt.get("--lgp")->getInt(lgp);
opt.get("--lg2")->getInt(lg2);
if (opt.isSet("--usemont"))
use_montgomery = true;
if (opt.firstArgs.size() == 2)
nparties = atoi(opt.firstArgs[1]->c_str());
else if (opt.lastArgs.size() == 1)
nparties = atoi(opt.lastArgs[0]->c_str());
else
{
cerr << "ERROR: invalid number of arguments\n";
opt.getUsage(usage);
cout << usage;
return 1;
}
PREP_DATA_PREFIX = get_prep_dir(nparties, lgp, lg2);
read_setup(PREP_DATA_PREFIX);
if (!use_montgomery)
{
// no montgomery
gfp::init_field(gfp::pr(), false);
}
/* Find number players and MAC keys etc*/
char filename[1024];
gfp keyp,pp; keyp.assign_zero();
gf2n key2,p2; key2.assign_zero();
int N=1;
ifstream inpf;
for (int i= 0; i < nparties; i++)
{
sprintf(filename, (PREP_DATA_PREFIX + "Player-MAC-Keys-P%d").c_str(), i);
inpf.open(filename);
if (inpf.fail()) { throw file_error(filename); }
inpf >> N;
pp.input(inpf,true);
p2.input(inpf,true);
cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl;
keyp.add(pp);
key2.add(p2);
inpf.close();
}
cout << "--------------\n";
cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl;
vector<Data_Files*> dataF(N);
for (int i = 0; i < N; i++)
dataF[i] = new Data_Files(i, N, PREP_DATA_PREFIX);
check_mult_triples(key2, N, dataF, DATA_GF2N);
check_mult_triples(keyp, N, dataF, DATA_MODP);
check_inputs(key2, N, dataF);
check_inputs(keyp, N, dataF);
check_bits(key2, N, dataF, DATA_GF2N);
check_bits(keyp, N, dataF, DATA_MODP);
for (int i = 0; i < N; i++)
delete dataF[i];
}

31
Compiler/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
# (C) 2016 University of Bristol. See License.txt
import compilerLib, program, instructions, types, library, floatingpoint
import inspect
from config import *
from compilerLib import run
# add all instructions to the program VARS dictionary
compilerLib.VARS = {}
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
instr_classes += [t[1] for t in inspect.getmembers(types, inspect.isclass)\
if t[1].__module__ == types.__name__]
instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\
if t[1].__module__ == library.__name__]
for op in instr_classes:
compilerLib.VARS[op.__name__] = op
# add open and input separately due to name conflict
compilerLib.VARS['open'] = instructions.asm_open
compilerLib.VARS['vopen'] = instructions.vasm_open
compilerLib.VARS['gopen'] = instructions.gasm_open
compilerLib.VARS['vgopen'] = instructions.vgasm_open
compilerLib.VARS['input'] = instructions.asm_input
compilerLib.VARS['ginput'] = instructions.gasm_input
compilerLib.VARS['comparison'] = comparison
compilerLib.VARS['floatingpoint'] = floatingpoint

653
Compiler/allocator.py Normal file
View File

@@ -0,0 +1,653 @@
# (C) 2016 University of Bristol. See License.txt
import itertools, time
from collections import defaultdict, deque
from Compiler.exceptions import *
from Compiler.config import *
from Compiler.instructions import *
from Compiler.instructions_base import *
from Compiler.util import *
import Compiler.graph
import Compiler.program
import heapq, itertools
import operator
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
It is based on the precondition that every register is only defined once."""
def __init__(self, n):
self.free = defaultdict(set)
self.alloc = {}
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
self.defined = {}
self.dealloc = set()
self.n = n
def alloc_reg(self, reg, persistent_allocation):
base = reg.vectorbase
if base in self.alloc:
# already allocated
return
reg_type = reg.reg_type
size = base.size
if not persistent_allocation and self.free[reg_type, size]:
res = self.free[reg_type, size].pop()
else:
if self.usage[reg_type] < self.n:
res = self.usage[reg_type]
self.usage[reg_type] += size
else:
raise RegisterOverflowError()
self.alloc[base] = res
if base.vector:
for i,r in enumerate(base.vector):
r.i = self.alloc[base] + i
else:
base.i = self.alloc[base]
def dealloc_reg(self, reg, inst):
self.dealloc.add(reg)
base = reg.vectorbase
if base.vector and not inst.is_vec():
for i in base.vector:
if i not in self.dealloc:
# not all vector elements ready for deallocation
return
self.free[reg.reg_type, base.size].add(self.alloc[base])
if inst.is_vec() and base.vector:
for i in base.vector:
self.defined[i] = inst
else:
self.defined[reg] = inst
def process(self, program, persistent_allocation=False):
for k,i in enumerate(reversed(program)):
unused_regs = []
for j in i.get_def():
if j.vectorbase in self.alloc:
if j in self.defined:
raise CompilerError("Double write on register %s " \
"assigned by '%s' in %s" % \
(j,i,format_trace(i.caller)))
else:
# unused register
self.alloc_reg(j, persistent_allocation)
unused_regs.append(j)
if unused_regs and len(unused_regs) == len(i.get_def()):
# only report if all assigned registers are unused
print "Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller))
for j in i.get_used():
self.alloc_reg(j, persistent_allocation)
for j in i.get_def():
self.dealloc_reg(j, i)
if k % 1000000 == 0 and k > 0:
print "Allocated registers for %d instructions at" % k, time.asctime()
# print "Successfully allocated registers"
# print "modp usage: %d clear, %d secret" % \
# (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp])
# print "GF2N usage: %d clear, %d secret" % \
# (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N])
return self.usage
def determine_scope(block):
last_def = defaultdict(lambda: -1)
used_from_scope = set()
def find_in_scope(reg, scope):
if scope is None:
return False
elif reg in scope.defined_registers:
return True
else:
return find_in_scope(reg, scope.scope)
def read(reg, n):
if last_def[reg] == -1:
if find_in_scope(reg, block.scope):
used_from_scope.add(reg)
reg.can_eliminate = False
else:
print 'Warning: read before write at register', reg
print '\tline %d: %s' % (n, instr)
print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')
print '\tregister trace: %s' % format_trace(reg.caller, '\t\t')
def write(reg, n):
if last_def[reg] != -1:
print 'Warning: double write at register', reg
print '\tline %d: %s' % (n, instr)
print '\ttrace: %s' % format_trace(instr.caller, '\t\t')
last_def[reg] = n
for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used()
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
for reg in outputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
write(i, n)
else:
write(reg, n)
block.used_from_scope = used_from_scope
block.defined_registers = set(last_def.iterkeys())
class Merger:
def __init__(self, block, options):
self.block = block
self.instructions = block.instructions
self.options = options
if options.max_parallel_open:
self.max_parallel_open = int(options.max_parallel_open)
else:
self.max_parallel_open = float('inf')
self.dependency_graph()
def do_merge(self, merges_iter):
""" Merge an iterable of nodes in G, returning the number of merged
instructions and the index of the merged instruction. """
instructions = self.instructions
mergecount = 0
try:
n = next(merges_iter)
except StopIteration:
return mergecount, None
def expand_vector_args(inst):
new_args = []
for arg in inst.args:
if inst.is_vec():
arg.create_vector_elements()
for reg in arg:
new_args.append(reg)
else:
new_args.append(arg)
return new_args
for i in merges_iter:
if isinstance(instructions[n], startinput_class):
instructions[n].args[1] += instructions[i].args[1]
elif isinstance(instructions[n], (stopinput, gstopinput)):
if instructions[n].get_size() != instructions[i].get_size():
raise NotImplemented()
else:
instructions[n].args += instructions[i].args[1:]
else:
if instructions[n].get_size() != instructions[i].get_size():
# merge as non-vector instruction
instructions[n].args = expand_vector_args(instructions[n]) + \
expand_vector_args(instructions[i])
if instructions[n].is_vec():
instructions[n].size = 1
else:
instructions[n].args += instructions[i].args
# join arg_formats if not special iterators
# if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \
# not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)):
# instructions[n].arg_format += instructions[i].arg_format
instructions[i] = None
self.merge_nodes(n, i)
mergecount += 1
return mergecount, n
def compute_max_depths(self, depth_of):
""" Compute the maximum 'depth' at which every instruction can be placed.
This is the minimum depth of any merge_node succeeding an instruction.
Similar to DAG shortest paths algorithm. Traverses the graph in reverse
topological order, updating the max depth of each node's predecessors.
"""
G = self.G
merge_nodes_set = self.open_nodes
top_order = Compiler.graph.topological_sort(G)
max_depth_of = [None] * len(G)
max_depth = max(depth_of)
for i in range(len(max_depth_of)):
if i in merge_nodes_set:
max_depth_of[i] = depth_of[i] - 1
else:
max_depth_of[i] = max_depth
for u in reversed(top_order):
for v in G.pred[u]:
if v not in merge_nodes_set:
max_depth_of[v] = min(max_depth_of[u], max_depth_of[v])
return max_depth_of
def merge_inputs(self):
merges = defaultdict(list)
remaining_input_nodes = []
def do_merge(nodes):
if len(nodes) > 1000:
print 'Merging %d inputs...' % len(nodes)
self.do_merge(iter(nodes))
for n in self.input_nodes:
inst = self.instructions[n]
merge = merges[inst.args[0],inst.__class__]
if len(merge) == 0:
remaining_input_nodes.append(n)
merge.append(n)
if len(merge) >= self.max_parallel_open:
do_merge(merge)
merge[:] = []
for merge in merges.itervalues():
if merge:
do_merge(merge)
self.input_nodes = remaining_input_nodes
def compute_preorder(self, merges, rev_depth_of):
# find flexible nodes that can be on several levels
# and find sources on level 0
G = self.G
merge_nodes_set = self.open_nodes
depth_of = self.depths
instructions = self.instructions
flex_nodes = defaultdict(dict)
starters = []
for n in xrange(len(G)):
if n not in merge_nodes_set and \
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
#print n, depth_of[n], rev_depth_of[n]
flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n)
elif len(G.pred[n]) == 0 and \
not isinstance(self.instructions[n], RawInputInstruction):
starters.append(n)
if n % 10000000 == 0 and n > 0:
print "Processed %d nodes at" % n, time.asctime()
inputs = defaultdict(list)
for node in self.input_nodes:
player = self.instructions[node].args[0]
inputs[player].append(node)
first_inputs = [l[0] for l in inputs.itervalues()]
other_inputs = []
i = 0
while True:
i += 1
found = False
for l in inputs.itervalues():
if i < len(l):
other_inputs.append(l[i])
found = True
if not found:
break
other_inputs.reverse()
preorder = []
# magical preorder for topological search
max_depth = max(merges)
if max_depth > 10000:
print "Computing pre-ordering ..."
for i in xrange(max_depth, 0, -1):
preorder.append(G.get_attr(merges[i], 'stop'))
for j in flex_nodes[i-1].itervalues():
preorder.extend(j)
preorder.extend(flex_nodes[0].get(i, []))
preorder.append(merges[i])
if i % 100000 == 0 and i > 0:
print "Done level %d at" % i, time.asctime()
preorder.extend(other_inputs)
preorder.extend(starters)
preorder.extend(first_inputs)
if max_depth > 10000:
print "Done at", time.asctime()
return preorder
def compute_continuous_preorder(self, merges, rev_depth_of):
print 'Computing pre-ordering for continuous computation...'
preorder = []
sources_for = defaultdict(list)
stops_in = defaultdict(list)
startinputs = []
stopinputs = []
for source in self.sources:
sources_for[rev_depth_of[source]].append(source)
for merge in merges.itervalues():
stop = self.G.get_attr(merge, 'stop')
stops_in[rev_depth_of[stop]].append(stop)
for node in self.input_nodes:
if isinstance(self.instructions[node], startinput_class):
startinputs.append(node)
else:
stopinputs.append(node)
max_round = max(rev_depth_of)
for i in xrange(max_round, 0, -1):
preorder.extend(reversed(stops_in[i]))
preorder.extend(reversed(sources_for[i]))
# inputs at the beginning
preorder.extend(reversed(stopinputs))
preorder.extend(reversed(sources_for[0]))
preorder.extend(reversed(startinputs))
return preorder
def longest_paths_merge(self, instruction_type=startopen_class,
merge_stopopens=True):
""" Attempt to merge instructions of type instruction_type (which are given in
merge_nodes) using longest paths algorithm.
Returns the no. of rounds of communication required after merging (assuming 1 round/instruction).
If merge_stopopens is True, will also merge associated stop_open instructions.
If reorder_between_opens is True, will attempt to place non-opens between start/stop opens.
Doesn't use networkx.
"""
G = self.G
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
if instruction_type is not startopen_class and merge_stopopens:
raise CompilerError('Cannot merge stopopens whilst merging %s instructions' % instruction_type)
if not merge_nodes and not self.input_nodes:
return 0
# merge opens at same depth
merges = defaultdict(list)
for node in merge_nodes:
merges[depths[node]].append(node)
# after merging, the first element in merges[i] remains for each depth i,
# all others are removed from instructions and G
last_nodes = [None, None]
for i in sorted(merges):
merge = merges[i]
if len(merge) > 1000:
print 'Merging %d opens in round %d/%d' % (len(merge), i, len(merges))
nodes = defaultdict(lambda: None)
for b in (False, True):
my_merge = (m for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b)
if merge_stopopens:
my_stopopen = [G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b]
mc, nodes[0,b] = self.do_merge(iter(my_merge))
if merge_stopopens:
mc, nodes[1,b] = self.do_merge(iter(my_stopopen))
# add edges to retain order of gf2n/modp start/stop opens
for j in (0,1):
node2 = nodes[j,True]
nodep = nodes[j,False]
if nodep is not None and node2 is not None:
G.add_edge(nodep, node2)
# add edge to retain order of opens over rounds
if last_nodes[j] is not None:
G.add_edge(last_nodes[j], node2 if nodep is None else nodep)
last_nodes[j] = nodep if node2 is None else node2
merges[i] = last_nodes[0]
self.merge_inputs()
# compute preorder for topological sort
if merge_stopopens and self.options.reorder_between_opens:
if self.options.continuous or not merge_nodes:
rev_depths = self.compute_max_depths(self.real_depths)
preorder = self.compute_continuous_preorder(merges, rev_depths)
else:
rev_depths = self.compute_max_depths(self.depths)
preorder = self.compute_preorder(merges, rev_depths)
else:
preorder = None
if len(instructions) > 100000:
print "Topological sort ..."
order = Compiler.graph.topological_sort(G, preorder)
instructions[:] = [instructions[i] for i in order if instructions[i] is not None]
if len(instructions) > 100000:
print "Done at", time.asctime()
return len(merges)
def dependency_graph(self, merge_class=startopen_class):
""" Create the program dependency graph. """
block = self.block
options = self.options
open_nodes = set()
self.open_nodes = open_nodes
self.input_nodes = []
colordict = defaultdict(lambda: 'gray', startopen='red', stopopen='red',\
ldi='lightblue', ldm='lightblue', stm='blue',\
mov='yellow', mulm='orange', mulc='orange',\
triple='green', square='green', bit='green',\
asm_input='lightgreen')
G = Compiler.graph.SparseDiGraph(len(block.instructions))
self.G = G
reg_nodes = {}
last_def = defaultdict(lambda: -1)
last_mem_write = None
last_mem_read = None
warned_about_mem = []
last_mem_write_of = defaultdict(list)
last_mem_read_of = defaultdict(list)
last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
depths = [0] * len(block.instructions)
self.depths = depths
parallel_open = defaultdict(lambda: 0)
next_available_depth = {}
self.sources = []
self.real_depths = [0] * len(block.instructions)
def add_edge(i, j):
from_merge = isinstance(block.instructions[i], merge_class)
to_merge = isinstance(block.instructions[j], merge_class)
G.add_edge(i, j)
is_source = G.get_attr(i, 'is_source') and G.get_attr(j, 'is_source') and not from_merge
G.set_attr(j, 'is_source', is_source)
for d in (self.depths, self.real_depths):
if d[j] < d[i]:
d[j] = d[i]
def read(reg, n):
if last_def[reg] != -1:
add_edge(last_def[reg], n)
def write(reg, n):
last_def[reg] = n
def handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind):
this = last_access_this_kind[addr,reg_type]
other = last_access_other_kind[addr,reg_type]
if this and other:
if this[-1] < other[0]:
del this[:]
this.append(n)
for inst in other:
add_edge(inst, n)
def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
addr = instr.args[1]
reg_type = instr.args[0].reg_type
if isinstance(addr, int):
for i in range(min(instr.get_size(), 100)):
addr_i = addr + i
handle_mem_access(addr_i, reg_type, last_access_this_kind,
last_access_other_kind)
if not warned_about_mem and (instr.get_size() > 100):
print 'WARNING: Order of memory instructions ' \
'not preserved due to long vector, errors possible'
warned_about_mem.append(True)
else:
handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind)
if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction):
print 'WARNING: Order of memory instructions ' \
'not preserved, errors possible'
# hack
warned_about_mem.append(True)
def keep_order(instr, n, t, arg_index=None):
if arg_index is None:
player = None
else:
player = instr.args[arg_index]
if last[t][player] is not None:
add_edge(last[t][player], n)
last[t][player] = n
for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used()
G.add_node(n, is_source=True)
# if options.debug:
# col = colordict[instr.__class__.__name__]
# G.add_node(n, color=col, label=str(instr))
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
for reg in outputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
write(i, n)
else:
write(reg, n)
if isinstance(instr, merge_class):
open_nodes.add(n)
last_open.append(n)
G.add_node(n, merges=[])
# the following must happen after adding the edge
self.real_depths[n] += 1
depth = depths[n] + 1
if int(options.max_parallel_open):
skipped_depths = set()
while parallel_open[depth] >= int(options.max_parallel_open):
skipped_depths.add(depth)
depth = next_available_depth.get(depth, depth + 1)
for d in skipped_depths:
next_available_depth[d] = depth
parallel_open[depth] += len(instr.args) * instr.get_size()
depths[n] = depth
if isinstance(instr, stopopen_class):
startopen = last_open.popleft()
add_edge(startopen, n)
G.set_attr(startopen, 'stop', n)
G.set_attr(n, 'start', last_open)
G.add_node(n, merges=[])
if isinstance(instr, ReadMemoryInstruction):
if options.preserve_mem_order:
last_mem_read = n
if last_mem_write:
add_edge(last_mem_write, n)
else:
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
elif isinstance(instr, WriteMemoryInstruction):
if options.preserve_mem_order:
last_mem_write = n
if last_mem_read:
add_edge(last_mem_read, n)
else:
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
# keep I/O instructions in order
elif isinstance(instr, IOInstruction):
if last_print_str is not None:
add_edge(last_print_str, n)
last_print_str = n
elif isinstance(instr, PublicFileIOInstruction):
keep_order(instr, n, instr.__class__)
elif isinstance(instr, RawInputInstruction):
keep_order(instr, n, instr.__class__, 0)
self.input_nodes.append(n)
G.add_node(n, merges=[])
player = instr.args[0]
if isinstance(instr, stopinput):
add_edge(last[startinput_class][player], n)
elif isinstance(instr, gstopinput):
add_edge(last[gstartinput][player], n)
elif isinstance(instr, startprivateoutput_class):
keep_order(instr, n, startprivateoutput_class, 2)
elif isinstance(instr, stopprivateoutput_class):
keep_order(instr, n, stopprivateoutput_class, 1)
elif isinstance(instr, prep_class):
keep_order(instr, n, instr.args[0])
if not G.pred[n]:
self.sources.append(n)
if n % 100000 == 0 and n > 0:
print "Processed dependency of %d/%d instructions at" % \
(n, len(block.instructions)), time.asctime()
if len(open_nodes) > 1000:
print "Program has %d %s instructions" % (len(open_nodes), merge_class)
def merge_nodes(self, i, j):
""" Merge node j into i, removing node j """
G = self.G
if j in G[i]:
G.remove_edge(i, j)
if i in G[j]:
G.remove_edge(j, i)
G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))
G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))
G.get_attr(i, 'merges').append(j)
G.remove_node(j)
def eliminate_dead_code(self):
instructions = self.instructions
G = self.G
merge_nodes = self.open_nodes
count = 0
open_count = 0
for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)):
# remove if instruction has result that isn't used
unused_result = not G.degree(i) and len(inst.get_def()) \
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
and not isinstance(inst, (DoNotEliminateInstruction))
stop_node = G.get_attr(i, 'stop')
unused_startopen = stop_node != -1 and instructions[stop_node] is None
if unused_result or unused_startopen:
G.remove_node(i)
merge_nodes.discard(i)
instructions[i] = None
count += 1
if unused_startopen:
open_count += len(inst.args)
if count > 0:
print 'Eliminated %d dead instructions, among which %d opens' % (count, open_count)
def print_graph(self, filename):
f = open(filename, 'w')
print >>f, 'digraph G {'
for i in range(self.G.n):
for j in self.G[i]:
print >>f, '"%d: %s" -> "%d: %s";' % \
(i, self.instructions[i], j, self.instructions[j])
print >>f, '}'
f.close()
def print_depth(self, filename):
f = open(filename, 'w')
for i in range(self.G.n):
print >>f, '%d: %s' % (self.depths[i], self.instructions[i])
f.close()

548
Compiler/comparison.py Normal file
View File

@@ -0,0 +1,548 @@
# (C) 2016 University of Bristol. See License.txt
"""
Functions for secure comparison of GF(p) types.
Most protocols come from [1], with a few subroutines described in [2].
Function naming of comparison routines is as in [1,2], with k always
representing the integer bit length, and kappa the statistical security
parameter.
Most of these routines were implemented before the cint/sint classes, so use
the old-fasioned Register class and assembly instructions instead of operator
overloading.
The PreMulC function has a few variants, depending on whether
preprocessing is only triples/bits, or inverse tuples or "special"
comparison-specific preprocessing is also available.
[1] https://www1.cs.fau.de/filepool/publications/octavian_securescm/smcint-scn10.pdf
[2] https://www1.cs.fau.de/filepool/publications/octavian_securescm/SecureSCM-D.9.2.pdf
"""
# Use constant rounds protocols instead of log rounds
const_rounds = False
# Set use_inv to use preprocessed inverse tuples for more efficient
# online phase comparisons.
use_inv = True
# If do_precomp is not set, use_inv uses standard inverse tuples, otherwise if
# both are set, use a list of "special" tuples of the form
# (r[i], r[i]^-1, r[i] * r[i-1]^-1)
do_precomp = True
import instructions_base
def set_variant(options):
""" Set flags based on the command-line option provided """
global const_rounds, do_precomp, use_inv
variant = options.comparison
if variant == 'log':
const_rounds = False
elif variant == 'plain':
const_rounds = True
use_inv = False
elif variant == 'inv':
const_rounds = True
use_inv = True
do_precomp = True
elif variant == 'sinv':
const_rounds = True
use_inv = True
do_precomp = False
elif variant is not None:
raise CompilerError('Unknown comparison variant: %s' % variant)
def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """
t1 = program.curr_block.new_reg('c')
ldi(t1, 2 ** (n % 30))
for i in range(n / 30):
t2 = program.curr_block.new_reg('c')
mulci(t2, t1, 2 ** 30)
t1 = t2
movc(c, t1)
inverse_of_two = {}
def divide_by_two(res, x):
""" Faster clear division by two using a cached value of 2^-1 mod p """
from program import Program
import types
tape = Program.prog.curr_tape
if len(inverse_of_two) == 0 or tape not in inverse_of_two:
inverse_of_two[tape] = types.cint(1) / 2
mulc(res, x, inverse_of_two[tape])
def LTZ(s, a, k, kappa):
"""
s = (a ?< 0)
k: bit length of a
"""
t = program.curr_block.new_reg('s')
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
def Trunc(d, a, k, m, kappa, signed):
"""
d = a >> m
k: bit length of a
m: compile-time integer
signed: True/False, describes a
"""
a_prime = program.curr_block.new_reg('s')
t = program.curr_block.new_reg('s')
c = [program.curr_block.new_reg('c') for i in range(3)]
c2m = program.curr_block.new_reg('c')
if m == 0:
movs(d, a)
return
elif m == 1:
Mod2(a_prime, a, k, kappa, signed)
else:
Mod2m(a_prime, a, k, m, kappa, signed)
subs(t, a, a_prime)
ldi(c[1], 1)
ld2i(c2m, m)
divc(c[2], c[1], c2m)
mulm(d, t, c[2])
def TruncRoundNearest(a, k, m, kappa):
"""
Returns a / 2^m, rounded to the nearest integer.
k: bit length of m
m: compile-time integer
"""
from types import sint, cint
from library import reveal, load_int_to_secret
if m == 1:
lsb = sint()
Mod2(lsb, a, k, kappa, False)
return (a + lsb) / 2
r_dprime = sint()
r_prime = sint()
r = [sint() for i in range(m)]
u = sint()
PRandM(r_dprime, r_prime, r, k, m, kappa)
c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
c_prime = c % (cint(1) << (m - 1))
if const_rounds:
BitLTC1(u, c_prime, r[:-1], kappa)
else:
BitLTL(u, c_prime, r[:-1], kappa)
bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2
xor = bit + u - 2 * bit * u
prod = xor * r[-1]
# u_prime = xor * u + (1 - xor) * r[-1]
u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
d = (a - a_prime) / (cint(1) << m)
rounding = xor + r[-1] - 2 * prod
return d + rounding
def Mod2m(a_prime, a, k, m, kappa, signed):
"""
a_prime = a % 2^m
k: bit length of a
m: compile-time integer
signed: True/False, describes a
"""
if m >= k:
movs(a_prime, a)
return
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
r = [program.curr_block.new_reg('s') for i in range(m)]
c = program.curr_block.new_reg('c')
c_prime = program.curr_block.new_reg('c')
v = program.curr_block.new_reg('s')
u = program.curr_block.new_reg('s')
t = [program.curr_block.new_reg('s') for i in range(6)]
c2m = program.curr_block.new_reg('c')
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, r, k, m, kappa)
ld2i(c2m, m)
mulm(t[0], r_dprime, c2m)
if signed:
ld2i(c2k1, k - 1)
addm(t[1], a, c2k1)
else:
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
startopen(t[3])
stopopen(c)
modc(c_prime, c, c2m)
if const_rounds:
BitLTC1(u, c_prime, r, kappa)
else:
BitLTL(u, c_prime, r, kappa)
mulm(t[4], u, c2m)
submr(t[5], c_prime, r_prime)
adds(a_prime, t[5], t[4])
return r_dprime, r_prime, c, c_prime, u, t, c2k1
def PRandM(r_dprime, r_prime, b, k, m, kappa):
"""
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
r_prime = random secret integer in range [0, 2^m - 1]
b = array containing bits of r_prime
"""
t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)]
t[0][1] = b[-1]
PRandInt(r_dprime, k + kappa - m)
# r_dprime is always multiplied by 2^m
program.curr_tape.require_bit_length(k + kappa)
bit(b[-1])
for i in range(1,m):
adds(t[i][0], t[i-1][1], t[i-1][1])
bit(b[-i-1])
adds(t[i][1], t[i][0], b[-i-1])
movs(r_prime, t[m-1][1])
def PRandInt(r, k):
"""
r = random secret integer in range [0, 2^k - 1]
"""
t = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(3)]
t[2][k-1] = r
bit(t[2][0])
for i in range(1,k):
adds(t[0][i], t[2][i-1], t[2][i-1])
bit(t[1][i])
adds(t[2][i], t[0][i], t[1][i])
def BitLTC1(u, a, b, kappa):
"""
u = a <? b
a: array of clear bits
b: array of secret bits (same length as a)
"""
k = len(b)
p = [program.curr_block.new_reg('s') for i in range(k)]
if instructions_base.get_global_vector_size() == 1:
b_vec = program.curr_block.new_reg('s', size=k)
for i in range(k):
movs(b_vec[i], b[i])
a_bits = program.curr_block.new_reg('c', size=k)
d = program.curr_block.new_reg('s', size=k)
s = program.curr_block.new_reg('s', size=k)
t = [program.curr_block.new_reg('s', size=k) for j in range(5)]
c = [program.curr_block.new_reg('c', size=k) for j in range(4)]
else:
a_bits = [program.curr_block.new_reg('c') for i in range(k)]
d = [program.curr_block.new_reg('s') for i in range(k)]
s = [program.curr_block.new_reg('s') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(5)]
c = [[program.curr_block.new_reg('c') for i in range(k)] for j in range(4)]
c[1] = [program.curr_block.new_reg('c') for i in range(k)]
# warning: computer scientists count from 0
modci(a_bits[0], a, 2)
c[1][0] = a
for i in range(1,k):
subc(c[0][i], c[1][i-1], a_bits[i-1])
divide_by_two(c[1][i], c[0][i])
modci(a_bits[i], c[1][i], 2)
if instructions_base.get_global_vector_size() == 1:
vmulci(k, c[2], a_bits, 2)
vmulm(k, t[0], b_vec, c[2])
vaddm(k, t[1], b_vec, a_bits)
vsubs(k, d, t[1], t[0])
vaddsi(k, t[2], d, 1)
t[2].create_vector_elements()
pre_input = t[2].vector[:]
else:
for i in range(k):
mulci(c[2][i], a_bits[i], 2)
mulm(t[0][i], b[i], c[2][i])
addm(t[1][i], b[i], a_bits[i])
subs(d[i], t[1][i], t[0][i])
addsi(t[2][i], d[i], 1)
pre_input = t[2][:]
pre_input.reverse()
if use_inv:
if instructions_base.get_global_vector_size() == 1:
PreMulC_with_inverses_and_vectors(p, pre_input)
else:
if do_precomp:
PreMulC_with_inverses(p, pre_input)
else:
raise NotImplementedError('Vectors not compatible with -c sinv')
else:
PreMulC_without_inverses(p, pre_input)
p.reverse()
for i in range(k-1):
subs(s[i], p[i], p[i+1])
subsi(s[k-1], p[k-1], 1)
subcfi(c[3][0], a_bits[0], 1)
mulm(t[4][0], s[0], c[3][0])
for i in range(1,k):
subcfi(c[3][i], a_bits[i], 1)
mulm(t[3][i], s[i], c[3][i])
adds(t[4][i], t[4][i-1], t[3][i])
Mod2(u, t[4][k-1], k, kappa, False)
return p, a_bits, d, s, t, c, b, pre_input
def carry(b, a, compute_p):
""" Carry propogation:
return (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
"""
if a is None:
return b
if b is None:
return a
t = [program.curr_block.new_reg('s') for i in range(3)]
if compute_p:
muls(t[0], a[0], b[0])
muls(t[1], a[0], b[1])
adds(t[2], a[1], t[1])
return t[0], t[2]
# from WP9 report
# length of a is even
def CarryOutAux(d, a, kappa):
k = len(a)
if k > 1 and k % 2 == 1:
a.append(None)
k += 1
u = [None]*(k/2)
a = a[::-1]
if k > 1:
for i in range(k/2):
u[i] = carry(a[2*i+1], a[2*i], i != k/2-1)
CarryOutAux(d, u[:k/2][::-1], kappa)
else:
movs(d, a[0][1])
# carry out with carry-in bit c
def CarryOut(res, a, b, c, kappa):
"""
res = last carry bit in addition of a and b
a: array of clear bits
b: array of secret bits (same length as a)
c: initial carry-in bit
"""
k = len(a)
d = [program.curr_block.new_reg('s') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
s = [program.curr_block.new_reg('s') for i in range(3)]
for i in range(k):
mulm(t[0][i], b[i], a[i])
mulsi(t[1][i], t[0][i], 2)
addm(t[2][i], b[i], a[i])
subs(t[3][i], t[2][i], t[1][i])
d[i] = [t[3][i], t[0][i]]
mulsi(s[0], d[-1][0], c)
adds(s[1], d[-1][1], s[0])
d[-1][1] = s[1]
CarryOutAux(res, d[::-1], kappa)
def BitLTL(res, a, b, kappa):
"""
res = a <? b (logarithmic rounds version)
a: clear integer register
b: array of secret bits (same length as a)
"""
k = len(b)
a_bits = [program.curr_block.new_reg('c') for i in range(k)]
c = [[program.curr_block.new_reg('c') for i in range(k)] for j in range(2)]
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
t = [program.curr_block.new_reg('s') for i in range(1)]
modci(a_bits[0], a, 2)
c[1][0] = a
for i in range(1,k):
subc(c[0][i], c[1][i-1], a_bits[i-1])
divide_by_two(c[1][i], c[0][i])
modci(a_bits[i], c[1][i], 2)
for i in range(len(b)):
subsfi(s[0][i], b[i], 1)
CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa)
subsfi(res, t[0], 1)
return a_bits, s[0]
def PreMulC_with_inverses_and_vectors(p, a):
"""
p[i] = prod_{j=0}^{i-1} a[i]
Variant for vector registers using preprocessed inverses.
"""
k = len(p)
a_vec = program.curr_block.new_reg('s', size=k)
r = program.curr_block.new_reg('s', size=k)
w = program.curr_block.new_reg('s', size=k)
w_tmp = program.curr_block.new_reg('s', size=k)
z = program.curr_block.new_reg('s', size=k)
m = program.curr_block.new_reg('c', size=k)
t = [program.curr_block.new_reg('s', size=k) for i in range(1)]
c = [program.curr_block.new_reg('c') for i in range(k)]
# warning: computer scientists count from 0
if do_precomp:
vinverse(k, r, z)
else:
vprep(k, 'PreMulC', r, z, w_tmp)
for i in range(1,k):
if do_precomp:
muls(w[i], r[i], z[i-1])
else:
movs(w[i], w_tmp[i])
movs(a_vec[i], a[i])
movs(w[0], r[0])
movs(a_vec[0], a[0])
vmuls(k, t[0], w, a_vec)
vstartopen(k, t[0])
vstopopen(k, m)
PreMulC_end(p, a, c, m, z)
def PreMulC_with_inverses(p, a):
"""
Variant using preprocessed inverses or special inverses.
The latter are triples of the form (a_i, a_i^{-1}, a_i * a_{i-1}^{-1}).
See also make_PreMulC() in Fake-Offline.cpp.
"""
k = len(a)
r = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(3)]
w = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
z = [program.curr_block.new_reg('s') for i in range(k)]
m = [program.curr_block.new_reg('c') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(1)]
c = [program.curr_block.new_reg('c') for i in range(k)]
# warning: computer scientists count from 0
for i in range(k):
if do_precomp:
inverse(r[0][i], z[i])
else:
prep('PreMulC', r[0][i], z[i], w[1][i])
if do_precomp:
for i in range(1,k):
muls(w[1][i], r[0][i], z[i-1])
w[1][0] = r[0][0]
for i in range(k):
muls(t[0][i], w[1][i], a[i])
startopen(t[0][i])
stopopen(m[i])
PreMulC_end(p, a, c, m, z)
def PreMulC_without_inverses(p, a):
"""
Plain variant with no extra preprocessing.
"""
k = len(a)
r = [program.curr_block.new_reg('s') for i in range(k)]
s = [program.curr_block.new_reg('s') for i in range(k)]
u = [program.curr_block.new_reg('c') for i in range(k)]
v = [program.curr_block.new_reg('s') for i in range(k)]
w = [program.curr_block.new_reg('s') for i in range(k)]
z = [program.curr_block.new_reg('s') for i in range(k)]
m = [program.curr_block.new_reg('c') for i in range(k)]
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(2)]
#tt = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
u_inv = [program.curr_block.new_reg('c') for i in range(k)]
c = [program.curr_block.new_reg('c') for i in range(k)]
# warning: computer scientists count from 0
for i in range(k):
triple(s[i], r[i], t[0][i])
#adds(tt[0][i], t[0][i], a[i])
#subs(tt[1][i], tt[0][i], a[i])
#startopen(tt[1][i])
startopen(t[0][i])
stopopen(u[i])
for i in range(k-1):
muls(v[i], r[i+1], s[i])
w[0] = r[0]
one = program.curr_block.new_reg('c')
ldi(one, 1)
for i in range(k):
divc(u_inv[i], one, u[i])
# avoid division by zero, just for benchmarking
#divc(u_inv[i], u[i], one)
for i in range(1,k):
mulm(w[i], v[i-1], u_inv[i-1])
for i in range(1,k):
mulm(z[i], s[i], u_inv[i])
for i in range(k):
muls(t[1][i], w[i], a[i])
startopen(t[1][i])
stopopen(m[i])
PreMulC_end(p, a, c, m, z)
def PreMulC_end(p, a, c, m, z):
"""
Helper function for all PreMulC variants. Local operation.
"""
k = len(a)
c[0] = m[0]
for j in range(1,k):
mulc(c[j], c[j-1], m[j])
if isinstance(p, list):
mulm(p[j], z[j], c[j])
if isinstance(p, list):
p[0] = a[0]
else:
mulm(p, z[-1], c[-1])
def PreMulC(a):
p = [type(a[0])() for i in range(len(a))]
instructions_base.set_global_instruction_type(a[0].instruction_type)
if use_inv:
PreMulC_with_inverses(p, a)
else:
PreMulC_without_inverses(p, a)
instructions_base.reset_global_instruction_type()
return p
def KMulC(a):
"""
Return just the product of all items in a
"""
from types import sint, cint
p = sint()
if use_inv:
PreMulC_with_inverses(p, a)
else:
PreMulC_without_inverses(p, a)
return p
def Mod2(a_0, a, k, kappa, signed):
"""
a_0 = a % 2
k: bit length of a
"""
if k <= 1:
movs(a_0, a)
return
r_dprime = program.curr_block.new_reg('s')
r_prime = program.curr_block.new_reg('s')
r_0 = program.curr_block.new_reg('s')
c = program.curr_block.new_reg('c')
c_0 = program.curr_block.new_reg('c')
tc = program.curr_block.new_reg('c')
t = [program.curr_block.new_reg('s') for i in range(6)]
c2k1 = program.curr_block.new_reg('c')
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
mulsi(t[0], r_dprime, 2)
if signed:
ld2i(c2k1, k - 1)
addm(t[1], a, c2k1)
else:
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
startopen(t[3])
stopopen(c)
modci(c_0, c, 2)
mulci(tc, c_0, 2)
mulm(t[4], r_0, tc)
addm(t[5], r_0, c_0)
subs(a_0, t[5], t[4])
# hack for circular dependency
from instructions import *

79
Compiler/compilerLib.py Normal file
View File

@@ -0,0 +1,79 @@
# (C) 2016 University of Bristol. See License.txt
from Compiler.program import Program
from Compiler.config import *
from Compiler.exceptions import *
import instructions, instructions_base, types, comparison, library
import random
import time
import sys
def run(filename, options, param=-1, merge_opens=True, emulate=True, \
reallocate=True, assemblymode=False, debug=False):
""" Compile a file and output a Program object.
If merge_opens is set to True, will attempt to merge any parallelisable open
instructions. """
prog = Program(filename, options, param, assemblymode)
instructions.program = prog
instructions_base.program = prog
types.program = prog
comparison.program = prog
prog.EMULATE = emulate
prog.DEBUG = debug
VARS['program'] = prog
comparison.set_variant(options)
print 'Compiling file', prog.infile
# no longer needed, but may want to support assembly in future (?)
if assemblymode:
prog.restart_main_thread()
for i in xrange(INIT_REG_MAX):
VARS['c%d'%i] = prog.curr_block.new_reg('c')
VARS['s%d'%i] = prog.curr_block.new_reg('s')
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
if i % 10000000 == 0 and i > 0:
print "Initialized %d register variables at" % i, time.asctime()
# first pass determines how many assembler registers are used
prog.FIRST_PASS = True
execfile(prog.infile, VARS)
if instructions_base.Instruction.count != 0:
print 'instructions count', instructions_base.Instruction.count
instructions_base.Instruction.count = 0
prog.FIRST_PASS = False
prog.reset_values()
# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
execfile(prog.infile, VARS)
# optimize the tapes
for tape in prog.tapes:
tape.optimize(options)
# check program still does the same thing after optimizations
if emulate:
clearmem = list(prog.mem_c)
sharedmem = list(prog.mem_s)
prog.emulate()
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
print 'Warning: emulated memory values changed after compiler optimization'
# raise CompilerError('Compiler optimization caused incorrect memory write.')
if prog.main_thread_running:
prog.update_req(prog.curr_tape)
print 'Program requires:', repr(prog.req_num)
print 'Cost:', prog.req_num.cost()
print 'Memory size:', prog.allocated_mem
# finalize the memory
prog.finalize_memory()
return prog

58
Compiler/config.py Normal file
View File

@@ -0,0 +1,58 @@
# (C) 2016 University of Bristol. See License.txt
from collections import defaultdict
#INIT_REG_MAX = 655360
INIT_REG_MAX = 1310720
REG_MAX = 2 ** 32
USER_MEM = 8192
TMP_MEM = 8192
TMP_MEM_BASE = USER_MEM
TMP_REG = 3
TMP_REG_BASE = REG_MAX - TMP_REG
P_VALUES = { -1: 2147483713, \
32: 2147565569, \
64: 9223372036855103489, \
128: 172035116406933162231178957667602464769, \
256: 57896044624266469032429686755131815517604980759976795324963608525438406557697, \
512: 6703903964971298549787012499123814115273848577471136527425966013026501536706464354255445443244279389455058889493431223951165286470575994074291745908195329 }
BIT_LENGTHS = { -1: 24,
32: 24,
64: 32,
128: 64,
256: 64,
512: 64 }
STAT_SEC = { -1: 6,
32: 6,
64: 30,
128: 40,
256: 40,
512: 40 }
COST = { 'modp': defaultdict(lambda: 0,
{ 'triple': 0.00020652622883106154,
'square': 0.00020652622883106154,
'bit': 0.00020652622883106154,
'inverse': 0.00020652622883106154,
'PreMulC': 2 * 0.00020652622883106154,
}),
'gf2n': defaultdict(lambda: 0,
{ 'triple': 0.00020716801325875284,
'square': 0.00020716801325875284,
'inverse': 0.00020716801325875284,
'bit': 1.4492753623188405e-07,
'bittriple': 0.00004828818388140422,
'bitgf2ntriple': 0.00020716801325875284,
'PreMulC': 2 * 0.00020716801325875284,
})
}
try:
from config_mine import *
except:
pass

17
Compiler/exceptions.py Normal file
View File

@@ -0,0 +1,17 @@
# (C) 2016 University of Bristol. See License.txt
class CompilerError(Exception):
"""Base class for compiler exceptions."""
pass
class RegisterOverflowError(CompilerError):
pass
class MemoryOverflowError(CompilerError):
pass
class ArgumentError(CompilerError):
""" Exception raised for errors in instruction argument parsing. """
def __init__(self, arg, msg):
self.arg = arg
self.msg = msg

517
Compiler/floatingpoint.py Normal file
View File

@@ -0,0 +1,517 @@
# (C) 2016 University of Bristol. See License.txt
from math import log, floor, ceil
from Compiler.instructions import *
import types
import comparison
import program
##
## Helper functions for floating point arithmetic
##
def two_power(n):
if isinstance(n, int) and n < 31:
return 2**n
else:
max = types.cint(1) << 31
res = 2**(n%31)
for i in range(n / 31):
res *= max
return res
def EQZ(a, k, kappa):
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
d = [None]*k
r = [types.sint() for i in range(k)]
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
startopen(a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
stopopen(c)
for i,b in enumerate(bits(c, k)):
d[i] = b + r[i] - 2*b*r[i]
return 1 - KOR(d, kappa)
def bits(a,m):
""" Get the bits of an int """
if isinstance(a, int):
res = [None]*m
for i in range(m):
res[i] = a & 1
a >>= 1
else:
c = [[types.cint() for i in range(m)] for i in range(2)]
res = [types.cint() for i in range(m)]
modci(res[0], a, 2)
c[1][0] = a
for i in range(1,m):
subc(c[0][i], c[1][i-1], res[i-1])
divci(c[1][i], c[0][i], 2)
modci(res[i], c[1][i], 2)
return res
def carry(b, a, compute_p=True):
""" Carry propogation:
(p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
"""
if compute_p:
t1 = a[0]*b[0]
else:
t1 = None
t2 = a[1] + a[0]*b[1]
return (t1, t2)
def or_op(a, b, void=None):
return a + b - a*b
def mul_op(a, b, void=None):
return a * b
def PreORC(a, kappa=None, m=None, raw=False):
k = len(a)
if k == 1:
return [a[0]]
m = m or k
if isinstance(a[0], types.sgf2n):
max_k = program.Program.prog.galois_length - 1
else:
max_k = int(log(program.Program.prog.P) / log(2)) - kappa
if k <= max_k:
p = [None] * m
if m == k:
p[0] = a[0]
if isinstance(a[0], types.sgf2n):
b = comparison.PreMulC([3 - a[i] for i in range(k)])
for i in range(m):
tmp = b[k-1-i]
if not raw:
tmp = tmp.bit_decompose()[0]
p[m-1-i] = 1 - tmp
else:
t = [types.sint() for i in range(m)]
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
for i in range(m):
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
p[m-1-i] = 1 - t[i]
return p
else:
# not constant-round anymore
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])
def PreOpL(op, items):
"""
Uses algorithm from SecureSCM WP9 deliverable.
op must be a binary function that outputs a new register
"""
k = len(items)
logk = int(ceil(log(k,2)))
kmax = 2**logk
output = list(items)
for i in range(logk):
for j in range(kmax/(2**(i+1))):
y = two_power(i) + j*two_power(i+1) - 1
for z in range(1, 2**i+1):
if y+z < k:
output[y+z] = op(output[y], output[y+z], j != 0)
return output
def PreOpN(op, items):
""" Naive PreOp algorithm """
k = len(items)
output = [None]*k
output[0] = items[0]
for i in range(1, k):
output[i] = op(output[i-1], items[i])
return output
def PreOR(a, kappa=None, raw=False):
if comparison.const_rounds:
return PreORC(a, kappa, raw=raw)
else:
return PreOpL(or_op, a)
def KOpL(op, a):
k = len(a)
if k == 1:
return a[0]
else:
t1 = KOpL(op, a[:k/2])
t2 = KOpL(op, a[k/2:])
return op(t1, t2)
def KORL(a, kappa):
""" log rounds k-ary OR """
k = len(a)
if k == 1:
return a[0]
else:
t1 = KORL(a[:k/2], kappa)
t2 = KORL(a[k/2:], kappa)
return t1 + t2 - t1*t2
def KORC(a, kappa):
return PreORC(a, kappa, 1)[0]
def KOR(a, kappa):
if comparison.const_rounds:
return KORC(a, kappa)
else:
return KORL(a, None)
def KMul(a):
if comparison.const_rounds:
return comparison.KMulC(a)
else:
return KOpL(mul_op, a)
def Inv(a):
""" Invert a non-zero value """
t = [types.sint() for i in range(3)]
c = [types.cint() for i in range(2)]
one = types.cint()
ldi(one, 1)
inverse(t[0], t[1])
s = t[0]*a
asm_open(c[0], s)
# avoid division by zero for benchmarking
divc(c[1], one, c[0])
#divc(c[1], c[0], one)
return c[1]*t[0]
def BitAdd(a, b, bits_to_compute=None):
""" Add the bits a[k-1], ..., a[0] and b[k-1], ..., b[0], return k+1
bits s[0], ... , s[k] """
k = len(a)
if not bits_to_compute:
bits_to_compute = range(k)
d = [None] * k
for i in range(1,k):
#assert(a[i].value == 0 or a[i].value == 1)
#assert(b[i].value == 0 or b[i].value == 1)
t = a[i]*b[i]
d[i] = (a[i] + b[i] - 2*t, t)
#assert(d[i][0].value == 0 or d[i][0].value == 1)
d[0] = (None, a[0]*b[0])
pg = PreOpL(carry, d)
c = [pair[1] for pair in pg]
# (for testing)
def print_state():
print 'a: ',
for i in range(k):
print '%d ' % a[i].value,
print '\nb: ',
for i in range(k):
print '%d ' % b[i].value,
print '\nd: ',
for i in range(k):
print '%d ' % d[i][0].value,
print '\n ',
for i in range(k):
print '%d ' % d[i][1].value,
print '\n\npg:',
for i in range(k):
print '%d ' % pg[i][0].value,
print '\n ',
for i in range(k):
print '%d ' % pg[i][1].value,
print ''
for bit in c:
pass#assert(bit.value == 0 or bit.value == 1)
s = [None] * (k+1)
if 0 in bits_to_compute:
s[0] = a[0] + b[0] - 2*c[0]
bits_to_compute.remove(0)
#assert(c[0].value == a[0].value*b[0].value)
#assert(s[0].value == 0 or s[0].value == 1)
for i in bits_to_compute:
s[i] = a[i] + b[i] + c[i-1] - 2*c[i]
try:
pass#assert(s[i].value == 0 or s[i].value == 1)
except AssertionError:
print '#assertion failed in BitAdd for s[%d]' % i
print_state()
s[k] = c[k-1]
#print_state()
return s
def BitDec(a, k, m, kappa, bits_to_compute=None):
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
r = [types.sint() for i in range(m)]
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
#assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P)
pow2 = two_power(k + kappa)
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
#rval = 2**m*r_dprime.value + r_prime.value
#assert(rval % 2**m == r_prime.value)
#assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P ))
try:
pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
except AssertionError:
print 'BitDec assertion failed'
print 'a =', a.value
print 'a mod 2^%d =' % k, (a.value % 2**k)
return BitAdd(list(bits(c,m)), r, bits_to_compute)[:-1]
def Pow2(a, l, kappa):
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
x = [types.sint() for i in range(m)]
pow2k = [types.cint() for i in range(m)]
for i in range(m):
pow2k[i] = two_power(2**i)
t[i] = t[i]*pow2k[i] + 1 - t[i]
return KMul(t)
def B2U(a, l, kappa):
pow2a = Pow2(a, l, kappa)
#assert(pow2a.value == 2**a.value)
r = [types.sint() for i in range(l)]
t = types.sint()
c = types.cint()
for i in range(l):
bit(r[i])
comparison.PRandInt(t, kappa)
asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l)))
comparison.program.curr_tape.require_bit_length(l + kappa)
c = list(bits(c, l))
x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)]
#print ' '.join(str(b.value) for b in x)
y = PreOR(x, kappa)
#print ' '.join(str(b.value) for b in y)
return [1 - y[i] for i in range(l)], pow2a
def Trunc(a, l, m, kappa, compute_modulo=False):
""" Oblivious truncation by secret m """
if l == 1:
if compute_modulo:
return a * m, 1 + m
else:
return a * (1 - m)
r = [types.sint() for i in range(l)]
r_dprime = types.sint(0)
r_prime = types.sint(0)
rk = types.sint()
c = types.cint()
ci = [types.cint() for i in range(l)]
d = types.sint()
x, pow2m = B2U(m, l, kappa)
#assert(pow2m.value == 2**m.value)
#assert(sum(b.value for b in x) == m.value)
for i in range(l):
bit(r[i])
t1 = two_power(i) * r[i]
t2 = t1*x[i]
r_prime += t2
r_dprime += t1 - t2
#assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
comparison.PRandInt(rk, kappa)
r_dprime += two_power(l) * rk
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
asm_open(c, a + r_dprime + r_prime)
for i in range(1,l):
ci[i] = c % two_power(i)
#assert(ci[i].value == c.value % 2**i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
#assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P))
lts(d, c_dprime, r_prime, l, kappa)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
else:
pow2inv = Inv(pow2m)
#assert(pow2inv.value * pow2m.value % comparison.program.P == 1)
b = (a - c_dprime + r_prime) * pow2inv - d
return b
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
s = (1 - overflow) * t + overflow * t / 2
return s, overflow
def Int2FL(a, gamma, l, kappa):
lam = gamma - 1
s = types.sint()
comparison.LTZ(s, a, gamma, kappa)
z = EQZ(a, gamma, kappa)
a = (1 - 2 * s) * a
a_bits = BitDec(a, lam, lam, kappa)
a_bits.reverse()
b = PreOR(a_bits, kappa)
t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b)))
p = - (lam - sum(b))
if gamma - 1 > l:
if types.sfloat.round_nearest:
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
p = p + overflow
else:
v = types.sint()
comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * (1 -z)
return v, p, z, s
def FLRound(x, mode):
""" Rounding with floating point output.
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
a = types.sint()
comparison.LTZ(a, p1, k, x.kappa)
b = p1.less_than(-l + 1, k, x.kappa)
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
c = EQZ(v2, l, x.kappa)
if mode == -1:
away_from_zero = 0
mode = x.s
else:
away_from_zero = mode + s1 - 2 * mode * s1
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
d = v.equal(two_power(l), l + 1, x.kappa)
v = d * two_power(l-1) + (1 - d) * v
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
s = (1 - b * mode) * s1
z = or_op(EQZ(v, l, x.kappa), z1)
v = v * (1 - z)
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
return v, p, z, s
def TruncPr(a, k, m, kappa=None):
""" Probabilistic truncation [a/2^m + u]
where Pr[u = 1] = (a % 2^m) / 2^m
"""
if kappa is None:
kappa = 40
b = two_power(k-1) + a
r_prime, r_dprime = types.sint(), types.sint()
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
k, m, kappa)
two_to_m = two_power(m)
r = two_to_m * r_dprime + r_prime
c = (b + r).reveal()
c_prime = c % two_to_m
a_prime = c_prime - r_prime
d = (a - a_prime) / two_to_m
return d
def SDiv(a, b, l, kappa):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
beta = 1 / types.cint(two_power(l))
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = TruncPr(y, 2 * l, l, kappa)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
x1 = (x - x2) * beta
for i in range(theta-1):
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa)
x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
x1 = (x - x2) * beta
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
y = TruncPr(y, 2 * l + 1, l - 1, kappa)
return y
def SDiv_mono(a, b, l, kappa):
theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l)
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w
y = a * w
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
for i in range(theta-1):
y = y * (alpha + x)
# keep y with l bits
y = TruncPr(y, 3 * l, 2 * l, kappa)
x = x**2
# keep x with 2l bits
x = TruncPr(x, 4 * l, 2 * l, kappa)
y = y * (alpha + x)
y = TruncPr(y, 3 * l, 2 * l, kappa)
return y
def FPDiv(a, b, k, f, kappa):
theta = int(ceil(log(k/3.5)))
alpha = types.cint(1 * two_power(2*f))
w = AppRcr(b, k, f, kappa)
x = alpha - b * w
y = a * w
y = TruncPr(y, 2*k, f, kappa)
for i in range(theta):
y = y * (alpha + x)
x = x * x
y = TruncPr(y, 2*k, 2*f, kappa)
x = TruncPr(x, 2*k, 2*f, kappa)
y = y * (alpha + x)
y = TruncPr(y, 2*k, 2*f, kappa)
return y
def AppRcr(b, k, f, kappa):
"""
Approximate reciprocal of [b]:
Given [b], compute [1/b]
"""
alpha = types.cint(int(2.9142 * (2**k)))
c, v = Norm(b, k, f, kappa)
d = alpha - 2 * c
w = d * v
w = TruncPr(w, 2 * k, 2 * (k - f))
return w
def Norm(b, k, f, kappa):
"""
Computes secret integer values [c] and [v_prime] st.
2^{k-1} <= c < 2^k and c = b*v_prime
"""
temp = types.sint()
comparison.LTZ(temp, b, k, kappa)
sign = 1 - 2 * temp # 1 - 2 * [b < 0]
x = sign * b
#x = |b|
bits = x.bit_decompose(k)
y = PreOR(bits)
z = [0] * k
for i in range(k - 1):
z[i] = y[i] - y[i + 1]
z[k - 1] = y[k - 1]
# z[i] = 0 for all i except when bits[i + 1] = first one
#now reverse bits of z[i]
v = types.sint()
for i in range(k):
v += two_power(k - i - 1) * z[i]
c = x * v
v_prime = sign * v
return c, v_prime

220
Compiler/graph.py Normal file
View File

@@ -0,0 +1,220 @@
# (C) 2016 University of Bristol. See License.txt
import heapq
from Compiler.exceptions import *
class GraphError(CompilerError):
pass
class SparseDiGraph(object):
""" Directed graph suitable when each node only has a small number of edges.
Edges are stored as a list instead of a dictionary to save memory, leading
to slower searching for dense graphs.
Node attributes must be specified in advance, as these are stored in the
same list as edges.
"""
def __init__(self, max_nodes, default_attributes=None):
""" max_nodes: maximum no of nodes
default_attributes: dict of node attributes and default values """
if default_attributes is None:
default_attributes = { 'merges': None, 'stop': -1, 'start': -1, 'is_source': True }
self.default_attributes = default_attributes
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
self.n = max_nodes
# each node contains list of default attributes, followed by outoing edges
self.nodes = [self.default_attributes.values() for i in range(self.n)]
self.pred = [[] for i in range(self.n)]
self.weights = {}
def __len__(self):
return self.n
def __getitem__(self, i):
""" Get list of the neighbours of node i """
return self.nodes[i][len(self.default_attributes):]
def __iter__(self):
pass #return iter(self.nodes)
def __contains__(self, i):
return i >= 0 and i < self.n
def add_node(self, i, **attr):
if i >= self.n:
raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n))
node = self.nodes[i]
for a,value in attr.items():
if a in self.default_attributes:
node[self.attribute_pos[a]] = value
else:
raise CompilerError('Invalid attribute %s for graph node' % a)
def set_attr(self, i, attr, value):
if attr in self.default_attributes:
self.nodes[i][self.attribute_pos[attr]] = value
else:
raise CompilerError('Invalid attribute %s for graph node' % attr)
def get_attr(self, i, attr):
return self.nodes[i][self.attribute_pos[attr]]
def remove_node(self, i):
""" Remove node i and all its edges """
succ = self[i]
pred = self.pred[i]
for v in succ:
self.pred[v].remove(i)
#del self.weights[(i,v)]
for v in pred:
# find index to ensure attribute isn't removed instead
index = self[v].index(i) + len(self.default_attributes)
del self.nodes[v][index]
#del self.weights[(v,i)]
#self.nodes[v].remove(i)
self.pred[i] = []
self.nodes[i] = self.default_attributes.values()
def add_edge(self, i, j, weight=1):
if j not in self[i]:
self.nodes[i].append(j)
self.pred[j].append(i)
self.weights[(i,j)] = weight
def add_edges_from(self, tuples):
for edge in tuples:
if len(edge) == 3:
# use weight
self.add_edge(edge[0], edge[1], edge[2])
else:
self.add_edge(edge[0], edge[1])
def remove_edge(self, i, j):
jindex = self[i].index(j) + len(self.default_attributes)
del self.nodes[i][jindex]
self.pred[j].remove(i)
del self.weights[(i,j)]
def remove_edges_from(self, pairs):
for i,j in pairs:
self.remove_edge(i, j)
def degree(self, i):
return len(self.nodes[i]) - len(self.default_attributes)
def topological_sort(G, nbunch=None, pref=None):
seen={}
order_explored=[] # provide order and
explored={} # fast search without more general priorityDictionary
if pref is None:
def get_children(node):
return G[node]
else:
def get_children(node):
if pref.has_key(node):
pref_set = set(pref[node])
for i in G[node]:
if i not in pref_set:
yield i
for i in reversed(pref[node]):
yield i
else:
for i in G[node]:
yield i
if nbunch is None:
nbunch = range(len(G))
for v in nbunch: # process all vertices in G
if v in explored:
continue
fringe=[v] # nodes yet to look at
while fringe:
w=fringe[-1] # depth first search
if w in explored: # already looked down this branch
fringe.pop()
continue
seen[w]=1 # mark as seen
# Check successors for cycles and for new nodes
new_nodes=[]
for n in get_children(w):
if n not in explored:
if n in seen: #CYCLE !!
raise GraphError("Graph contains a cycle at %d (%s,%s)." % \
(n, G[n], G.pred[n]))
new_nodes.append(n)
if new_nodes: # Add new_nodes to fringe
fringe.extend(new_nodes)
else: # No new nodes so w is fully explored
explored[w]=1
order_explored.append(w)
fringe.pop() # done considering this node
order_explored.reverse() # reverse order explored
return order_explored
def dag_shortest_paths(G, source):
top_order = topological_sort(G)
dist = [None] * len(G)
dist[source] = 0
for u in top_order:
if dist[u] is None:
continue
for v in G[u]:
if dist[v] is None or dist[v] > dist[u] + G.weights[(u,v)]:
dist[v] = dist[u] + G.weights[(u,v)]
return dist
def reverse_dag_shortest_paths(G, source):
top_order = reversed(topological_sort(G))
dist = [None] * len(G)
dist[source] = 0
for u in top_order:
if u ==68273:
print 'dist[68273]', dist[u]
print 'pred[u]', G.pred[u]
if dist[u] is None:
continue
for v in G.pred[u]:
if dist[v] is None or dist[v] > dist[u] + G.weights[(v,u)]:
dist[v] = dist[u] + G.weights[(v,u)]
return dist
def single_source_longest_paths(G, source, reverse=False):
# make weights negative, then do shortest paths
for edge in G.weights:
G.weights[edge] = -G.weights[edge]
if reverse:
dist = reverse_dag_shortest_paths(G, source)
else:
dist = dag_shortest_paths(G, source)
#dist = johnson(G, sources)
# reset weights
for edge in G.weights:
G.weights[edge] = -G.weights[edge]
for i,n in enumerate(dist):
if n is None:
dist[i] = 0
else:
dist[i] = -dist[i]
#for k, v in dist.iteritems():
# dist[k] = -v
return dist
def longest_paths(G, sources=None):
# make weights negative, then do shortest paths
for edge in G.weights:
G.weights[edge] = -G.weights[edge]
dist = {}
for source in sources:
print ('%s, ' % source),
dist[source] = dag_shortest_paths(G, source)
#dist = johnson(G, sources)
# reset weights
for edge in G.weights:
G.weights[edge] = -G.weights[edge]
return dist

1310
Compiler/instructions.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,742 @@
# (C) 2016 University of Bristol. See License.txt
import itertools
from random import randint
import time
import inspect
import functools
from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
###
### Opcode constants
###
### Whenever these are changed the corresponding enums in Processor/instruction.h
### MUST also be changed. (+ the documentation)
###
opcodes = dict(
# Load/store
LDI = 0x1,
LDSI = 0x2,
LDMC = 0x3,
LDMS = 0x4,
STMC = 0x5,
STMS = 0x6,
LDMCI = 0x7,
LDMSI = 0x8,
STMCI = 0x9,
STMSI = 0xA,
MOVC = 0xB,
MOVS = 0xC,
PROTECTMEMS = 0xD,
PROTECTMEMC = 0xE,
PROTECTMEMINT = 0xF,
LDMINT = 0xCA,
STMINT = 0xCB,
LDMINTI = 0xCC,
STMINTI = 0xCD,
PUSHINT = 0xCE,
POPINT = 0xCF,
MOVINT = 0xD0,
# Machine
LDTN = 0x10,
LDARG = 0x11,
REQBL = 0x12,
STARG = 0x13,
TIME = 0x14,
START = 0x15,
STOP = 0x16,
USE = 0x17,
USE_INP = 0x18,
RUN_TAPE = 0x19,
JOIN_TAPE = 0x1A,
CRASH = 0x1B,
USE_PREP = 0x1C,
# Addition
ADDC = 0x20,
ADDS = 0x21,
ADDM = 0x22,
ADDCI = 0x23,
ADDSI = 0x24,
SUBC = 0x25,
SUBS = 0x26,
SUBML = 0x27,
SUBMR = 0x28,
SUBCI = 0x29,
SUBSI = 0x2A,
SUBCFI = 0x2B,
SUBSFI = 0x2C,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
MULCI = 0x32,
MULSI = 0x33,
DIVC = 0x34,
DIVCI = 0x35,
MODC = 0x36,
MODCI = 0x37,
LEGENDREC = 0x38,
GMULBITC = 0x136,
GMULBITM = 0x137,
# Open
STARTOPEN = 0xA0,
STOPOPEN = 0xA1,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
SQUARE = 0x52,
INV = 0x53,
GBITTRIPLE = 0x154,
GBITGF2NTRIPLE = 0x155,
INPUTMASK = 0x56,
PREP = 0x57,
# Input
INPUT = 0x60,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
READSOCKETS = 0x64,
WRITESOCKETC = 0x65,
WRITESOCKETS = 0x66,
OPENSOCKET = 0x67,
CLOSESOCKET = 0x68,
# Bitwise logic
ANDC = 0x70,
XORC = 0x71,
ORC = 0x72,
ANDCI = 0x73,
XORCI = 0x74,
ORCI = 0x75,
NOTC = 0x76,
# Bitwise shifts
SHLC = 0x80,
SHRC = 0x81,
SHLCI = 0x82,
SHRCI = 0x83,
# Branching and comparison
JMP = 0x90,
JMPNZ = 0x91,
JMPEQZ = 0x92,
EQZC = 0x93,
LTZC = 0x94,
LTC = 0x95,
GTC = 0x96,
EQC = 0x97,
JMPI = 0x98,
# Integers
LDINT = 0x9A,
ADDINT = 0x9B,
SUBINT = 0x9C,
MULINT = 0x9D,
DIVINT = 0x9E,
# Conversion
CONVINT = 0xC0,
CONVMODP = 0xC1,
GCONVGF2N = 0x1C1,
# IO
PRINTMEM = 0xB0,
PRINTREG = 0XB1,
RAND = 0xB2,
PRINTREGPLAIN = 0xB3,
PRINTCHR = 0xB4,
PRINTSTR = 0xB5,
PUBINPUT = 0xB6,
RAWOUTPUT = 0xB7,
STARTPRIVATEOUTPUT = 0xB8,
STOPPRIVATEOUTPUT = 0xB9,
PRINTCHRINT = 0xBA,
PRINTSTRINT = 0xBB,
GBITDEC = 0x184,
GBITCOM = 0x185,
)
def int_to_bytes(x):
""" 32 bit int to big-endian 4 byte conversion. """
return [(x >> 8*i) % 256 for i in (3,2,1,0)]
global_vector_size = 1
global_vector_size_depth = 0
global_instruction_type_stack = ['modp']
def set_global_vector_size(size):
global global_vector_size, global_vector_size_depth
if size == 1:
return
if global_vector_size == 1 or global_vector_size == size:
global_vector_size = size
global_vector_size_depth += 1
else:
raise CompilerError('Cannot set global vector size when already set')
def set_global_instruction_type(t):
if t == 'modp' or t == 'gf2n':
global_instruction_type_stack.append(t)
else:
raise CompilerError('Invalid type %s for setting global instruction type')
def reset_global_vector_size():
global global_vector_size, global_vector_size_depth
if global_vector_size_depth > 0:
global_vector_size_depth -= 1
if global_vector_size_depth == 0:
global_vector_size = 1
def reset_global_instruction_type():
global_instruction_type_stack.pop()
def get_global_vector_size():
return global_vector_size
def get_global_instruction_type():
return global_instruction_type_stack[-1]
def vectorize(instruction, global_dict=None):
""" Decorator to vectorize instructions. """
if global_dict is None:
global_dict = inspect.getmodule(instruction).__dict__
class Vectorized_Instruction(instruction):
__slots__ = ['size']
def __init__(self, size, *args, **kwargs):
self.size = size
super(Vectorized_Instruction, self).__init__(*args, **kwargs)
for arg,f in zip(self.args, self.arg_format):
if issubclass(ArgFormats[f], RegisterArgFormat):
arg.set_size(size)
def get_code(self):
return (self.size << 9) + self.code
def get_pre_arg(self):
return "%d, " % self.size
def is_vec(self):
return self.size > 1
def get_size(self):
return self.size
def expand(self):
set_global_vector_size(self.size)
super(Vectorized_Instruction, self).expand()
reset_global_vector_size()
@functools.wraps(instruction)
def maybe_vectorized_instruction(*args, **kwargs):
if global_vector_size == 1:
return instruction(*args, **kwargs)
else:
return Vectorized_Instruction(global_vector_size, *args, **kwargs)
maybe_vectorized_instruction.vec_ins = Vectorized_Instruction
maybe_vectorized_instruction.std_ins = instruction
vectorized_name = 'v' + instruction.__name__
Vectorized_Instruction.__name__ = vectorized_name
global_dict[vectorized_name] = Vectorized_Instruction
global_dict[instruction.__name__ + '_class'] = instruction
return maybe_vectorized_instruction
def gf2n(instruction):
""" Decorator to create GF_2^n instruction corresponding to a given
modp instruction.
Adds the new GF_2^n instruction to the globals dictionary. Also adds a
vectorized GF_2^n instruction if a modp version exists. """
global_dict = inspect.getmodule(instruction).__dict__
if global_dict.has_key('v' + instruction.__name__):
vectorized = True
else:
vectorized = False
if isinstance(instruction, type) and issubclass(instruction, Instruction):
instruction_cls = instruction
else:
try:
instruction_cls = global_dict[instruction.__name__ + '_class']
except KeyError:
raise CompilerError('Cannot decorate instruction %s' % instruction)
class GF2N_Instruction(instruction_cls):
__doc__ = instruction_cls.__doc__.replace('c_', 'c^g_').replace('s_', 's^g_')
__slots__ = []
field_type = 'gf2n'
if isinstance(instruction_cls.code, int):
code = (1 << 8) + instruction_cls.code
# set modp registers in arg_format to GF2N registers
if 'gf2n_arg_format' in instruction_cls.__dict__:
arg_format = instruction_cls.gf2n_arg_format
elif isinstance(instruction_cls.arg_format, itertools.repeat):
__f = instruction_cls.arg_format.next()
if __f != 'int' and __f != 'p':
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
else:
__format = []
for __f in instruction_cls.arg_format:
if __f in ('int', 'p', 'ci', 'str'):
__format.append(__f)
else:
__format.append(__f[0] + 'g' + __f[1:])
arg_format = __format
def is_gf2n(self):
return True
def expand(self):
set_global_instruction_type('gf2n')
super(GF2N_Instruction, self).expand()
reset_global_instruction_type()
GF2N_Instruction.__name__ = 'g' + instruction_cls.__name__
if vectorized:
vec_GF2N = vectorize(GF2N_Instruction, global_dict)
@functools.wraps(instruction)
def maybe_gf2n_instruction(*args, **kwargs):
if get_global_instruction_type() == 'gf2n':
if vectorized:
return vec_GF2N(*args, **kwargs)
else:
return GF2N_Instruction(*args, **kwargs)
else:
return instruction(*args, **kwargs)
# If instruction is vectorized, new GF2N instruction must also be
if vectorized:
global_dict[GF2N_Instruction.__name__] = vec_GF2N
else:
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
global_dict[instruction.__name__ + '_class'] = instruction_cls
return maybe_gf2n_instruction
#return instruction
class RegType(object):
""" enum-like static class for Register types """
ClearModp = 'c'
SecretModp = 's'
ClearGF2N = 'cg'
SecretGF2N = 'sg'
ClearInt = 'ci'
Types = [ClearModp, SecretModp, ClearGF2N, SecretGF2N, ClearInt]
@staticmethod
def create_dict(init_value_fn):
""" Create a dictionary with all the RegTypes as keys """
return {
RegType.ClearModp : init_value_fn(),
RegType.SecretModp : init_value_fn(),
RegType.ClearGF2N : init_value_fn(),
RegType.SecretGF2N : init_value_fn(),
RegType.ClearInt : init_value_fn(),
}
class ArgFormat(object):
@classmethod
def check(cls, arg):
return NotImplemented
@classmethod
def encode(cls, arg):
return NotImplemented
class RegisterArgFormat(ArgFormat):
@classmethod
def check(cls, arg):
if not isinstance(arg, program.curr_tape.Register):
raise ArgumentError(arg, 'Invalid register argument')
if arg.i > REG_MAX:
raise ArgumentError(arg, 'Register index too large')
if arg.program != program.curr_tape:
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
util.format_trace(arg.caller))
if arg.reg_type != cls.reg_type:
raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \
(arg.reg_type, cls.reg_type))
@classmethod
def encode(cls, arg):
return int_to_bytes(arg.i)
class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp
class SecretModpAF(RegisterArgFormat):
reg_type = RegType.SecretModp
class ClearGF2NAF(RegisterArgFormat):
reg_type = RegType.ClearGF2N
class SecretGF2NAF(RegisterArgFormat):
reg_type = RegType.SecretGF2N
class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
class IntArgFormat(ArgFormat):
@classmethod
def check(cls, arg):
if not isinstance(arg, (int, long)):
raise ArgumentError(arg, 'Expected an integer-valued argument')
@classmethod
def encode(cls, arg):
return int_to_bytes(arg)
class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
super(ImmediateModpAF, cls).check(arg)
if arg >= 2**31 or arg < -2**31:
raise ArgumentError(arg, 'Immediate value outside of 32-bit range')
class ImmediateGF2NAF(IntArgFormat):
@classmethod
def check(cls, arg):
# bounds checking for GF(2^n)???
super(ImmediateGF2NAF, cls).check(arg)
class PlayerNoAF(IntArgFormat):
@classmethod
def check(cls, arg):
super(PlayerNoAF, cls).check(arg)
if arg > 256:
raise ArgumentError(arg, 'Player number > 256')
class String(ArgFormat):
length = 12
@classmethod
def check(cls, arg):
if not isinstance(arg, str):
raise ArgumentError(arg, 'Argument is not string')
if len(arg) > cls.length:
raise ArgumentError(arg, 'String longer than ' + cls.length)
if '\0' in arg:
raise ArgumentError(arg, 'String contains zero-byte')
@classmethod
def encode(cls, arg):
return arg + '\0' * (cls.length - len(arg))
ArgFormats = {
'c': ClearModpAF,
's': SecretModpAF,
'cw': ClearModpAF,
'sw': SecretModpAF,
'cg': ClearGF2NAF,
'sg': SecretGF2NAF,
'cgw': ClearGF2NAF,
'sgw': SecretGF2NAF,
'ci': ClearIntAF,
'ciw': ClearIntAF,
'i': ImmediateModpAF,
'ig': ImmediateGF2NAF,
'int': IntArgFormat,
'p': PlayerNoAF,
'str': String,
}
def format_str_is_reg(format_str):
return issubclass(ArgFormats[format_str], RegisterArgFormat)
def format_str_is_writeable(format_str):
return format_str_is_reg(format_str) and format_str[-1] == 'w'
class Instruction(object):
"""
Base class for a RISC-type instruction. Has methods for checking arguments,
getting byte encoding, emulating the instruction, etc.
"""
__slots__ = ['args', 'arg_format', 'code', 'caller']
count = 0
def __init__(self, *args, **kwargs):
""" Create an instruction and append it to the program list. """
self.args = list(args)
self.check_args()
if not program.FIRST_PASS:
if kwargs.get('add_to_prog', True):
program.curr_block.instructions.append(self)
if program.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
else:
self.caller = None
if program.EMULATE:
self.execute()
Instruction.count += 1
if Instruction.count % 100000 == 0:
print "Compiled %d lines at" % self.__class__.count, time.asctime()
def get_code(self):
return self.code
def get_encoding(self):
enc = int_to_bytes(self.get_code())
# add the number of registers to a start/stop open instruction
if self.has_var_args():
enc += int_to_bytes(len(self.args))
for arg,format in zip(self.args, self.arg_format):
enc += ArgFormats[format].encode(arg)
return enc
def get_bytes(self):
return bytearray(self.get_encoding())
def execute(self):
""" Emulate execution of this instruction """
raise NotImplementedError('execute method must be implemented')
def check_args(self):
""" Check the args match up with that specified in arg_format """
for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)):
if arg is None:
if not isinstance(self.arg_format, (list, tuple)):
break # end of optional arguments
else:
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
try:
ArgFormats[f].check(arg)
except ArgumentError as e:
raise CompilerError('Invalid argument "%s" to instruction: %s'
% (e.arg, self) + '\n' + e.msg)
def get_used(self):
""" Return the set of registers that are read in this instruction. """
return set(arg for arg,w in zip(self.args, self.arg_format) if \
format_str_is_reg(w) and not format_str_is_writeable(w))
def get_def(self):
""" Return the set of registers that are written to in this instruction. """
return set(arg for arg,w in zip(self.args, self.arg_format) if \
format_str_is_writeable(w))
def get_pre_arg(self):
return ""
def has_var_args(self):
return False
def is_vec(self):
return False
def is_gf2n(self):
return False
def get_size(self):
return 1
def add_usage(self, req_node):
pass
def __str__(self):
return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args)
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
###
### Basic arithmetic
###
class AddBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value + self.args[2].value) % program.P
class SubBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value - self.args[2].value) % program.P
class MulBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value * self.args[2].value) % program.P
###
### Basic arithmetic with immediate values
###
class ImmediateBase(Instruction):
__slots__ = ['op']
def execute(self):
exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op)
class SharedImmediate(ImmediateBase):
__slots__ = []
arg_format = ['sw', 's', 'i']
class ClearImmediate(ImmediateBase):
__slots__ = []
arg_format = ['cw', 'c', 'i']
###
### Memory access instructions
###
class DirectMemoryInstruction(Instruction):
__slots__ = []
def __init__(self, *args, **kwargs):
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
class ReadMemoryInstruction(Instruction):
__slots__ = []
class WriteMemoryInstruction(Instruction):
__slots__ = []
class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
WriteMemoryInstruction):
__slots__ = []
def __init__(self, *args, **kwargs):
if program.curr_tape.prevent_direct_memory_write:
raise CompilerError('Direct memory writing prevented')
super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs)
###
### I/O instructions
###
class DoNotEliminateInstruction(Instruction):
""" What do you think? """
__slots__ = []
class IOInstruction(DoNotEliminateInstruction):
""" Instruction that uses stdin/stdout during runtime. These are linked
to prevent instruction reordering during optimization. """
__slots__ = []
@classmethod
def str_to_int(cls, s):
""" Convert a 4 character string to an integer. """
if len(s) > 4:
raise CompilerError('String longer than 4 characters')
n = 0
for c in reversed(s.ljust(4)):
n <<= 8
n += ord(c)
return n
class AsymmetricCommunicationInstruction(DoNotEliminateInstruction):
""" Instructions involving sending from or to only one party. """
__slots__ = []
class RawInputInstruction(AsymmetricCommunicationInstruction):
""" Raw input instructions. """
__slots__ = []
class PublicFileIOInstruction(DoNotEliminateInstruction):
""" Instruction to reads/writes public information from/to files. """
__slots__ = []
###
### Data access instructions
###
class DataInstruction(Instruction):
__slots__ = []
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, self.data_type), self.get_size())
###
### Integer operations
###
class IntegerInstruction(Instruction):
""" Base class for integer operations. """
__slots__ = []
arg_format = ['ciw', 'ci', 'ci']
###
### Clear comparison instructions
###
class UnaryComparisonInstruction(Instruction):
""" Base class for unary comparisons. """
__slots__ = []
arg_format = ['ciw', 'ci']
###
### Clear shift instructions
###
class ClearShiftInstruction(ClearImmediate):
__slots__ = []
def check_args(self):
super(ClearShiftInstruction, self).check_args()
if program.galois_length > 64:
bits = 127
else:
# assume 64-bit machine
bits = 63
if self.args[2] > bits:
raise CompilerError('Shifting by more than %d bits '
'not implemented' % bits)
###
### Jumps etc
###
class dummywrite(Instruction):
""" Dummy instruction to create source node in the dependency graph,
preventing read-before-write warnings. """
__slots__ = []
def __init__(self, *args, **kwargs):
self.arg_format = [arg.reg_type + 'w' for arg in args]
super(dummywrite, self).__init__(*args, **kwargs)
def execute(self):
pass
def get_encoding(self):
return []
class JumpInstruction(Instruction):
__slots__ = ['jump_arg']
def set_relative_jump(self, value):
if value == -1:
raise CompilerException('Jump by -1 would cause infinite loop')
self.args[self.jump_arg] = value
def get_relative_jump(self):
return self.args[self.jump_arg]
class CISC(Instruction):
"""
Base class for a CISC instruction.
Children must implement expand(self) to process the instruction.
"""
__slots__ = []
code = None
def __init__(self, *args):
self.args = args
self.check_args()
#if EMULATE:
# self.expand()
if not program.FIRST_PASS:
self.expand()
def expand(self):
""" Expand this into a sequence of RISC instructions. """
raise NotImplementedError('expand method must be implemented')

1115
Compiler/library.py Normal file

File diff suppressed because it is too large Load Diff

902
Compiler/program.py Normal file
View File

@@ -0,0 +1,902 @@
# (C) 2016 University of Bristol. See License.txt
from Compiler.config import *
from Compiler.exceptions import *
from Compiler.instructions_base import RegType
import Compiler.instructions
import Compiler.instructions_base
import compilerLib
import allocator as al
import random
import time
import sys, os, errno
import inspect
from collections import defaultdict
import itertools
import math
data_types = dict(
triple = 0,
square = 1,
bit = 2,
inverse = 3,
bittriple = 4,
bitgf2ntriple = 5
)
field_types = dict(
modp = 0,
gf2n = 1,
)
class Program(object):
""" A program consists of a list of tapes and a scheduled order
of execution for these tapes.
These are created by executing a file containing appropriate instructions
and threads. """
def __init__(self, name, options, param=-1, assemblymode=False):
self.options = options
self.init_names(name, assemblymode)
self.P = P_VALUES[param]
self.param = param
self.bit_length = BIT_LENGTHS[param]
print 'Default bit length:', self.bit_length
self.security = STAT_SEC[param]
print 'Default security parameter:', self.security
self.galois_length = int(options.galois)
print 'Galois length:', self.galois_length
self.schedule = [('start', [])]
self.main_ctr = 0
self.tapes = []
self._curr_tape = None
self.EMULATE = True # defaults
self.FIRST_PASS = False
self.DEBUG = False
self.main_thread_running = False
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
self.free_mem_blocks = defaultdict(set)
self.allocated_mem_blocks = {}
self.req_num = None
self.tape_stack = []
self.n_threads = 1
self.free_threads = set()
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % name, 'w')
Program.prog = self
self.reset_values()
def max_par_tapes(self):
""" Upper bound on number of tapes that will be run in parallel.
(Excludes empty tapes) """
if self.n_threads > 1:
if len(self.schedule) > 1:
raise CompilerError('Static and dynamic parallelism not compatible')
return self.n_threads
res = 1
running = defaultdict(lambda: 0)
for action,tapes in self.schedule:
tapes = [t[0] for t in tapes if not t[0].is_empty()]
if action == 'start':
for tape in tapes:
running[tape] += 1
elif action == 'stop':
for tape in tapes:
running[tape] -= 1
else:
raise CompilerError('Invalid schedule action')
res = max(res, sum(running.itervalues()))
return res
def init_names(self, name, assemblymode):
# ignore path to file - source must be in Programs/Source
if 'Programs' in os.listdir(os.getcwd()):
# compile prog in ./Programs/Source directory
self.programs_dir = os.getcwd() + '/Programs'
else:
# assume source is in main SPDZ directory
self.programs_dir = sys.path[0] + '/Programs'
print 'Compiling program in', self.programs_dir
# create extra directories if needed
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
if not os.path.exists(self.programs_dir + '/' + dirname):
os.mkdir(self.programs_dir + '/' + dirname)
name = name.split('/')[-1]
if name.endswith('.mpc'):
self.name = name[:-4]
else:
self.name = name
if assemblymode:
self.infile = self.programs_dir + '/Source/' + self.name + '.asm'
else:
self.infile = self.programs_dir + '/Source/' + self.name + '.mpc'
def new_tape(self, function, args=[], name=None):
if name is None:
name = function.__name__
name = "%s-%s" % (self.name, name)
# make sure there is a current tape
self.curr_tape
tape_index = len(self.tapes)
name += "-%d" % tape_index
self.tape_stack.append(self.curr_tape)
self.curr_tape = Tape(name, self)
self.curr_tape.prevent_direct_memory_write = True
self.tapes.append(self.curr_tape)
function(*args)
self.finalize_tape(self.curr_tape)
if self.tape_stack:
self.curr_tape = self.tape_stack.pop()
return tape_index
def run_tape(self, tape_index, arg):
if self.curr_tape is not self.tapes[0]:
raise CompilerError('Compiler does not support ' \
'recursive spawning of threads')
if self.free_threads:
thread_number = self.free_threads.pop()
else:
thread_number = self.n_threads
self.n_threads += 1
self.curr_tape.start_new_basicblock(name='pre-run_tape')
Compiler.instructions.run_tape(thread_number, arg, tape_index)
self.curr_tape.start_new_basicblock(name='post-run_tape')
self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree)
return thread_number
def join_tape(self, thread_number):
self.curr_tape.start_new_basicblock(name='pre-join_tape')
Compiler.instructions.join_tape(thread_number)
self.curr_tape.start_new_basicblock(name='post-join_tape')
self.free_threads.add(thread_number)
def start_thread(self, thread, arg):
if self.main_thread_running:
# wait for main thread to finish
self.schedule_wait(self.curr_tape)
self.main_thread_running = False
# compile thread if not been used already
if thread.tape not in self.tapes:
self.curr_tape = thread.tape
self.tapes.append(thread.tape)
thread.target(*thread.args)
# add thread to schedule
self.schedule_start(thread.tape, arg)
self.curr_tape = None
def stop_thread(self, thread):
tape = thread.tape
self.schedule_wait(tape)
def update_req(self, tape):
if self.req_num is None:
self.req_num = tape.req_num
else:
self.req_num += tape.req_num
def read_memory(self, filename):
""" Read the clear and shared memory from a file """
f = open(filename)
n = int(f.next())
self.mem_c = [0]*n
self.mem_s = [0]*n
mem = self.mem_c
done_c = False
for line in f:
line = line.split(' ')
a = int(line[0])
b = int(line[1])
if a != -1:
mem[a] = b
elif done_c:
break
else:
mem = self.mem_s
done_c = True
def get_memory(self, mem_type, i):
if mem_type == 'c':
return self.mem_c[i]
elif mem_type == 's':
return self.mem_s[i]
raise CompilerError('Invalid memory type')
def reset_values(self):
""" Reset register and memory values. """
for tape in self.tapes:
tape.reset_registers()
self.mem_c = range(USER_MEM + TMP_MEM)
self.mem_s = range(USER_MEM + TMP_MEM)
random.seed(0)
def write_bytes(self, outfile=None):
""" Write all non-empty threads and schedule to files. """
# runtime doesn't support 'new-style' parallelism yet
old_style = True
nonempty_tapes = [t for t in self.tapes if not t.is_empty()]
sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name
sch_file = open(sch_filename, 'w')
print 'Writing to', sch_filename
sch_file.write(str(self.max_par_tapes()) + '\n')
sch_file.write(str(len(nonempty_tapes)) + '\n')
sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n')
# assign tapes indices (needed for scheduler)
for i,tape in enumerate(nonempty_tapes):
tape.index = i
for sch in self.schedule:
# schedule may still contain empty tapes: ignore these
tapes = filter(lambda x: not x[0].is_empty(), sch[1])
# no empty line
if not tapes:
continue
line = ' '.join(str(t[0].index) +
(':' + str(t[1]) if t[1] is not None else '') for t in tapes)
if old_style:
if sch[0] == 'start':
sch_file.write('%d %s\n' % (len(tapes), line))
else:
sch_file.write('%s %d %s\n' % (tapes[0], len(tapes), line))
sch_file.write('0\n')
sch_file.write(' '.join(sys.argv) + '\n')
for tape in self.tapes:
tape.write_bytes()
def schedule_start(self, tape, arg=None):
""" Schedule the start of a thread. """
if self.schedule[-1][0] == 'start':
self.schedule[-1][1].append((tape, arg))
else:
self.schedule.append(('start', [(tape, arg)]))
def schedule_wait(self, tape):
""" Schedule the end of a thread. """
if self.schedule[-1][0] == 'stop':
self.schedule[-1][1].append((tape, None))
else:
self.schedule.append(('stop', [(tape, None)]))
self.finalize_tape(tape)
self.update_req(tape)
def finalize_tape(self, tape):
if not tape.purged:
tape.optimize(self.options)
tape.write_bytes()
if self.options.asmoutfile:
tape.write_str(self.options.asmoutfile + '-' + tape.name)
tape.purge()
def emulate(self):
""" Emulate execution of entire program. """
self.reset_values()
for sch in self.schedule:
if sch[0] == 'start':
for tape in sch[1]:
self._curr_tape = tape
for block in tape.basicblocks:
for line in block.instructions:
line.execute()
def restart_main_thread(self):
if self.main_thread_running:
# wait for main thread to finish
self.schedule_wait(self._curr_tape)
self.main_thread_running = False
name = '%s-%d' % (self.name, self.main_ctr)
self._curr_tape = Tape(name, self)
self.tapes.append(self._curr_tape)
self.main_ctr += 1
# add to schedule
self.schedule_start(self._curr_tape)
self.main_thread_running = True
@property
def curr_tape(self):
""" The tape that is currently running."""
if self._curr_tape is None:
# Create a new main thread if necessary
self.restart_main_thread()
return self._curr_tape
@curr_tape.setter
def curr_tape(self, value):
self._curr_tape = value
@property
def curr_block(self):
""" The basic block that is currently being created. """
return self.curr_tape.active_basicblock
def malloc(self, size, mem_type):
""" Allocate memory from the top """
if size == 0:
return
if isinstance(mem_type, type):
mem_type = mem_type.reg_type
key = size, mem_type
if self.free_mem_blocks[key]:
addr = self.free_mem_blocks[key].pop()
else:
addr = self.allocated_mem[mem_type]
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)):
print "Memory of type '%s' now of size %d" % (mem_type, addr + size)
self.allocated_mem_blocks[addr,mem_type] = size
return addr
def free(self, addr, mem_type):
""" Free memory """
if self.curr_block.persistent_allocation:
raise CompilerError('Cannot free memory within function block')
size = self.allocated_mem_blocks.pop((addr,mem_type))
self.free_mem_blocks[size,mem_type].add(addr)
def finalize_memory(self):
import library
self.curr_tape.start_new_basicblock(None, 'memory-usage')
for mem_type,size in self.allocated_mem.items():
if size:
#print "Memory of type '%s' of size %d" % (mem_type, size)
library.load_mem(size - 1, mem_type)
def public_input(self, x):
self.public_input_file.write('%s\n' % str(x))
def set_bit_length(self, bit_length):
self.bit_length = bit_length
print 'Changed bit length for comparisons etc. to', bit_length
def set_security(self, security):
self.security = security
print 'Changed statistical security for comparison etc. to', security
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
def __init__(self, name, program, param=-1):
""" Set prime p and the initial instructions and registers. """
self.program = program
self.init_names(name)
self.P = P_VALUES[param]
self.init_registers()
self.req_tree = self.ReqNode(name)
self.req_node = self.req_tree
self.basicblocks = []
self.purged = False
self.active_basicblock = None
self.start_new_basicblock()
self._is_empty = False
self.merge_opens = True
self.if_states = []
self.req_bit_length = defaultdict(lambda: 0)
self.function_basicblocks = {}
self.functions = []
self.prevent_direct_memory_write = False
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None):
self.parent = parent
self.P = parent.P
self.instructions = []
self.name = name
self.index = len(parent.basicblocks)
self.open_queue = []
self.exit_condition = exit_condition
self.exit_block = None
self.previous_block = None
self.scope = scope
self.children = []
if scope is not None:
scope.children.append(self)
self.persistent_allocation = scope.persistent_allocation
else:
self.persistent_allocation = False
def new_reg(self, reg_type, size=None):
return self.parent.new_reg(reg_type, size=size)
def set_return(self, previous_block, sub_block):
self.previous_block = previous_block
self.sub_block = sub_block
def adjust_return(self):
offset = self.sub_block.get_offset(self)
self.previous_block.return_address_store.args[1] = offset
def set_exit(self, condition, exit_true=None):
""" Sets the block which we start from next, depending on the condition.
(Default is to go to next block in the list)
"""
self.exit_condition = condition
self.exit_block = exit_true
for reg in condition.get_used():
reg.can_eliminate = False
def add_jump(self):
""" Add the jump for this block's exit condition to list of
instructions (must be done after merging) """
self.instructions.append(self.exit_condition)
def get_offset(self, next_block):
return next_block.offset - (self.offset + len(self.instructions))
def adjust_jump(self):
""" Set the correct relative jump offset """
offset = self.get_offset(self.exit_block)
self.exit_condition.set_relative_jump(offset)
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
def __str__(self):
return self.name
def is_empty(self):
""" Returns True if the list of basic blocks is empty.
Note: False is returned even when tape only contains basic
blocks with no instructions. However, these are removed when
optimize is called. """
if not self.purged:
self._is_empty = (len(self.basicblocks) == 0)
return self._is_empty
def start_new_basicblock(self, scope=False, name=''):
# use False because None means no scope
if scope is False:
scope = self.active_basicblock
suffix = '%s-%d' % (name, len(self.basicblocks))
sub = self.BasicBlock(self, self.name + '-' + suffix, scope)
self.basicblocks.append(sub)
self.active_basicblock = sub
self.req_node.add_block(sub)
print 'Compiling basic block', sub.name
def init_registers(self):
self.reset_registers()
self.reg_counter = RegType.create_dict(lambda: 0)
def init_names(self, name):
# ignore path to file - source must be in Programs/Source
name = name.split('/')[-1]
if name.endswith('.asm'):
self.name = name[:-4]
else:
self.name = name
self.infile = self.program.programs_dir + '/Source/' + self.name + '.asm'
self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc'
def purge(self):
self._is_empty = (len(self.basicblocks) == 0)
del self.reg_values
del self.basicblocks
del self.active_basicblock
self.purged = True
def unpurged(function):
def wrapper(self, *args, **kwargs):
if self.purged:
print '%s called on purged block %s, ignoring' % \
(function.__name__, self.name)
return
return function(self, *args, **kwargs)
return wrapper
@unpurged
def optimize(self, options):
if len(self.basicblocks) == 0:
print 'Tape %s is empty' % self.name
return
if self.if_states:
raise CompilerError('Unclosed if/else blocks')
print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)
for block in self.basicblocks:
al.determine_scope(block)
# merge open instructions
# need to do this if there are several blocks
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
for i,block in enumerate(self.basicblocks):
if len(block.instructions) > 0:
print 'Processing basic block %s, %d/%d, %d instructions' % \
(block.name, i, len(self.basicblocks), \
len(block.instructions))
# the next call is necessary for allocation later even without merging
merger = al.Merger(block, options)
if options.dead_code_elimination:
if len(block.instructions) > 10000:
print 'Eliminate dead code...'
merger.eliminate_dead_code()
if options.merge_opens and self.merge_opens:
if len(block.instructions) == 0:
block.used_from_scope = set()
block.defined_registers = set()
continue
if len(block.instructions) > 10000:
print 'Merging open instructions...'
numrounds = merger.longest_paths_merge()
if numrounds > 0:
print 'Program requires %d rounds of communication' % numrounds
numinv = sum(len(i.args) for i in block.instructions if isinstance(i, Compiler.instructions.startopen_class))
if numinv > 0:
print 'Program requires %d invocations' % numinv
if options.dead_code_elimination:
block.instructions = filter(lambda x: x is not None, block.instructions)
if not (options.merge_opens and self.merge_opens):
print 'Not merging open instructions in tape %s' % self.name
# add jumps
offset = 0
for block in self.basicblocks:
if block.exit_condition is not None:
block.add_jump()
block.offset = offset
offset += len(block.instructions)
for block in self.basicblocks:
if block.exit_block is not None:
block.adjust_jump()
if block.previous_block is not None:
block.adjust_return()
# now remove any empty blocks (must be done after setting jumps)
self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks)
# allocate registers
reg_counts = self.count_regs()
if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate:
print 'Tape register usage:'
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
print 'Re-allocating...'
allocator = al.StraightlineAllocator(REG_MAX)
def alloc_loop(block):
for reg in block.used_from_scope:
allocator.alloc_reg(reg, block.persistent_allocation)
for child in block.children:
if child.instructions:
alloc_loop(child)
for i,block in enumerate(reversed(self.basicblocks)):
if len(block.instructions) > 10000:
print 'Allocating %s, %d/%d' % \
(block.name, i, len(self.basicblocks))
if block.exit_condition is not None:
jump = block.exit_condition.get_relative_jump()
if isinstance(jump, (int,long)) and jump < 0 and \
block.exit_block.scope is not None:
alloc_loop(block.exit_block.scope)
allocator.process(block.instructions, block.persistent_allocation)
# offline data requirements
print 'Compile offline data requirements...'
self.req_num = self.req_tree.aggregate()
print 'Tape requires', self.req_num
for req,num in self.req_num.items():
if num == float('inf'):
num = -1
if req[1] in data_types:
self.basicblocks[-1].instructions.append(
Compiler.instructions.use(field_types[req[0]], \
data_types[req[1]], num, \
add_to_prog=False))
elif req[1] == 'input':
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_inp(field_types[req[0]], \
req[2], num, \
add_to_prog=False))
elif req[0] == 'modp':
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_prep(req[1], num, \
add_to_prog=False))
elif req[0] == 'gf2n':
self.basicblocks[-1].instructions.append(
Compiler.instructions.guse_prep(req[1], num, \
add_to_prog=False))
if not self.is_empty():
# bit length requirement
self.basicblocks[-1].instructions.append(
Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False))
self.basicblocks[-1].instructions.append(
Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False))
print 'Tape requires prime bit length', self.req_bit_length['p']
print 'Tape requires galois bit length', self.req_bit_length['2']
@unpurged
def _get_instructions(self):
return itertools.chain.\
from_iterable(b.instructions for b in self.basicblocks)
@unpurged
def get_encoding(self):
""" Get the encoding of the program, in human-readable format. """
return [i.get_encoding() for i in self._get_instructions() if i is not None]
@unpurged
def get_bytes(self):
""" Get the byte encoding of the program as an actual string of bytes. """
return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None)
@unpurged
def write_encoding(self, filename):
""" Write the readable encoding to a file. """
print 'Writing to', filename
f = open(filename, 'w')
for line in self.get_encoding():
f.write(str(line) + '\n')
f.close()
@unpurged
def write_str(self, filename):
""" Write the sequence of instructions to a file. """
print 'Writing to', filename
f = open(filename, 'w')
n = 0
for block in self.basicblocks:
if block.instructions:
f.write('# %s\n' % block.name)
for line in block.instructions:
f.write('%s # %d\n' % (line, n))
n += 1
f.close()
@unpurged
def write_bytes(self, filename=None):
""" Write the program's byte encoding to a file. """
if filename is None:
filename = self.outfile
if not filename.endswith('.bc'):
filename += '.bc'
if not 'Bytecode' in filename:
filename = self.program.programs_dir + '/Bytecode/' + filename
print 'Writing to', filename
f = open(filename, 'w')
f.write(self.get_bytes())
f.close()
def new_reg(self, reg_type, size=None):
return self.Register(reg_type, self, size=size)
def count_regs(self, reg_type=None):
if reg_type is None:
return self.reg_counter
else:
return self.reg_counter[reg_type]
def reset_registers(self):
""" Reset register values to zero. """
self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX)
def get_value(self, reg_type, i):
return self.reg_values[reg_type][i]
def __str__(self):
return self.name
class ReqNum(defaultdict):
def __init__(self, init={}):
super(Tape.ReqNum, self).__init__(lambda: 0, init)
def __add__(self, other):
res = Tape.ReqNum()
for i,count in self.items():
res[i] += count
for i,count in other.items():
res[i] += count
return res
def __mul__(self, other):
res = Tape.ReqNum()
for i in self:
res[i] = other * self[i]
return res
__rmul__ = __mul__
def set_all(self, value):
res = Tape.ReqNum()
for i in self:
res[i] = value
return res
def max(self, other):
res = Tape.ReqNum()
for i in self:
res[i] = max(self[i], other[i])
for i in other:
res[i] = max(self[i], other[i])
return res
def cost(self):
return sum(num * COST[req[0]][req[1]] for req,num in self.items() \
if req[1] != 'input')
def __str__(self):
return ", ".join('%s inputs in %s from player %d' \
% (num, req[0], req[2]) \
if req[1] == 'input' \
else '%s %ss in %s' % (num, req[1], req[0]) \
for req,num in self.items())
def __repr__(self):
return repr(dict(self))
class ReqNode(object):
__slots__ = ['num', 'children', 'name', 'blocks']
def __init__(self, name):
self.children = []
self.name = name
self.blocks = []
def aggregate(self, *args):
self.num = Tape.ReqNum()
for block in self.blocks:
for inst in block.instructions:
inst.add_usage(self)
res = reduce(lambda x,y: x + y.aggregate(self.name),
self.children, self.num)
return res
def increment(self, data_type, num=1):
self.num[data_type] += num
def add_block(self, block):
self.blocks.append(block)
class ReqChild(object):
__slots__ = ['aggregator', 'nodes', 'parent']
def __init__(self, aggregator, parent):
self.aggregator = aggregator
self.nodes = []
self.parent = parent
def aggregate(self, name):
res = self.aggregator([node.aggregate() for node in self.nodes])
return res
def add_node(self, tape, name):
new_node = Tape.ReqNode(name)
self.nodes.append(new_node)
tape.req_node = new_node
def open_scope(self, aggregator, scope=False, name=''):
child = self.ReqChild(aggregator, self.req_node)
self.req_node.children.append(child)
child.add_node(self, '%s-%d' % (name, len(self.basicblocks)))
self.start_new_basicblock(name=name)
return child
def close_scope(self, outer_scope, parent_req_node, name):
self.req_node = parent_req_node
self.start_new_basicblock(outer_scope, name)
def require_bit_length(self, bit_length, t='p'):
if t == 'p':
self.req_bit_length[t] = max(bit_length + 1, \
self.req_bit_length[t])
if self.program.param != -1 and bit_length >= self.program.param:
raise CompilerError('Inadequate bit length %d for prime, ' \
'program requires %d bits' % \
(self.program.param, self.req_bit_length['p']))
else:
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
class Register(object):
"""
Class for creating new registers. The register's index is automatically assigned
based on the block's reg_counter dictionary.
The 'value' property is for emulation.
"""
__slots__ = ["reg_type", "program", "i", "value", "_is_active", \
"size", "vector", "vectorbase", "caller", \
"can_eliminate"]
def __init__(self, reg_type, program, value=None, size=None, i=None):
""" Creates a new register.
reg_type must be one of those defined in RegType. """
if Compiler.instructions_base.get_global_instruction_type() == 'gf2n':
if reg_type == RegType.ClearModp:
reg_type = RegType.ClearGF2N
elif reg_type == RegType.SecretModp:
reg_type = RegType.SecretGF2N
self.reg_type = reg_type
self.program = program
if size is None:
size = Compiler.instructions_base.get_global_vector_size()
self.size = size
if i:
self.i = i
else:
self.i = program.reg_counter[reg_type]
program.reg_counter[reg_type] += size
self.vector = []
self.vectorbase = self
if value is not None:
self.value = value
self._is_active = False
self.can_eliminate = True
if Program.prog.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
else:
self.caller = None
if self.i % 1000000 == 0 and self.i > 0:
print "Initialized %d registers at" % self.i, time.asctime()
def set_size(self, size):
if self.size == size:
return
elif self.size == 1 and self.vectorbase is self:
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
# create vector register in assembly mode
self.size = size
self.vector = [self]
for i in range(1,size):
reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)]
reg.set_vectorbase(self)
self.vector.append(reg)
else:
raise CompilerError('Cannot find %s in VARS' % str(self))
else:
raise CompilerError('Cannot reset size of vector register')
def set_vectorbase(self, vectorbase):
if self.vectorbase != self:
raise CompilerError('Cannot assign one register' \
'to several vectors')
self.vectorbase = vectorbase
def create_vector_elements(self):
if self.vector:
return
elif self.size == 1:
self.vector = [self]
return
self.vector = [self]
for i in range(1,self.size):
reg = Tape.Register(self.reg_type, self.program, size=1, i=self.i+i)
reg.set_vectorbase(self)
self.vector.append(reg)
def __getitem__(self, index):
if not self.vector:
self.create_vector_elements()
return self.vector[index]
def __len__(self):
return self.size
def activate(self):
""" Activating a register signals that it will at some point be used
in the program.
Inactive registers are reserved for temporaries for CISC instructions. """
if not self._is_active:
self._is_active = True
@property
def value(self):
return self.program.reg_values[self.reg_type][self.i]
@value.setter
def value(self, val):
while (len(self.program.reg_values[self.reg_type]) <= self.i):
self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX
self.program.reg_values[self.reg_type][self.i] = val
@property
def is_active(self):
return self._is_active
@property
def is_gf2n(self):
return self.reg_type == RegType.ClearGF2N or \
self.reg_type == RegType.SecretGF2N
@property
def is_clear(self):
return self.reg_type == RegType.ClearModp or \
self.reg_type == RegType.ClearGF2N or \
self.reg_type == RegType.ClearInt
def __str__(self):
return self.reg_type + str(self.i)
__repr__ = __str__

9
Compiler/tools.py Normal file
View File

@@ -0,0 +1,9 @@
# (C) 2016 University of Bristol. See License.txt
import itertools
class chain(object):
def __init__(self, *args):
self.args = args
def __iter__(self):
return itertools.chain(*self.args)

2013
Compiler/types.py Normal file

File diff suppressed because it is too large Load Diff

105
Compiler/util.py Normal file
View File

@@ -0,0 +1,105 @@
# (C) 2016 University of Bristol. See License.txt
import math
import operator
def format_trace(trace, prefix=' '):
if trace is None:
return '<omitted>'
else:
return ''.join('\n%sFile "%s", line %s, in %s\n%s %s' %
(prefix,i[0],i[1],i[2],prefix,i[3][0].strip()) \
for i in reversed(trace))
def tuplify(x):
if isinstance(x, (list, tuple)):
return tuple(x)
else:
return (x,)
def untuplify(x):
if len(x) == 1:
return x[0]
else:
return x
def greater_than(a, b, bits):
if isinstance(a, int) and isinstance(b, int):
return a > b
else:
return a.greater_than(b, bits)
def pow2(a, bits):
if isinstance(a, int):
return 2**a
else:
return a.pow2(bits)
def mod2m(a, b, bits, signed):
if isinstance(a, int):
return a % 2**b
else:
return a.mod2m(b, bits, signed=signed)
def right_shift(a, b, bits):
if isinstance(a, int):
return a >> b
else:
return a.right_shift(b, bits)
def bit_decompose(a, bits):
if isinstance(a, (int,long)):
return [int((a >> i) & 1) for i in range(bits)]
else:
return a.bit_decompose(bits)
def bit_compose(bits):
return sum(b << i for i,b in enumerate(bits))
def series(a):
sum = 0
for i in a:
yield sum
sum += i
yield sum
def if_else(cond, a, b):
try:
if isinstance(cond, (bool, int)):
if cond:
return a
else:
return b
return cond.if_else(a, b)
except:
print cond, a, b
raise
def cond_swap(cond, a, b):
if isinstance(cond, (bool, int)):
if cond:
return a, b
else:
return b, a
return cond.cond_swap(a, b)
def log2(x):
#print 'Compute log2 of', x
return int(math.ceil(math.log(x, 2)))
def tree_reduce(function, sequence):
n = len(sequence)
if n == 1:
return sequence[0]
else:
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)]
return tree_reduce(function, reduced + sequence[n/2*2:])
def or_op(a, b):
return a + b - a * b
OR = or_op
def pow2(bits):
powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)]
return tree_reduce(operator.mul, powers)

162
Exceptions/Exceptions.h Normal file
View File

@@ -0,0 +1,162 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Exceptions
#define _Exceptions
#include <exception>
#include <string>
#include <sstream>
using namespace std;
class not_implemented: public exception
{ virtual const char* what() const throw()
{ return "Case not implemented"; }
};
class division_by_zero: public exception
{ virtual const char* what() const throw()
{ return "Division by zero"; }
};
class invalid_plaintext: public exception
{ virtual const char* what() const throw()
{ return "Inconsistent plaintext space"; }
};
class rep_mismatch: public exception
{ virtual const char* what() const throw()
{ return "Representation mismatch"; }
};
class pr_mismatch: public exception
{ virtual const char* what() const throw()
{ return "Prime mismatch"; }
};
class params_mismatch: public exception
{ virtual const char* what() const throw()
{ return "FHE params mismatch"; }
};
class field_mismatch: public exception
{ virtual const char* what() const throw()
{ return "Plaintext Field mismatch"; }
};
class level_mismatch: public exception
{ virtual const char* what() const throw()
{ return "Level mismatch"; }
};
class invalid_length: public exception
{ virtual const char* what() const throw()
{ return "Invalid length"; }
};
class invalid_commitment: public exception
{ virtual const char* what() const throw()
{ return "Invalid Commitment"; }
};
class IO_Error: public exception
{ string msg, ans;
public:
IO_Error(string m) : msg(m)
{ ans="IO-Error : ";
ans+=msg;
}
~IO_Error()throw() { }
virtual const char* what() const throw()
{
return ans.c_str();
}
};
class broadcast_invalid: public exception
{ virtual const char* what() const throw()
{ return "Inconsistent broadcast at some point"; }
};
class bad_keygen: public exception
{ string msg;
public:
bad_keygen(string m) : msg(m) {}
~bad_keygen()throw() { }
virtual const char* what() const throw()
{ string ans="KeyGen has gone wrong: "+msg;
return ans.c_str();
}
};
class bad_enccommit: public exception
{ virtual const char* what() const throw()
{ return "Error in EncCommit"; }
};
class invalid_params: public exception
{ virtual const char* what() const throw()
{ return "Invalid Params"; }
};
class bad_value: public exception
{ virtual const char* what() const throw()
{ return "Some value is wrong somewhere"; }
};
class Offline_Check_Error: public exception
{ string msg;
public:
Offline_Check_Error(string m) : msg(m) {}
~Offline_Check_Error()throw() { }
virtual const char* what() const throw()
{ string ans="Offline-Check-Error : ";
ans+=msg;
return ans.c_str();
}
};
class mac_fail: public exception
{ virtual const char* what() const throw()
{ return "MacCheck Failure"; }
};
class invalid_program: public exception
{ virtual const char* what() const throw()
{ return "Invalid Program"; }
};
class file_error: public exception
{ string filename, ans;
public:
file_error(string m="") : filename(m)
{
ans="File Error : ";
ans+=filename;
}
~file_error()throw() { }
virtual const char* what() const throw()
{
return ans.c_str();
}
};
class end_of_file: public exception
{ virtual const char* what() const throw()
{ return "End of file reached"; }
};
class Processor_Error: public exception
{ string msg;
public:
Processor_Error(string m)
{
msg = "Processor-Error : " + m;
}
~Processor_Error()throw() { }
virtual const char* what() const throw()
{
return msg.c_str();
}
};
class max_mod_sz_too_small : public exception
{ int len;
public:
max_mod_sz_too_small(int len) : len(len) {}
~max_mod_sz_too_small() throw() {}
virtual const char* what() const throw()
{ stringstream out;
out << "MAX_MOD_SZ too small for desired bit length of p, "
<< "must be at least ceil(len(p)/len(word))+1, "
<< "in this case: " << len;
return out.str().c_str();
}
};
class crash_requested: public exception
{ virtual const char* what() const throw()
{ return "Crash requested by program"; }
};
class memory_exception : public exception {};
class how_would_that_work : public exception {};
#endif

554
Fake-Offline.cpp Normal file
View File

@@ -0,0 +1,554 @@
// (C) 2016 University of Bristol. See License.txt
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Share.h"
#include "Math/Setup.h"
#include "Auth/fake-stuff.h"
#include "Exceptions/Exceptions.h"
#include "Math/Setup.h"
#include "Processor/Data_Files.h"
#include "Tools/mkpath.h"
#include "Tools/ezOptionParser.h"
#include <sstream>
#include <fstream>
using namespace std;
string prep_data_prefix;
/* N = Number players
* ntrip = Number triples needed
* str = "2" or "p"
*/
template<class T>
void make_mult_triples(const T& key,int N,int ntrip,const string& str,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
T a,b,c;
vector<Share<T> > Sa(N),Sb(N),Sc(N);
/* Generate Triples */
for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << "Triples-" << str << "-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{
if (!zero)
a.randomize(G);
make_share(Sa,a,N,key,G);
if (!zero)
b.randomize(G);
make_share(Sb,b,N,key,G);
c.mul(a,b);
make_share(Sc,c,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false);
Sb[j].output(outf[j],false);
Sc[j].output(outf[j],false);
}
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
delete[] outf;
}
void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
gf2n a,b,c, one;
one.assign_one();
vector<Share<gf2n> > Sa(N),Sb(N),Sc(N);
/* Generate Triples */
for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << Data_Files::dtype_names[dtype] << "-2-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{
if (!zero)
a.randomize(G);
a.AND(a, one);
make_share(Sa,a,N,key,G);
if (!zero)
b.randomize(G);
if (dtype == DATA_BITTRIPLE)
b.AND(b, one);
make_share(Sb,b,N,key,G);
c.mul(a,b);
make_share(Sc,c,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false);
Sb[j].output(outf[j],false);
Sc[j].output(outf[j],false);
}
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
delete[] outf;
}
/* N = Number players
* ntrip = Number tuples needed
* str = "2" or "p"
*/
template<class T>
void make_square_tuples(const T& key,int N,int ntrip,const string& str,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
T a,c;
vector<Share<T> > Sa(N),Sc(N);
/* Generate Squares */
for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << "Squares-" << str << "-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{
if (!zero)
a.randomize(G);
make_share(Sa,a,N,key,G);
c.mul(a,a);
make_share(Sc,c,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false);
Sc[j].output(outf[j],false);
}
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
delete[] outf;
}
/* N = Number players
* ntrip = Number bits needed
* str = "2" or "p"
*/
template<class T>
void make_bits(const T& key,int N,int ntrip,const string& str,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
T a;
vector<Share<T> > Sa(N);
/* Generate Bits */
for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << "Bits-" << str << "-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{ if ((G.get_uchar()&1)==0 || zero) { a.assign_zero(); }
else { a.assign_one(); }
make_share(Sa,a,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false); }
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
delete[] outf;
}
/* N = Number players
* ntrip = Number inputs needed
* str = "2" or "p"
*
*/
template<class T>
void make_inputs(const T& key,int N,int ntrip,const string& str,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
T a;
vector<Share<T> > Sa(N);
/* Generate Inputs */
for (int player=0; player<N; player++)
{ for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << "Inputs-" << str << "-P" << i << "-" << player;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{
if (!zero)
a.randomize(G);
make_share(Sa,a,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false);
if (j==player)
{ a.output(outf[j],false); }
}
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
}
delete[] outf;
}
/* N = Number players
* ntrip = Number inverses needed
* str = "2" or "p"
*/
template<class T>
void make_inverse(const T& key,int N,int ntrip,bool zero)
{
PRNG G;
G.ReSeed();
ofstream* outf=new ofstream[N];
T a,b;
vector<Share<T> > Sa(N),Sb(N);
/* Generate Triples */
for (int i=0; i<N; i++)
{ stringstream filename;
filename << prep_data_prefix << "Inverses-" << T::type_char() << "-P" << i;
cout << "Opening " << filename.str() << endl;
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
}
for (int i=0; i<ntrip; i++)
{
if (zero)
// ironic?
a.assign_one();
else
do
a.randomize(G);
while (a.is_zero());
make_share(Sa,a,N,key,G);
b=a; b.invert();
make_share(Sb,b,N,key,G);
for (int j=0; j<N; j++)
{ Sa[j].output(outf[j],false);
Sb[j].output(outf[j],false);
}
}
for (int i=0; i<N; i++)
{ outf[i].close(); }
delete[] outf;
}
template<class T>
void make_PreMulC(const T& key, int N, int ntrip, bool zero)
{
stringstream ss;
ss << prep_data_prefix << "PreMulC-" << T::type_char();
Files<T> files(N, key, ss.str());
PRNG G;
G.ReSeed();
T a, b, c;
c = 1;
for (int i=0; i<ntrip; i++)
{
// close the circle
if (i == ntrip - 1 || zero)
a.assign_one();
else
do
a.randomize(G);
while (a.is_zero());
files.output_shares(a);
b = a;
b.invert();
files.output_shares(b);
files.output_shares(a * c);
c = b;
}
}
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
opt.syntax = "./Fake-Offline.x <nplayers> [OPTIONS]\n\nOptions with 2 arguments take the form '-X <#gf2n tuples>,<#modp tuples>'";
opt.example = "./Fake-Offline.x 2 -lgp 128 -lg2 128 --default 10000\n./Fake-Offline.x 3 -trip 50000,10000 -btrip 100000\n";
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(p) field (default: 128)", // Help description.
"-lgp", // Flag token.
"--lgp" // Flag token.
);
opt.add(
"40", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(2^n) field (default: 40)", // Help description.
"-lg2", // Flag token.
"--lg2" // Flag token.
);
opt.add(
"1000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Default number of tuples to generate for ALL data types (default: 1000)", // Help description.
"-d", // Flag token.
"--default" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
2, // Number of args expected.
',', // Delimiter if expecting multiple args.
"Number of triples, for gf2n / modp types", // Help description.
"-trip", // Flag token.
"--ntriples" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
2, // Number of args expected.
',', // Delimiter if expecting multiple args.
"Number of random bits, for gf2n / modp types", // Help description.
"-bit", // Flag token.
"--nbits" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
2, // Number of args expected.
',', // Delimiter if expecting multiple args.
"Number of input tuples, for gf2n / modp types", // Help description.
"-inp", // Flag token.
"--ninputs" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
2, // Number of args expected.
',', // Delimiter if expecting multiple args.
"Number of square tuples, for gf2n / modp types", // Help description.
"-sq", // Flag token.
"--nsquares" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of inverse tuples (modp only)", // Help description.
"-inv", // Flag token.
"--ninverses" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of GF(2) triples", // Help description.
"-btrip", // Flag token.
"--nbittriples" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of GF(2) x GF(2^n) triples", // Help description.
"-mixed", // Flag token.
"--nbitgf2ntriples" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Set all values to zero, but not the shares", // Help description.
"-z", // Flag token.
"--zero" // Flag token.
);
opt.parse(argc, argv);
vector<string> badOptions;
string usage;
unsigned int i;
if(!opt.gotRequired(badOptions))
{
for (i=0; i < badOptions.size(); ++i)
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
return 1;
}
if(!opt.gotExpected(badOptions))
{
for(i=0; i < badOptions.size(); ++i)
cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
return 1;
}
int nplayers;
if (opt.firstArgs.size() == 2)
{
nplayers = atoi(opt.firstArgs[1]->c_str());
}
else if (opt.lastArgs.size() == 1)
{
nplayers = atoi(opt.lastArgs[0]->c_str());
}
else
{
cerr << "ERROR: invalid number of arguments\n";
opt.getUsage(usage);
cout << usage;
return 1;
}
int default_num = 0;
int ntrip2=0, ntripp=0, nbits2=0,nbitsp=0,nsqr2=0,nsqrp=0,ninp2=0,ninpp=0,ninv=0, nbittrip=0, nbitgf2ntrip=0;
vector<int> list_options;
int lg2, lgp;
opt.get("--lgp")->getInt(lgp);
opt.get("--lg2")->getInt(lg2);
opt.get("--default")->getInt(default_num);
ntrip2 = ntripp = nbits2 = nbitsp = nsqr2 = nsqrp = ninp2 = ninpp = ninv =
nbittrip = nbitgf2ntrip = default_num;
if (opt.isSet("--ntriples"))
{
opt.get("--ntriples")->getInts(list_options);
ntrip2 = list_options[0];
ntripp = list_options[1];
}
if (opt.isSet("--nbits"))
{
opt.get("--nbits")->getInts(list_options);
nbits2 = list_options[0];
nbitsp = list_options[1];
}
if (opt.isSet("--ninputs"))
{
opt.get("--ninputs")->getInts(list_options);
ninp2 = list_options[0];
ninpp = list_options[1];
}
if (opt.isSet("--nsquares"))
{
opt.get("--nsquares")->getInts(list_options);
nsqr2 = list_options[0];
nsqrp = list_options[1];
}
if (opt.isSet("--ninverses"))
opt.get("--ninverses")->getInt(ninv);
if (opt.isSet("--nbittriples"))
opt.get("--nbittriples")->getInt(nbittrip);
if (opt.isSet("--nbitgf2ntriples"))
opt.get("--nbitgf2ntriples")->getInt(nbitgf2ntrip);
bool zero = opt.isSet("--zero");
if (zero)
cout << "Set all values to zero" << endl;
PRNG G;
G.ReSeed();
prep_data_prefix = get_prep_dir(nplayers, lgp, lg2);
// Set up the fields
ofstream outf;
bigint p;
generate_online_setup(outf, prep_data_prefix, p, lgp, lg2);
generate_keys(prep_data_prefix, nplayers);
/* Find number players and MAC keys etc*/
gfp keyp,pp; keyp.assign_zero();
gf2n key2,p2; key2.assign_zero();
int tmpN = 0;
ifstream inpf;
// create Player-Data if not there
if (mkdir_p("Player-Data") == -1)
{
cerr << "mkdir_p(Player-Data) failed\n";
throw file_error();
}
for (i = 0; i < (unsigned int)nplayers; i++)
{
stringstream filename;
filename << prep_data_prefix << "Player-MAC-Keys-P" << i;
inpf.open(filename.str().c_str());
if (inpf.fail())
{
inpf.close();
cout << "No MAC key share for player " << i << ", generating a fresh one\n";
pp.randomize(G);
p2.randomize(G);
ofstream outf(filename.str().c_str());
if (outf.fail())
throw file_error(filename.str().c_str());
outf << nplayers << " " << pp << " " << p2;
outf.close();
cout << "Written new MAC key share to " << filename.str() << endl;
}
else
{
inpf >> tmpN; // not needed here
pp.input(inpf,true);
p2.input(inpf,true);
inpf.close();
}
cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl;
keyp.add(pp);
key2.add(p2);
}
cout << "--------------\n";
cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl;
make_mult_triples(key2,nplayers,ntrip2,"2",zero);
make_mult_triples(keyp,nplayers,ntripp,"p",zero);
make_bits(key2,nplayers,nbits2,"2",zero);
make_bits(keyp,nplayers,nbitsp,"p",zero);
make_square_tuples(key2,nplayers,nsqr2,"2",zero);
make_square_tuples(keyp,nplayers,nsqrp,"p",zero);
make_inputs(key2,nplayers,ninp2,"2",zero);
make_inputs(keyp,nplayers,ninpp,"p",zero);
make_inverse(key2,nplayers,ninv,zero);
make_inverse(keyp,nplayers,ninv,zero);
make_bit_triples(key2,nplayers,nbittrip,DATA_BITTRIPLE,zero);
make_bit_triples(key2,nplayers,nbitgf2ntrip,DATA_BITGF2NTRIPLE,zero);
make_PreMulC(key2,nplayers,ninv,zero);
make_PreMulC(keyp,nplayers,ninv,zero);
}

5
HOSTS.example Normal file
View File

@@ -0,0 +1,5 @@
192.168.0.1
192.168.0.2
192.168.0.3
192.168.0.4
192.168.0.5

19
License.txt Normal file
View File

@@ -0,0 +1,19 @@
University of Bristol : Open Access Software Licence
Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
All rights reserved
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Any use of the software for scientific publications or commercial purposes should be reported to the University of Bristol (OSI-notifications@bristol.ac.uk and quote reference 1914). This is for impact and usage monitoring purposes only.
Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk

75
Makefile Normal file
View File

@@ -0,0 +1,75 @@
# (C) 2016 University of Bristol. See License.txt
include CONFIG
MATH = $(patsubst %.cpp,%.o,$(wildcard Math/*.cpp))
TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp))
NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
AUTH = $(patsubst %.cpp,%.o,$(wildcard Auth/*.cpp))
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp))
# OT stuff needs GF2N_LONG, so only compile if this is enabled
ifeq ($(USE_GF2N_LONG),1)
OT = $(patsubst %.cpp,%.o,$(filter-out OT/OText_main.cpp,$(wildcard OT/*.cpp)))
OT_EXE = ot.x ot-offline.x
endif
COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH)
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE)
LIB = libSPDZ.a
LIBSIMPLEOT = SimpleOT/libsimpleot.a
all: gen_input online offline
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x
offline: $(OT_EXE) Check-Offline.x
gen_input: gen_input_f2n.x gen_input_fp.x
Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS)
Server.x: Server.cpp $(COMMON)
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)
Player-Online.x: Player-Online.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) Player-Online.cpp -o Player-Online.x $(COMMON) $(PROCESSOR) $(LDLIBS)
ifeq ($(USE_GF2N_LONG),1)
ot.x: $(OT) $(COMMON) OT/OText_main.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT)
ot-check.x: $(OT) $(COMMON)
$(CXX) $(CFLAGS) -o ot-check.x OT/BitVector.o OT/OutputCheck.cpp $(COMMON) $(LDLIBS)
ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp
$(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o OT/BitVector.o $(COMMON) $(LDLIBS)
ot-offline.x: $(OT) $(COMMON) ot-offline.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT)
endif
check-passive.x: $(COMMON) check-passive.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON)
$(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS)
gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON)
$(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS)
clean:
-rm */*.o *.o *.x core.* *.a gmon.out

24
Math/Integer.cpp Normal file
View File

@@ -0,0 +1,24 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Integer.cpp
*
*/
#include "Integer.h"
void Integer::output(ostream& s,bool human) const
{
if (human)
s << a;
else
s.write((char*)&a, sizeof(a));
}
void Integer::input(istream& s,bool human)
{
if (human)
s >> a;
else
s.read((char*)&a, sizeof(a));
}

34
Math/Integer.h Normal file
View File

@@ -0,0 +1,34 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Integer.h
*
*/
#ifndef INTEGER_H_
#define INTEGER_H_
#include <iostream>
using namespace std;
// Wrapper class for integer, used for Memory
class Integer
{
long a;
public:
Integer() { a = 0; }
Integer(long a) : a(a) {}
long get() const { return a; }
void assign_zero() { a = 0; }
void output(ostream& s,bool human) const;
void input(istream& s,bool human);
};
#endif /* INTEGER_H_ */

148
Math/Setup.cpp Normal file
View File

@@ -0,0 +1,148 @@
// (C) 2016 University of Bristol. See License.txt
#include "Math/Setup.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Tools/mkpath.h"
#include <fstream>
/*
* Just setup the primes, doesn't need NTL.
* Sets idx and m to be used by SHE setup if necessary
*/
void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m)
{
cout << "Setting up parameters" << endl;
switch (lgp)
{ case -1:
m=16;
idx=1; // Any old figures will do, but need to be for lgp at last
lgp=32; // Switch to bigger prime to get parameters
break;
case 32:
m=8192;
idx=0;
break;
case 64:
m=16384;
idx=1;
break;
case 128:
m=32768;
idx=2;
break;
case 256:
m=32768;
idx=3;
break;
case 512:
m=65536;
idx=4;
break;
default:
throw invalid_params();
break;
}
cout << "m = " << m << endl;
// Here we choose a prime which is the order of a BN curve
// - Reason is that there are some applications where this
// would be a good idea. So I have hard coded it in here
// - This is pointless/impossible for lgp=32, 64 so for
// these do something naive
// - Have not tested 256 and 512
bigint u;
int ex;
if (lgp!=32 && lgp!=64)
{ u=1; u=u<<(lgp-1); u=sqrt(sqrt(u/36))/m;
u=u*m;
bigint q;
// cout << ex << " " << u << " " << numBits(u) << endl;
p=(((36*u+36)*u+18)*u+6)*u+1; // The group order of a BN curve
q=(((36*u+36)*u+24)*u+6)*u+1; // The base field size of a BN curve
while (!probPrime(p) || !probPrime(q) || numBits(p)<lgp)
{ u=u+m;
p=(((36*u+36)*u+18)*u+6)*u+1;
q=(((36*u+36)*u+24)*u+6)*u+1;
}
}
else
{ ex=lgp-numBits(m);
u=1; u=(u<<ex)*m; p=u+1;
while (!probPrime(p) || numBits(p)<lgp)
{ u=u+m; p=u+1; }
}
cout << "\t p = " << p << " u = " << u << " : ";
cout << lgp << " <= " << numBits(p) << endl;
}
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2)
{
int idx, m;
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
stringstream ss;
ss << dirname;
cout << "Writing to file in " << ss.str() << endl;
// create preprocessing dir. if necessary
if (mkdir_p(ss.str().c_str()) == -1)
{
cerr << "mkdir_p(" << ss.str() << ") failed\n";
throw file_error();
}
// Output the data
ss << "/Params-Data";
outf.open(ss.str().c_str());
// Only need p and lg2 for online phase
outf << p << endl;
// Fix as a negative lg2 is a ``signal'' to choose slightly weaker
// LWE parameters
outf << abs(lg2) << endl;
gfp::init_field(p, true);
gf2n::init_field(lg2);
}
string get_prep_dir(int nparties, int lg2p, int gf2ndegree)
{
stringstream ss;
ss << PREP_DIR << nparties << "-" << lg2p << "-" << gf2ndegree << "/";
return ss.str();
}
// Only read enough to initialize the fields (i.e. for OT offline or online phase only)
void read_setup(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cerr << "loading params from: " << filename << endl;
// backwards compatibility hack
if (dir_prefix.compare("") == 0)
filename = string("Player-Data/Params-Data");
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
void read_setup(int nparties, int lg2p, int gf2ndegree)
{
string dir = get_prep_dir(nparties, lg2p, gf2ndegree);
read_setup(dir);
}

35
Math/Setup.h Normal file
View File

@@ -0,0 +1,35 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Setup.h
*
*/
#ifndef MATH_SETUP_H_
#define MATH_SETUP_H_
#include "Math/bigint.h"
#include <iostream>
using namespace std;
/*
* Routines to create and read setup files for the finite fields
*/
// Create setup file for gfp and gf2n
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2);
// Setup primes only
// Chooses a p of at least lgp bits
void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m);
// get main directory for prep. data
string get_prep_dir(int nparties, int lg2p, int gf2ndegree);
// Read online setup file for gfp and gf2n
void read_setup(const string& dir_prefix);
void read_setup(int nparties, int lg2p, int gf2ndegree);
#endif /* MATH_SETUP_H_ */

126
Math/Share.cpp Normal file
View File

@@ -0,0 +1,126 @@
// (C) 2016 University of Bristol. See License.txt
#include "Share.h"
//#include "Tools/random.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Math/operators.h"
template<class T>
Share<T>::Share(const T& aa, int my_num, const T& alphai)
{
if (my_num == 0)
a = aa;
else
a.assign_zero();
mac = aa * alphai;
}
template<class T>
void Share<T>::mul_by_bit(const Share<T>& S,const T& aa)
{
a.mul(S.a,aa);
mac.mul(S.mac,aa);
}
template<>
void Share<gf2n>::mul_by_bit(const Share<gf2n>& S, const gf2n& aa)
{
a.mul_by_bit(S.a,aa);
mac.mul_by_bit(S.mac,aa);
}
template<class T>
void Share<T>::add(const Share<T>& S,const T& aa,bool playerone,const T& alphai)
{
if (playerone)
{ a.add(S.a,aa); }
else
{ a=S.a; }
T tmp;
tmp.mul(alphai,aa);
mac.add(S.mac,tmp);
}
template<class T>
void Share<T>::sub(const Share<T>& S,const T& aa,bool playerone,const T& alphai)
{
if (playerone)
{ a.sub(S.a,aa); }
else
{ a=S.a; }
T tmp;
tmp.mul(alphai,aa);
mac.sub(S.mac,tmp);
}
template<class T>
void Share<T>::sub(const T& aa,const Share<T>& S,bool playerone,const T& alphai)
{
if (playerone)
{ a.sub(aa,S.a); }
else
{ a=S.a;
a.negate();
}
T tmp;
tmp.mul(alphai,aa);
mac.sub(tmp,S.mac);
}
template<class T>
void Share<T>::sub(const Share<T>& S1,const Share<T>& S2)
{
a.sub(S1.a,S2.a);
mac.sub(S1.mac,S2.mac);
}
template<class T>
T combine(const vector< Share<T> >& S)
{
T ans=S[0].a;
for (unsigned int i=1; i<S.size(); i++)
{ ans.add(ans,S[i].a); }
return ans;
}
template<class T>
bool check_macs(const vector< Share<T> >& S,const T& key)
{
T val=combine(S);
// Now check the MAC is valid
val.mul(val,key);
for (unsigned i=0; i<S.size(); i++)
{ val.sub(val,S[i].mac); }
if (!val.is_zero()) { return false; }
return true;
}
template class Share<gf2n>;
template class Share<gfp>;
template gf2n combine(const vector< Share<gf2n> >& S);
template gfp combine(const vector< Share<gfp> >& S);
template bool check_macs(const vector< Share<gf2n> >& S,const gf2n& key);
template bool check_macs(const vector< Share<gfp> >& S,const gfp& key);
#ifdef USE_GF2N_LONG
template class Share<gf2n_short>;
template gf2n_short combine(const vector< Share<gf2n_short> >& S);
template bool check_macs(const vector< Share<gf2n_short> >& S,const gf2n_short& key);
#endif

117
Math/Share.h Normal file
View File

@@ -0,0 +1,117 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Share
#define _Share
/* Class for holding a share of either a T or gfp element */
#include <vector>
#include <iostream>
using namespace std;
#include "Math/gfp.h"
#include "Math/gf2n.h"
// Forward declaration as apparently this is needed for friends in templates
template<class T> class Share;
template<class T> T combine(const vector< Share<T> >& S);
template<class T> bool check_macs(const vector< Share<T> >& S,const T& key);
template<class T>
class Share
{
T a; // The share
T mac; // Shares of the mac
public:
typedef T value_type;
static int size()
{ return 2 * T::size(); }
static string type_string()
{ return T::type_string(); }
void assign(const Share<T>& S)
{ a=S.a; mac=S.mac; }
void assign(const char* buffer)
{ a.assign(buffer); mac.assign(buffer + T::size()); }
void assign_zero()
{ a.assign_zero();
mac.assign_zero();
}
Share() { assign_zero(); }
Share(const Share<T>& S) { assign(S); }
Share(const T& aa, int my_num, const T& alphai);
~Share() { ; }
Share& operator=(const Share<T>& S)
{ if (this!=&S) { assign(S); }
return *this;
}
const T& get_share() const { return a; }
const T& get_mac() const { return mac; }
void set_share(const T& aa) { a=aa; }
void set_mac(const T& aa) { mac=aa; }
/* Arithmetic Routines */
void mul(const Share<T>& S,const T& aa);
void mul_by_bit(const Share<T>& S,const T& aa);
void add(const Share<T>& S,const T& aa,bool playerone,const T& alphai);
void negate() { a.negate(); mac.negate(); }
void sub(const Share<T>& S,const T& aa,bool playerone,const T& alphai);
void sub(const T& aa,const Share<T>& S,bool playerone,const T& alphai);
void add(const Share<T>& S1,const Share<T>& S2);
void sub(const Share<T>& S1,const Share<T>& S2);
void add(const Share<T>& S1) { add(*this,S1); }
// Input and output from a stream
// - Can do in human or machine only format (later should be faster)
void output(ostream& s,bool human) const
{ a.output(s,human); if (human) { s << " "; }
mac.output(s,human);
}
void input(istream& s,bool human)
{ a.input(s,human);
mac.input(s,human);
}
/* Takes a vector of shares, one from each player and
* determines the shared value
* - i.e. Partially open the shares
*/
friend T combine<T>(const vector< Share<T> >& S);
/* Given a set of shares, one from each player and
* the global key, determines if the sharing is valid
* - Mainly for test purposes
*/
friend bool check_macs<T>(const vector< Share<T> >& S,const T& key);
};
// specialized mul by bit for gf2n
template <>
void Share<gf2n>::mul_by_bit(const Share<gf2n>& S,const gf2n& aa);
template <class T>
Share<T> operator*(const T& y, const Share<T>& x) { Share<T> res; res.mul(x, y); return res; }
template<class T>
inline void Share<T>::add(const Share<T>& S1,const Share<T>& S2)
{
a.add(S1.a,S2.a);
mac.add(S1.mac,S2.mac);
}
template<class T>
inline void Share<T>::mul(const Share<T>& S,const T& aa)
{
a.mul(S.a,aa);
mac.mul(S.mac,aa);
}
#endif

138
Math/Zp_Data.cpp Normal file
View File

@@ -0,0 +1,138 @@
// (C) 2016 University of Bristol. See License.txt
#include "Zp_Data.h"
void Zp_Data::init(const bigint& p,bool mont)
{ pr=p;
mask=(1<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1;
montgomery=mont;
t=mpz_size(pr.get_mpz_t());
if (t>=MAX_MOD_SZ)
throw max_mod_sz_too_small(t+1);
if (montgomery)
{ mpn_zero(R,MAX_MOD_SZ);
mpn_zero(R2,MAX_MOD_SZ);
mpn_zero(R3,MAX_MOD_SZ);
bigint r=2,pp=pr;
mpz_pow_ui(r.get_mpz_t(),r.get_mpz_t(),t*8*sizeof(mp_limb_t));
mpz_invert(pp.get_mpz_t(),pr.get_mpz_t(),r.get_mpz_t());
pp=r-pp; // pi=-1/p mod R
pi=(pp.get_mpz_t()->_mp_d)[0];
r=r%pr;
mpn_copyi(R,r.get_mpz_t()->_mp_d,mpz_size(r.get_mpz_t()));
bigint r2=(r*r)%pr;
mpn_copyi(R2,r2.get_mpz_t()->_mp_d,mpz_size(r2.get_mpz_t()));
bigint r3=(r2*r)%pr;
mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t()));
if (sizeof(unsigned long)!=sizeof(mp_limb_t))
{ cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl;
throw not_implemented();
}
}
mpn_zero(prA,MAX_MOD_SZ);
mpn_copyi(prA,pr.get_mpz_t()->_mp_d,t);
}
void Zp_Data::assign(const Zp_Data& Zp)
{ pr=Zp.pr;
mask=Zp.mask;
montgomery=Zp.montgomery;
t=Zp.t;
mpn_copyi(R,Zp.R,t+1);
mpn_copyi(R2,Zp.R2,t+1);
mpn_copyi(R3,Zp.R3,t+1);
pi=Zp.pi;
mpn_copyi(prA,Zp.prA,t+1);
}
void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
{
mp_limb_t borrow = mpn_sub_n(ans,x,y,t);
if (borrow!=0)
mpn_add_n(ans,ans,prA,t);
}
__m128i Zp_Data::get_random128(PRNG& G)
{
while (true)
{
__m128i res = G.get_doubleword();
if (mpn_cmp((mp_limb_t*)&res, prA, t) < 0)
return res;
}
}
#include <stdlib.h>
void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
{
if (x[t]!=0 || y[t]!=0) { cout << "Mont_Mult Bug" << endl; abort(); }
mp_limb_t ans[2*MAX_MOD_SZ],u;
// First loop
u=x[0]*y[0]*pi;
ans[t] = mpn_mul_1(ans,y,t,x[0]);
ans[t+1] = mpn_addmul_1(ans,prA,t+1,u);
for (int i=1; i<t; i++)
{ // u=(ans0+xi*y0)*pd
u=(ans[i]+x[i]*y[0])*pi;
// ans=ans+xi*y+u*pr
ans[t+i+1]=mpn_addmul_1(ans+i,y,t+1,x[i]);
ans[t+i+1]+=mpn_addmul_1(ans+i,prA,t+1,u);
}
// if (ans>=pr) { ans=z-pr; }
// else { z=ans; }
if (mpn_cmp(ans+t,prA,t+1)>=0)
{ mpn_sub_n(z,ans+t,prA,t); }
else
{ mpn_copyi(z,ans+t,t); }
}
ostream& operator<<(ostream& s,const Zp_Data& ZpD)
{
s << ZpD.pr << " " << ZpD.montgomery << endl;
if (ZpD.montgomery)
{ s << ZpD.t << " " << ZpD.pi << endl;
for (int i=0; i<ZpD.t; i++) { s << ZpD.R[i] << " "; }
s << endl;
for (int i=0; i<ZpD.t; i++) { s << ZpD.R2[i] << " "; }
s << endl;
for (int i=0; i<ZpD.t; i++) { s << ZpD.R3[i] << " "; }
s << endl;
for (int i=0; i<ZpD.t; i++) { s << ZpD.prA[i] << " "; }
s << endl;
}
return s;
}
istream& operator>>(istream& s,Zp_Data& ZpD)
{
s >> ZpD.pr >> ZpD.montgomery;
if (ZpD.montgomery)
{ s >> ZpD.t >> ZpD.pi;
if (ZpD.t>=MAX_MOD_SZ)
throw max_mod_sz_too_small(ZpD.t+1);
mpn_zero(ZpD.R,MAX_MOD_SZ);
mpn_zero(ZpD.R2,MAX_MOD_SZ);
mpn_zero(ZpD.R3,MAX_MOD_SZ);
mpn_zero(ZpD.prA,MAX_MOD_SZ);
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R[i]; }
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R2[i]; }
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R3[i]; }
for (int i=0; i<ZpD.t; i++) { s >> ZpD.prA[i]; }
}
return s;
}

129
Math/Zp_Data.h Normal file
View File

@@ -0,0 +1,129 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Zp_Data
#define _Zp_Data
/* Class to define helper information for a Zp element
*
* Basically the data needed for Montgomery operations
*
* Almost all data is public as this is basically a container class
*
*/
#include "Math/bigint.h"
#include "Tools/random.h"
#include <smmintrin.h>
#include <iostream>
using namespace std;
#ifndef MAX_MOD_SZ
#ifdef LargeM
#define MAX_MOD_SZ 20
#else
#define MAX_MOD_SZ 3
#endif
#endif
class modp;
class Zp_Data
{
bool montgomery; // True if we are using Montgomery arithmetic
mp_limb_t R[MAX_MOD_SZ],R2[MAX_MOD_SZ],R3[MAX_MOD_SZ],pi;
mp_limb_t prA[MAX_MOD_SZ];
int t; // More Montgomery data
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
public:
bigint pr;
mp_limb_t mask;
void assign(const Zp_Data& Zp);
void init(const bigint& p,bool mont=true);
int get_t() const { return t; }
const mp_limb_t* get_prA() const { return prA; }
// This one does nothing, needed so as to make vectors of Zp_Data
Zp_Data() : montgomery(0), pi(0), mask(0) { t=1; }
// The main init funciton
Zp_Data(const bigint& p,bool mont=true)
{ init(p,mont); }
Zp_Data(const Zp_Data& Zp) { assign(Zp); }
Zp_Data& operator=(const Zp_Data& Zp)
{ if (this!=&Zp) { assign(Zp); }
return *this;
}
~Zp_Data() { ; }
template <int T>
void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
void Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
__m128i get_random128(PRNG& G);
friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce);
friend void to_modp(modp& ans,int x,const Zp_Data& ZpD);
friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD);
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD);
friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD);
friend void assignOne(modp& x,const Zp_Data& ZpD);
friend void assignZero(modp& x,const Zp_Data& ZpD);
friend bool isZero(const modp& x,const Zp_Data& ZpD);
friend bool isOne(const modp& x,const Zp_Data& ZpD);
friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD);
friend class modp;
friend ostream& operator<<(ostream& s,const Zp_Data& ZpD);
friend istream& operator>>(istream& s,Zp_Data& ZpD);
};
template<>
inline void Zp_Data::Add<2>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
{
__uint128_t a, b, p;
memcpy(&a, x, sizeof(__uint128_t));
memcpy(&b, y, sizeof(__uint128_t));
memcpy(&p, prA, sizeof(__uint128_t));
__uint128_t c = a + b;
asm goto ("jc %l[sub]" :::: sub);
if (c >= p)
sub:
c -= p;
memcpy(ans, &c, sizeof(__uint128_t));
}
template<>
inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
{
mp_limb_t carry = mpn_add_n(ans,x,y,t);
if (carry!=0 || mpn_cmp(ans,prA,t)>=0)
{ mpn_sub_n(ans,ans,prA,t); }
}
inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
{
if (t == 2)
return Add<2>(ans, x, y);
else
return Add<0>(ans, x, y);
}
#endif

95
Math/bigint.cpp Normal file
View File

@@ -0,0 +1,95 @@
// (C) 2016 University of Bristol. See License.txt
#include "bigint.h"
#include "Exceptions/Exceptions.h"
bigint sqrRootMod(const bigint& a,const bigint& p)
{
bigint ans;
if (a==0) { ans=0; return ans; }
if (mpz_tstbit(p.get_mpz_t(),1)==1)
{ // First do case with p=3 mod 4
bigint exp=(p+1)/4;
mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t());
}
else
{ // Shanks algorithm
gmp_randclass Gen(gmp_randinit_default);
Gen.seed(0);
bigint x,y,n,q,t,b,temp;
// Find n such that (n/p)=-1
int leg=1;
while (leg!=-1)
{ n=Gen.get_z_range(p);
leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t());
}
// Split p-1 = 2^e q
q=p-1;
int e=0;
while (mpz_even_p(q.get_mpz_t()))
{ e++; q=q/2; }
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
int r=e;
mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t());
temp=(q-1)/2;
mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t());
// b=a*x^2 mod p, x=a*x mod p
b=(a*x*x)%p;
x=(a*x)%p;
// While b!=1 do
while (b!=1)
{ // Find smallest m such that b^(2^m)=1 mod p
int m=1;
temp=(b*b)%p;
while (temp!=1)
{ temp=(temp*temp)%p; m++; }
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
t=y;
for (int i=0; i<r-m-1; i++)
{ t=(t*t)%p; }
y=(t*t)%p;
r=m;
// x=x*t mod p, b=b*y mod p
x=(x*t)%p;
b=(b*y)%p;
}
ans=x;
}
return ans;
}
bigint powerMod(const bigint& x,const bigint& e,const bigint& p)
{
bigint ans;
if (e>=0)
{ mpz_powm(ans.get_mpz_t(),x.get_mpz_t(),e.get_mpz_t(),p.get_mpz_t()); }
else
{ bigint xi,ei=-e;
invMod(xi,x,p);
mpz_powm(ans.get_mpz_t(),xi.get_mpz_t(),ei.get_mpz_t(),p.get_mpz_t());
}
return ans;
}
int powerMod(int x,int e,int p)
{
if (e==1) { return x; }
if (e==0) { return 1; }
if (e<0)
{ throw not_implemented(); }
int t=x,ans=1;
while (e!=0)
{ if ((e&1)==1) { ans=(ans*t)%p; }
e>>=1;
t=(t*t)%p;
}
return ans;
}

116
Math/bigint.h Normal file
View File

@@ -0,0 +1,116 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _bigint
#define _bigint
#include <iostream>
using namespace std;
#include <stddef.h>
#include <mpirxx.h>
typedef mpz_class bigint;
#include "Exceptions/Exceptions.h"
#include "Tools/int.h"
/**********************************
* Utility Functions *
**********************************/
inline int gcd(const int x,const int y)
{
bigint xx=x;
return mpz_gcd_ui(NULL,xx.get_mpz_t(),y);
}
inline bigint gcd(const bigint& x,const bigint& y)
{
bigint g;
mpz_gcd(g.get_mpz_t(),x.get_mpz_t(),y.get_mpz_t());
return g;
}
inline void invMod(bigint& ans,const bigint& x,const bigint& p)
{
mpz_invert(ans.get_mpz_t(),x.get_mpz_t(),p.get_mpz_t());
}
inline int numBits(const bigint& m)
{
return mpz_sizeinbase(m.get_mpz_t(),2);
}
inline int numBits(int m)
{
bigint te=m;
return mpz_sizeinbase(te.get_mpz_t(),2);
}
inline int numBytes(const bigint& m)
{
return mpz_sizeinbase(m.get_mpz_t(),256);
}
inline int probPrime(const bigint& x)
{
gmp_randstate_t rand_state;
gmp_randinit_default(rand_state);
int ans=mpz_probable_prime_p(x.get_mpz_t(),rand_state,40,0);
gmp_randclear(rand_state);
return ans;
}
inline void bigintFromBytes(bigint& x,octet* bytes,int len)
{
mpz_import(x.get_mpz_t(),len,1,sizeof(octet),0,0,bytes);
}
inline void bytesFromBigint(octet* bytes,const bigint& x,unsigned int len)
{
size_t ll;
mpz_export(bytes,&ll,1,sizeof(octet),0,0,x.get_mpz_t());
if (ll>len)
{ throw invalid_length(); }
for (unsigned int i=ll; i<len; i++)
{ bytes[i]=0; }
}
inline int isOdd(const bigint& x)
{
return mpz_odd_p(x.get_mpz_t());
}
bigint sqrRootMod(const bigint& x,const bigint& p);
bigint powerMod(const bigint& x,const bigint& e,const bigint& p);
// Assume e>=0
int powerMod(int x,int e,int p);
inline int Hwt(int N)
{
int result=0;
while(N)
{ result++;
N&=(N-1);
}
return result;
}
#endif

15
Math/field_types.h Normal file
View File

@@ -0,0 +1,15 @@
// (C) 2016 University of Bristol. See License.txt
/*
* types.h
*
*/
#ifndef MATH_FIELD_TYPES_H_
#define MATH_FIELD_TYPES_H_
enum DataFieldType { DATA_MODP, DATA_GF2N, N_DATA_FIELD_TYPE };
#endif /* MATH_FIELD_TYPES_H_ */

345
Math/gf2n.cpp Normal file
View File

@@ -0,0 +1,345 @@
// (C) 2016 University of Bristol. See License.txt
#include "Math/gf2n.h"
#include "Exceptions/Exceptions.h"
#include <stdint.h>
#include <wmmintrin.h>
#include <xmmintrin.h>
#include <emmintrin.h>
int gf2n_short::n;
int gf2n_short::t1;
int gf2n_short::t2;
int gf2n_short::t3;
int gf2n_short::l0;
int gf2n_short::l1;
int gf2n_short::l2;
int gf2n_short::l3;
int gf2n_short::nterms;
word gf2n_short::mask;
bool gf2n_short::useC;
bool gf2n_short::rewind = false;
word gf2n_short_table[256][256];
#define num_2_fields 4
/* Require
* 2*(n-1)-64+t1<64
*/
int fields_2[num_2_fields][4] = {
{4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10}
};
void gf2n_short::init_tables()
{
if (sizeof(word)!=8)
{ cout << "Word size is wrong" << endl;
throw not_implemented();
}
int i,j;
for (i=0; i<256; i++)
{ for (j=0; j<256; j++)
{ word ii=i,jj=j;
gf2n_short_table[i][j]=0;
while (ii!=0)
{ if ((ii&1)==1) { gf2n_short_table[i][j]^=jj; }
jj<<=1;
ii>>=1;
}
}
}
}
void gf2n_short::init_field(int nn)
{
gf2n_short::init_tables();
int i,j=-1;
for (i=0; i<num_2_fields && j==-1; i++)
{ if (nn==fields_2[i][0]) { j=i; } }
if (j==-1) { throw invalid_params(); }
n=nn;
nterms=1;
l0=64-n;
t1=fields_2[j][1];
l1=64+t1-n;
if (fields_2[j][2]!=0)
{ nterms=3;
t2=fields_2[j][2];
l2=64+t2-n;
t3=fields_2[j][3];
l3=64+t3-n;
}
if (2*(n-1)-64+t1>=64) { throw invalid_params(); }
mask=(1ULL<<n)-1;
useC=(Check_CPU_support_AES()==0);
}
/* Takes 16bit x and y and returns the 32 bit product in c1 and c0
ans = (c1<<16)^c0
where c1 and c0 are 16 bit
*/
inline void mul16(word x,word y,word& c0,word& c1)
{
word a1=x&(0xFF), b1=y&(0xFF);
word a2=x>>8, b2=y>>8;
c0=gf2n_short_table[a1][b1];
c1=gf2n_short_table[a2][b2];
word te=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1];
c0^=(te&0xFF)<<8;
c1^=te>>8;
}
/* Takes 16 bit x and y and returns the 32 bit product */
inline word mul16(word x,word y)
{
word a1=x&(0xFF), b1=y&(0xFF);
word a2=x>>8, b2=y>>8;
word ans=gf2n_short_table[a2][b2]<<8;
ans^=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1];
ans<<=8;
ans^=gf2n_short_table[a1][b1];
return ans;
}
/* Takes 16 bit x the 32 bit square */
inline word sqr16(word x)
{
word a1=x&(0xFF),a2=x>>8;
word ans=gf2n_short_table[a2][a2]<<16;
ans^=gf2n_short_table[a1][a1];
return ans;
}
void gf2n_short::reduce_trinomial(word xh,word xl)
{
// Deal with xh first
a=xl;
a^=(xh<<l0);
a^=(xh<<l1);
// Now deal with last word
word hi=a>>n;
while (hi!=0)
{ a&=mask;
a^=hi;
a^=(hi<<t1);
hi=a>>n;
}
}
void gf2n_short::reduce_pentanomial(word xh,word xl)
{
// Deal with xh first
a=xl;
a^=(xh<<l0);
a^=(xh<<l1);
a^=(xh<<l2);
a^=(xh<<l3);
// Now deal with last word
word hi=a>>n;
while (hi!=0)
{ a&=mask;
a^=hi;
a^=(hi<<t1);
a^=(hi<<t2);
a^=(hi<<t3);
hi=a>>n;
}
}
void mul32(word x,word y,word& ans)
{
word a1=x&(0xFFFF),b1=y&(0xFFFF);
word a2=x>>16, b2=y>>16;
word c0,c1;
ans=mul16(a1,b1);
word upp=mul16(a2,b2);
mul16(a1,b2,c0,c1);
ans^=c0<<16; upp^=c1;
mul16(a2,b1,c0,c1);
ans^=c0<<16; upp^=c1;
ans^=(upp<<32);
}
void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y)
{
word hi,lo;
if (gf2n_short::useC)
{ /* Uses Karatsuba */
word c,d,e,t;
word xl=x.a&0xFFFFFFFF,yl=y.a&0xFFFFFFFF;
word xh=x.a>>32,yh=y.a>>32;
mul32(xl,yl,c);
mul32(xh,yh,d);
mul32((xl^xh),(yl^yh),e);
t=c^e^d;
lo=c^(t<<32);
hi=d^(t>>32);
}
else
{ /* Use Intel Instructions */
__m128i xx,yy,zz;
uint64_t c[] __attribute__((aligned (16))) = { 0,0 };
xx=_mm_set1_epi64x(x.a);
yy=_mm_set1_epi64x(y.a);
zz=_mm_clmulepi64_si128(xx,yy,0);
_mm_store_si128((__m128i*)c,zz);
lo=c[0];
hi=c[1];
}
reduce(hi,lo);
}
inline void sqr32(word x,word& ans)
{
word a1=x&(0xFFFF),a2=x>>16;
ans=sqr16(a1)^(sqr16(a2)<<32);
}
void gf2n_short::square()
{
word xh,xl;
sqr32(a&0xFFFFFFFF,xl);
sqr32(a>>32,xh);
reduce(xh,xl);
}
void gf2n_short::square(const gf2n_short& bb)
{
word xh,xl;
sqr32(bb.a&0xFFFFFFFF,xl);
sqr32(bb.a>>32,xh);
reduce(xh,xl);
}
void gf2n_short::invert()
{
if (is_one()) { return; }
if (is_zero()) { throw division_by_zero(); }
word u,v=a,B=0,D=1,mod=1;
mod^=(1ULL<<n);
mod^=(1ULL<<t1);
if (nterms==3)
{ mod^=(1ULL<<t2);
mod^=(1ULL<<t3);
}
u=mod; v=a;
while (u!=0)
{ while ((u&1)==0)
{ u>>=1;
if ((B&1)!=0) { B^=mod; }
B>>=1;
}
while ((v&1)==0 && v!=0)
{ v>>=1;
if ((D&1)!=0) { D^=mod; }
D>>=1;
}
if (u>=v) { u=u^v; B=B^D; }
else { v=v^u; D=D^B; }
}
a=D;
}
void gf2n_short::power(long i)
{
long n=i;
if (n<0) { invert(); n=-n; }
gf2n_short T=*this;
assign_one();
while (n!=0)
{ if ((n&1)!=0) { mul(*this,T); }
n>>=1;
T.square();
}
}
void gf2n_short::randomize(PRNG& G)
{
a=G.get_uint();
a=(a<<32)^G.get_uint();
a&=mask;
}
void gf2n_short::output(ostream& s,bool human) const
{
if (human)
{ s << hex << a << dec << " "; }
else
{ s.write((char*) &a,sizeof(word)); }
}
void gf2n_short::input(istream& s,bool human)
{
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
}
//throw end_of_file();
s.clear(); // unset EOF flag
s.seekg(0);
if (!rewind)
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
rewind = true;
}
if (human)
{ s >> hex >> a >> dec; }
else
{ s.read((char*) &a,sizeof(word)); }
a &= mask;
}

191
Math/gf2n.h Normal file
View File

@@ -0,0 +1,191 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _gf2n
#define _gf2n
#include <stdlib.h>
#include <string.h>
#include <iostream>
using namespace std;
#include "Tools/random.h"
#include "Math/gf2nlong.h"
#include "Math/field_types.h"
/* This interface compatible with the gfp interface
* which then allows us to template the Share
* data type.
*/
/*
Arithmetic in Gf_{2^n} with n<64
*/
class gf2n_short
{
word a;
static int n,t1,t2,t3,nterms;
static int l0,l1,l2,l3;
static word mask;
static bool useC;
static bool rewind;
/* Assign x[0..2*nwords] to a and reduce it... */
void reduce_trinomial(word xh,word xl);
void reduce_pentanomial(word xh,word xl);
void reduce(word xh,word xl)
{ if (nterms==3)
{ reduce_pentanomial(xh,xl); }
else
{ reduce_trinomial(xh,xl); }
}
static void init_tables();
public:
typedef gf2n_short value_type;
typedef word internal_type;
static void init_field(int nn);
static int degree() { return n; }
static int get_nterms() { return nterms; }
static int get_t(int i)
{ if (i==0) { return t1; }
else if (i==1) { return t2; }
else if (i==2) { return t3; }
return -1;
}
static DataFieldType field_type() { return DATA_GF2N; }
static char type_char() { return '2'; }
static string type_string() { return "gf2n"; }
static int size() { return sizeof(a); }
static int t() { return 0; }
word get() const { return a; }
word get_word() const { return a; }
void assign(const gf2n_short& g) { a=g.a; }
void assign_zero() { a=0; }
void assign_one() { a=1; }
void assign_x() { a=2; }
void assign(word aa) { a=aa&mask; }
void assign(int aa) { a=static_cast<unsigned int>(aa)&mask; }
void assign(const char* buffer) { a = *(word*)buffer; }
int get_bit(int i) const
{ return (a>>i)&1; }
void set_bit(int i,unsigned int b)
{ if (b==1)
{ a |= (1UL<<i); }
else
{ a &= ~(1UL<<i); }
}
gf2n_short() { a=0; }
gf2n_short(const gf2n_short& g) { assign(g); }
gf2n_short(int g) { assign(g); }
~gf2n_short() { ; }
gf2n_short& operator=(const gf2n_short& g)
{ assign(g);
return *this;
}
int is_zero() const { return (a==0); }
int is_one() const { return (a==1); }
int equal(const gf2n_short& y) const { return (a==y.a); }
bool operator==(const gf2n_short& y) const { return a==y.a; }
bool operator!=(const gf2n_short& y) const { return a!=y.a; }
// x+y
void add(const gf2n_short& x,const gf2n_short& y)
{ a=x.a^y.a; }
void add(const gf2n_short& x)
{ a^=x.a; }
template<int T>
void add(octet* x)
{ a^=*(word*)(x); }
void add(octet* x)
{ add<0>(x); }
void sub(const gf2n_short& x,const gf2n_short& y)
{ a=x.a^y.a; }
void sub(const gf2n_short& x)
{ a^=x.a; }
// = x * y
void mul(const gf2n_short& x,const gf2n_short& y);
void mul(const gf2n_short& x) { mul(*this,x); }
// x * y when one of x,y is a bit
void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; }
gf2n_short operator+(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; }
gf2n_short operator*(const gf2n_short& x) { gf2n_short res; res.mul(*this, x); return res; }
gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; }
gf2n_short& operator*=(const gf2n_short& x) { mul(x); return *this; }
gf2n_short operator-(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; }
gf2n_short& operator-=(const gf2n_short& x) { sub(x); return *this; }
void square();
void square(const gf2n_short& aa);
void invert();
void invert(const gf2n_short& aa)
{ *this=aa; invert(); }
void negate() { return; }
void power(long i);
/* Bitwise Ops */
void AND(const gf2n_short& x,const gf2n_short& y) { a=x.a&y.a; }
void XOR(const gf2n_short& x,const gf2n_short& y) { a=x.a^y.a; }
void OR(const gf2n_short& x,const gf2n_short& y) { a=x.a|y.a; }
void NOT(const gf2n_short& x) { a=(~x.a)&mask; }
void SHL(const gf2n_short& x,int n) { a=(x.a<<n)&mask; }
void SHR(const gf2n_short& x,int n) { a=x.a>>n; }
gf2n_short operator&(const gf2n_short& x) { gf2n_short res; res.AND(*this, x); return res; }
gf2n_short operator^(const gf2n_short& x) { gf2n_short res; res.XOR(*this, x); return res; }
gf2n_short operator|(const gf2n_short& x) { gf2n_short res; res.OR(*this, x); return res; }
gf2n_short operator!() { gf2n_short res; res.NOT(*this); return res; }
gf2n_short operator<<(int i) { gf2n_short res; res.SHL(*this, i); return res; }
gf2n_short operator>>(int i) { gf2n_short res; res.SHR(*this, i); return res; }
/* Crap RNG */
void randomize(PRNG& G);
// compatibility with gfp
void almost_randomize(PRNG& G) { randomize(G); }
void output(ostream& s,bool human) const;
void input(istream& s,bool human);
friend ostream& operator<<(ostream& s,const gf2n_short& x)
{ s << hex << "0x" << x.a << dec;
return s;
}
friend istream& operator>>(istream& s,gf2n_short& x)
{ s >> hex >> x.a >> dec;
return s;
}
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const
{ o.append((octet*) &a,sizeof(word)); }
void unpack(octetStream& o)
{ o.consume((octet*) &a,sizeof(word)); }
};
#ifdef USE_GF2N_LONG
typedef gf2n_long gf2n;
#else
typedef gf2n_short gf2n;
#endif
#endif

277
Math/gf2nlong.cpp Normal file
View File

@@ -0,0 +1,277 @@
// (C) 2016 University of Bristol. See License.txt
/*
* gf2n_longlong.cpp
*
*/
#include "gf2nlong.h"
#include "Exceptions/Exceptions.h"
#include <stdint.h>
#include <wmmintrin.h>
#include <xmmintrin.h>
#include <emmintrin.h>
bool is_ge(__m128i a, __m128i b)
{
word aa[2], bb[2];
_mm_storeu_si128((__m128i*)aa, a);
_mm_storeu_si128((__m128i*)bb, b);
// cout << hex << "is_ge " << aa[1] << " " << bb[1] << " " << (aa[1] > bb[1]) << " ";
// cout << aa[0] << " " << bb[0] << " " << (aa[0] >= bb[0]) << endl;
return aa[1] == bb[1] ? aa[0] >= bb[0] : aa[1] > bb[1];
}
ostream& operator<<(ostream& s, const int128& a)
{
word* tmp = (word*)&a.a;
s << hex;
s.width(16);
s.fill('0');
s << tmp[1];
s.width(16);
s << tmp[0] << dec;
return s;
}
int gf2n_long::n;
int gf2n_long::t1;
int gf2n_long::t2;
int gf2n_long::t3;
int gf2n_long::l0;
int gf2n_long::l1;
int gf2n_long::l2;
int gf2n_long::l3;
int gf2n_long::nterms;
int128 gf2n_long::mask;
int128 gf2n_long::lowermask;
int128 gf2n_long::uppermask;
bool gf2n_long::rewind = false;
#define num_2_fields 1
/* Require
* 2*(n-1)-64+t1<64
*/
int long_fields_2[num_2_fields][4] = {
{128,7,2,1},
};
void gf2n_long::init_field(int nn)
{
if (nn!=128) {
cout << "Compiled for GF(2^128) only. Change parameters or compile "
"without USE_GF2N_LONG" << endl;
throw not_implemented();
}
int i,j=-1;
for (i=0; i<num_2_fields && j==-1; i++)
{ if (nn==long_fields_2[i][0]) { j=i; } }
if (j==-1) { throw invalid_params(); }
n=nn;
nterms=1;
l0=128-n;
t1=long_fields_2[j][1];
l1=128+t1-n;
if (long_fields_2[j][2]!=0)
{ nterms=3;
t2=long_fields_2[j][2];
l2=128+t2-n;
t3=long_fields_2[j][3];
l3=128+t3-n;
}
// 2^128 has a pentanomial
// if (nterms==1 && 2*(n-1)-128+t1>=128) { throw not_implemented(); }
// if (nterms==3 && n!=128) { throw not_implemented(); }
mask=_mm_set_epi64x(-1,-1);
lowermask=_mm_set_epi64x((1LL<<(64-7))-1,-1);
uppermask=_mm_set_epi64x(((word)-1)<<(64-7),0);
}
void gf2n_long::reduce_trinomial(int128 xh,int128 xl)
{
// Deal with xh first
a=xl;
a^=(xh<<l0);
a^=(xh<<l1);
// Now deal with last int128
int128 hi=a>>n;
while (hi==0)
{ a&=mask;
a^=hi;
a^=(hi<<t1);
hi=a>>n;
}
}
void gf2n_long::reduce_pentanomial(int128 xh, int128 xl)
{
// Deal with xh first
a=xl;
int128 upper, lower;
upper=xh&uppermask;
lower=xh&lowermask;
// Upper part
int128 tmp = 0;
tmp^=(upper>>(n-t1-l0));
tmp^=(upper>>(n-t1-l1));
tmp^=(upper>>(n-t1-l2));
tmp^=(upper>>(n-t1-l3));
lower^=(tmp>>(l1));
a^=(tmp<<(n-l1));
// Lower part
a^=(lower<<l0);
a^=(lower<<l1);
a^=(lower<<l2);
a^=(lower<<l3);
/*
// Now deal with last int128
int128 hi=a>>n;
while (hi!=0)
{ a&=mask;
a^=hi;
a^=(hi<<t1);
a^=(hi<<t2);
a^=(hi<<t3);
hi=a>>n;
}
*/
}
gf2n_long& gf2n_long::mul(const gf2n_long& x,const gf2n_long& y)
{
__m128i res[2];
memset(res,0,sizeof(res));
mul128(x.a.a,y.a.a,res,res+1);
reduce(res[1],res[0]);
return *this;
}
class int129
{
int128 lower;
bool msb;
public:
int129() : lower(_mm_setzero_si128()), msb(false) { }
int129(int128 lower, bool msb) : lower(lower), msb(msb) { }
int129(int128 a) : lower(a), msb(false) { }
int129(word a)
{ *this = a; }
int128 get_lower() { return lower; }
int129& operator=(const __m128i& other)
{ lower = other; msb = false; return *this; }
int129& operator=(const word& other)
{ lower = _mm_set_epi64x(0, other); msb = false; return *this; }
bool operator==(const int129& other)
{ return (lower == other.lower) && (msb == other.msb); }
bool operator!=(const int129& other)
{ return !(*this == other); }
bool operator>=(const int129& other)
{ //cout << ">= " << msb << other.msb << (msb > other.msb) << is_ge(lower.a, other.lower.a) << endl;
return msb == other.msb ? is_ge(lower.a, other.lower.a) : msb > other.msb; }
int129 operator<<(int other)
{ return int129(lower << other, _mm_cvtsi128_si32(((lower >> (128-other)) & 1).a)); }
int129& operator>>=(int other)
{ lower >>= other; lower |= (int128(msb) << (128-other)); msb = !other; return *this; }
int129 operator^(const int129& other)
{ return int129(lower ^ other.lower, msb ^ other.msb); }
int129& operator^=(const int129& other)
{ lower ^= other.lower; msb ^= other.msb; return *this; }
int129 operator&(const word& other)
{ return int129(lower & other, false); }
friend ostream& operator<<(ostream& s, const int129& a)
{ s << a.msb << a.lower; return s; }
};
void gf2n_long::invert()
{
if (is_one()) { return; }
if (is_zero()) { throw division_by_zero(); }
int129 u,v=a,B=0,D=1,mod=1;
mod^=(int129(1)<<n);
mod^=(int129(1)<<t1);
if (nterms==3)
{ mod^=(int129(1)<<t2);
mod^=(int129(1)<<t3);
}
u=mod; v=a;
while (u!=0)
{ while ((u&1)==0)
{ u>>=1;
if ((B&1)!=0) { B^=mod; }
B>>=1;
}
while ((v&1)==0 && v!=0)
{ v>>=1;
if ((D&1)!=0) { D^=mod; }
D>>=1;
}
if (u>=v) { u=u^v; B=B^D; }
else { v=v^u; D=D^B; }
}
a=D.get_lower();
}
void gf2n_long::randomize(PRNG& G)
{
a=G.get_doubleword();
a&=mask;
}
void gf2n_long::output(ostream& s,bool human) const
{
if (human)
{ s << *this; }
else
{ s.write((char*) &a,sizeof(__m128i)); }
}
void gf2n_long::input(istream& s,bool human)
{
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
}
//throw end_of_file();
s.clear(); // unset EOF flag
s.seekg(0);
if (!rewind)
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
rewind = true;
}
if (human)
{ s >> *this; }
else
{ s.read((char*) &a,sizeof(__m128i)); }
}

274
Math/gf2nlong.h Normal file
View File

@@ -0,0 +1,274 @@
// (C) 2016 University of Bristol. See License.txt
/*
* gf2nlong.h
*
*/
#ifndef MATH_GF2NLONG_H_
#define MATH_GF2NLONG_H_
#include <stdlib.h>
#include <string.h>
#include <iostream>
using namespace std;
#include <smmintrin.h>
#include "Tools/random.h"
#include "Math/field_types.h"
class int128
{
public:
__m128i a;
int128() : a(_mm_setzero_si128()) { }
int128(const int128& a) : a(a.a) { }
int128(const __m128i& a) : a(a) { }
int128(const word& a) : a(_mm_cvtsi64_si128(a)) { }
int128(const word& upper, const word& lower) : a(_mm_set_epi64x(upper, lower)) { }
word get_lower() { return (word)_mm_cvtsi128_si64(a); }
bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); }
bool operator!=(const int128& other) const { return !(*this == other); }
int128 operator<<(const int& other) const;
int128 operator>>(const int& other) const;
int128 operator^(const int128& other) const { return a ^ other.a; }
int128 operator|(const int128& other) const { return a | other.a; }
int128 operator&(const int128& other) const { return a & other.a; }
int128 operator~() const { return ~a; }
int128& operator<<=(const int& other) { return *this = *this << other; }
int128& operator>>=(const int& other) { return *this = *this >> other; }
int128& operator^=(const int128& other) { a ^= other.a; return *this; }
int128& operator|=(const int128& other) { a |= other.a; return *this; }
int128& operator&=(const int128& other) { a &= other.a; return *this; }
friend ostream& operator<<(ostream& s, const int128& a);
};
/* This interface compatible with the gfp interface
* which then allows us to template the Share
* data type.
*/
/*
Arithmetic in Gf_{2^n} with n<=128
*/
class gf2n_long
{
int128 a;
static int n,t1,t2,t3,nterms;
static int l0,l1,l2,l3;
static int128 mask,lowermask,uppermask;
static bool rewind;
/* Assign x[0..2*nwords] to a and reduce it... */
void reduce_trinomial(int128 xh,int128 xl);
void reduce_pentanomial(int128 xh,int128 xl);
public:
typedef gf2n_long value_type;
typedef int128 internal_type;
void reduce(int128 xh,int128 xl)
{
if (nterms==3)
{ reduce_pentanomial(xh,xl); }
else
{ reduce_trinomial(xh,xl); }
}
static void init_field(int nn);
static int degree() { return n; }
static int get_nterms() { return nterms; }
static int get_t(int i)
{ if (i==0) { return t1; }
else if (i==1) { return t2; }
else if (i==2) { return t3; }
return -1;
}
static DataFieldType field_type() { return DATA_GF2N; }
static char type_char() { return '2'; }
static string type_string() { return "gf2n_long"; }
static int size() { return sizeof(a); }
static int t() { return 0; }
int128 get() const { return a; }
__m128i to_m128i() const { return a.a; }
word get_word() const { return _mm_cvtsi128_si64x(a.a); }
void assign(const gf2n_long& g) { a=g.a; }
void assign_zero() { a=_mm_setzero_si128(); }
void assign_one() { a=int128(0,1); }
void assign_x() { a=int128(0,2); }
void assign(int128 aa) { a=aa&mask; }
void assign(int aa) { a=int128(static_cast<unsigned int>(aa))&mask; }
void assign(const char* buffer) { a = _mm_loadu_si128((__m128i*)buffer); }
int get_bit(int i) const
{ return ((a>>i)&1).get_lower(); }
void set_bit(int i,unsigned int b)
{ if (b==1)
{ a |= (1UL<<i); }
else
{ a &= ~(1UL<<i); }
}
gf2n_long() { assign_zero(); }
gf2n_long(const gf2n_long& g) { assign(g); }
gf2n_long(const int128& g) { assign(g); }
gf2n_long(int g) { assign(g); }
~gf2n_long() { ; }
gf2n_long& operator=(const gf2n_long& g)
{ if (&g!=this) { assign(g); }
return *this;
}
int is_zero() const { return a==int128(0); }
int is_one() const { return a==int128(1); }
int equal(const gf2n_long& y) const { return (a==y.a); }
bool operator==(const gf2n_long& y) const { return a==y.a; }
bool operator!=(const gf2n_long& y) const { return a!=y.a; }
// x+y
void add(const gf2n_long& x,const gf2n_long& y)
{ a=x.a^y.a; }
void add(const gf2n_long& x)
{ a^=x.a; }
template<int T>
void add(octet* x)
{ a^=int128(_mm_loadu_si128((__m128i*)x)); }
void add(octet* x)
{ add<0>(x); }
void sub(const gf2n_long& x,const gf2n_long& y)
{ a=x.a^y.a; }
void sub(const gf2n_long& x)
{ a^=x.a; }
// = x * y
gf2n_long& mul(const gf2n_long& x,const gf2n_long& y);
void mul(const gf2n_long& x) { mul(*this,x); }
// x * y when one of x,y is a bit
void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; }
gf2n_long operator+(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; }
gf2n_long operator*(const gf2n_long& x) { gf2n_long res; res.mul(*this, x); return res; }
gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; }
gf2n_long& operator*=(const gf2n_long& x) { mul(x); return *this; }
gf2n_long operator-(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; }
gf2n_long& operator-=(const gf2n_long& x) { sub(x); return *this; }
void square();
void square(const gf2n_long& aa);
void invert();
void invert(const gf2n_long& aa)
{ *this=aa; invert(); }
void negate() { return; }
void power(long i);
/* Bitwise Ops */
void AND(const gf2n_long& x,const gf2n_long& y) { a=x.a&y.a; }
void XOR(const gf2n_long& x,const gf2n_long& y) { a=x.a^y.a; }
void OR(const gf2n_long& x,const gf2n_long& y) { a=x.a|y.a; }
void NOT(const gf2n_long& x) { a=(~x.a)&mask; }
void SHL(const gf2n_long& x,int n) { a=(x.a<<n)&mask; }
void SHR(const gf2n_long& x,int n) { a=x.a>>n; }
gf2n_long operator&(const gf2n_long& x) { gf2n_long res; res.AND(*this, x); return res; }
gf2n_long operator^(const gf2n_long& x) { gf2n_long res; res.XOR(*this, x); return res; }
gf2n_long operator|(const gf2n_long& x) { gf2n_long res; res.OR(*this, x); return res; }
gf2n_long operator!() { gf2n_long res; res.NOT(*this); return res; }
gf2n_long operator<<(int i) { gf2n_long res; res.SHL(*this, i); return res; }
gf2n_long operator>>(int i) { gf2n_long res; res.SHR(*this, i); return res; }
/* Crap RNG */
void randomize(PRNG& G);
// compatibility with gfp
void almost_randomize(PRNG& G) { randomize(G); }
void output(ostream& s,bool human) const;
void input(istream& s,bool human);
friend ostream& operator<<(ostream& s,const gf2n_long& x)
{ s << hex << x.a << dec;
return s;
}
friend istream& operator>>(istream& s,gf2n_long& x)
{ bigint tmp;
s >> hex >> tmp >> dec;
x.a = 0;
mpn_copyi((word*)&x.a.a, tmp.get_mpz_t()->_mp_d, tmp.get_mpz_t()->_mp_size);
return s;
}
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const
{ o.append((octet*) &a,sizeof(__m128i)); }
void unpack(octetStream& o)
{ o.consume((octet*) &a,sizeof(__m128i)); }
};
inline int128 int128::operator<<(const int& other) const
{
int128 res(_mm_slli_epi64(a, other));
__m128i mask;
if (other < 64)
mask = _mm_srli_epi64(a, 64 - other);
else
mask = _mm_slli_epi64(a, other - 64);
res.a ^= _mm_slli_si128(mask, 8);
return res;
}
inline int128 int128::operator>>(const int& other) const
{
int128 res(_mm_srli_epi64(a, other));
__m128i mask;
if (other < 64)
mask = _mm_slli_epi64(a, 64 - other);
else
mask = _mm_srli_epi64(a, other - 64);
res.a ^= _mm_srli_si128(mask, 8);
return res;
}
inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2)
{
__m128i tmp3, tmp4, tmp5, tmp6;
tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
tmp4 = _mm_clmulepi64_si128(a, b, 0x10);
tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
tmp6 = _mm_clmulepi64_si128(a, b, 0x11);
tmp4 = _mm_xor_si128(tmp4, tmp5);
tmp5 = _mm_slli_si128(tmp4, 8);
tmp4 = _mm_srli_si128(tmp4, 8);
tmp3 = _mm_xor_si128(tmp3, tmp5);
tmp6 = _mm_xor_si128(tmp6, tmp4);
// initial mul now in tmp3, tmp6
*res1 = tmp3;
*res2 = tmp6;
}
#endif /* MATH_GF2NLONG_H_ */

125
Math/gfp.cpp Normal file
View File

@@ -0,0 +1,125 @@
// (C) 2016 University of Bristol. See License.txt
#include "Math/gfp.h"
#include "Exceptions/Exceptions.h"
Zp_Data gfp::ZpD;
void gfp::almost_randomize(PRNG& G)
{
G.get_octets((octet*)a.x,t()*sizeof(mp_limb_t));
a.x[t()-1]&=ZpD.mask;
}
void gfp::AND(const gfp& x,const gfp& y)
{
bigint bi1,bi2;
to_bigint(bi1,x);
to_bigint(bi2,y);
mpz_and(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
to_gfp(*this, bi1);
}
void gfp::OR(const gfp& x,const gfp& y)
{
bigint bi1,bi2;
to_bigint(bi1,x);
to_bigint(bi2,y);
mpz_ior(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
to_gfp(*this, bi1);
}
void gfp::XOR(const gfp& x,const gfp& y)
{
bigint bi1,bi2;
to_bigint(bi1,x);
to_bigint(bi2,y);
mpz_xor(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
to_gfp(*this, bi1);
}
void gfp::AND(const gfp& x,const bigint& y)
{
bigint bi;
to_bigint(bi,x);
mpz_and(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
to_gfp(*this, bi);
}
void gfp::OR(const gfp& x,const bigint& y)
{
bigint bi;
to_bigint(bi,x);
mpz_ior(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
to_gfp(*this, bi);
}
void gfp::XOR(const gfp& x,const bigint& y)
{
bigint bi;
to_bigint(bi,x);
mpz_xor(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
to_gfp(*this, bi);
}
void gfp::SHL(const gfp& x,int n)
{
if (!x.is_zero())
{
bigint bi;
to_bigint(bi,x,false);
mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
}
else
{
assign_zero();
}
}
void gfp::SHR(const gfp& x,int n)
{
if (!x.is_zero())
{
bigint bi;
to_bigint(bi,x);
mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
}
else
{
assign_zero();
}
}
void gfp::SHL(const gfp& x,const bigint& n)
{
SHL(x,mpz_get_si(n.get_mpz_t()));
}
void gfp::SHR(const gfp& x,const bigint& n)
{
SHR(x,mpz_get_si(n.get_mpz_t()));
}
gfp gfp::sqrRoot()
{
// Temp move to bigint so as to call sqrRootMod
bigint ti;
to_bigint(ti, *this);
ti = sqrRootMod(ti, ZpD.pr);
if (!isOdd(ti))
ti = ZpD.pr - ti;
gfp temp;
to_gfp(temp, ti);
return temp;
}

205
Math/gfp.h Normal file
View File

@@ -0,0 +1,205 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _gfp
#define _gfp
#include <iostream>
using namespace std;
#include "Math/gf2n.h"
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/field_types.h"
#include "Tools/random.h"
/* This is a wrapper class for the modp data type
* It is used to be interface compatible with the gfp
* type, which then allows us to template the Share
* data type.
*
* So gfp is used ONLY for the stuff in the finite fields
* we are going to be doing MPC over, not the modp stuff
* for the FHE scheme
*/
class gfp
{
modp a;
static Zp_Data ZpD;
public:
typedef gfp value_type;
static void init_field(const bigint& p,bool mont=true)
{ ZpD.init(p,mont); }
static bigint pr()
{ return ZpD.pr; }
static int t()
{ return ZpD.get_t(); }
static Zp_Data& get_ZpD()
{ return ZpD; }
static DataFieldType field_type() { return DATA_MODP; }
static char type_char() { return 'p'; }
static string type_string() { return "gfp"; }
static int size() { return t() * sizeof(mp_limb_t); }
void assign(const gfp& g) { a=g.a; }
void assign_zero() { assignZero(a,ZpD); }
void assign_one() { assignOne(a,ZpD); }
void assign(word aa) { bigint b=aa; to_gfp(*this,b); }
void assign(long aa) { bigint b=aa; to_gfp(*this,b); }
void assign(int aa) { bigint b=aa; to_gfp(*this,b); }
void assign(const char* buffer) { a.assign(buffer, ZpD.get_t()); }
modp get() const { return a; }
// Assumes prD behind x is equal to ZpD
void assign(modp& x) { a=x; }
gfp() { assignZero(a,ZpD); }
gfp(const gfp& g) { a=g.a; }
gfp(const modp& g) { a=g; }
gfp(const __m128i& x) { *this=x; }
gfp(const int128& x) { *this=x.a; }
gfp(const bigint& x) { to_modp(a, x, ZpD); }
gfp(int x) { assign(x); }
~gfp() { ; }
gfp& operator=(const gfp& g)
{ if (&g!=this) { a=g.a; }
return *this;
}
gfp& operator=(const __m128i other)
{
memcpy(a.x, &other, sizeof(other));
a.x[2] = 0;
return *this;
}
void to_m128i(__m128i& ans)
{
memcpy(&ans, a.x, sizeof(ans));
}
__m128i to_m128i()
{
return _mm_loadu_si128((__m128i*)a.x);
}
bool is_zero() const { return isZero(a,ZpD); }
bool is_one() const { return isOne(a,ZpD); }
bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); }
bool operator==(const gfp& y) const { return equal(y); }
bool operator!=(const gfp& y) const { return !equal(y); }
// x+y
template <int T>
void add(const gfp& x,const gfp& y)
{ Add<T>(a,x.a,y.a,ZpD); }
template <int T>
void add(const gfp& x)
{ Add<T>(a,a,x.a,ZpD); }
template <int T>
void add(void* x)
{ ZpD.Add<T>(a.x,a.x,(mp_limb_t*)x); }
void add(const gfp& x,const gfp& y)
{ Add(a,x.a,y.a,ZpD); }
void add(const gfp& x)
{ Add(a,a,x.a,ZpD); }
void add(void* x)
{ ZpD.Add(a.x,a.x,(mp_limb_t*)x); }
void sub(const gfp& x,const gfp& y)
{ Sub(a,x.a,y.a,ZpD); }
void sub(const gfp& x)
{ Sub(a,a,x.a,ZpD); }
// = x * y
void mul(const gfp& x,const gfp& y)
{ Mul(a,x.a,y.a,ZpD); }
void mul(const gfp& x)
{ Mul(a,a,x.a,ZpD); }
gfp operator+(const gfp& x) { gfp res; res.add(*this, x); return res; }
gfp operator-(const gfp& x) { gfp res; res.sub(*this, x); return res; }
gfp operator*(const gfp& x) { gfp res; res.mul(*this, x); return res; }
gfp& operator+=(const gfp& x) { add(x); return *this; }
gfp& operator-=(const gfp& x) { sub(x); return *this; }
gfp& operator*=(const gfp& x) { mul(x); return *this; }
void square(const gfp& aa)
{ Sqr(a,aa.a,ZpD); }
void square()
{ Sqr(a,a,ZpD); }
void invert()
{ Inv(a,a,ZpD); }
void invert(const gfp& aa)
{ Inv(a,aa.a,ZpD); }
void negate()
{ Negate(a,a,ZpD); }
void power(long i)
{ Power(a,a,i,ZpD); }
// deterministic square root
gfp sqrRoot();
void randomize(PRNG& G)
{ a.randomize(G,ZpD); }
// faster randomization, see implementation for explanation
void almost_randomize(PRNG& G);
void output(ostream& s,bool human) const
{ a.output(s,ZpD,human); }
void input(istream& s,bool human)
{ a.input(s,ZpD,human); }
friend ostream& operator<<(ostream& s,const gfp& x)
{ x.output(s,true);
return s;
}
friend istream& operator>>(istream& s,gfp& x)
{ x.input(s,true);
return s;
}
/* Bitwise Ops
* - Converts gfp args to bigints and then converts answer back to gfp
*/
void AND(const gfp& x,const gfp& y);
void XOR(const gfp& x,const gfp& y);
void OR(const gfp& x,const gfp& y);
void AND(const gfp& x,const bigint& y);
void XOR(const gfp& x,const bigint& y);
void OR(const gfp& x,const bigint& y);
void SHL(const gfp& x,int n);
void SHR(const gfp& x,int n);
void SHL(const gfp& x,const bigint& n);
void SHR(const gfp& x,const bigint& n);
gfp operator&(const gfp& x) { gfp res; res.AND(*this, x); return res; }
gfp operator^(const gfp& x) { gfp res; res.XOR(*this, x); return res; }
gfp operator|(const gfp& x) { gfp res; res.OR(*this, x); return res; }
gfp operator<<(int i) { gfp res; res.SHL(*this, i); return res; }
gfp operator>>(int i) { gfp res; res.SHR(*this, i); return res; }
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const
{ a.pack(o,ZpD); }
void unpack(octetStream& o)
{ a.unpack(o,ZpD); }
// Convert representation to and from a bigint number
friend void to_bigint(bigint& ans,const gfp& x,bool reduce=true)
{ to_bigint(ans,x.a,x.ZpD,reduce); }
friend void to_gfp(gfp& ans,const bigint& x)
{ to_modp(ans.a,x,ans.ZpD); }
};
#endif

263
Math/modp.cpp Normal file
View File

@@ -0,0 +1,263 @@
// (C) 2016 University of Bristol. See License.txt
#include "Zp_Data.h"
#include "modp.h"
#include "Exceptions/Exceptions.h"
bool modp::rewind = false;
/***********************************************************************
* The following functions remain the same in Real and Montgomery rep *
***********************************************************************/
void modp::randomize(PRNG& G, const Zp_Data& ZpD)
{
bigint x=G.randomBnd(ZpD.pr);
to_modp(*this,x,ZpD);
}
void modp::pack(octetStream& o,const Zp_Data& ZpD) const
{
o.append((octet*) x,ZpD.t*sizeof(mp_limb_t));
}
void modp::unpack(octetStream& o,const Zp_Data& ZpD)
{
o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t));
}
void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
{
ZpD.Sub(ans.x, x.x, y.x);
}
void Negate(modp& ans,const modp& x,const Zp_Data& ZpD)
{
if (isZero(x,ZpD)) { ans=x; return; }
mpn_sub_n(ans.x,ZpD.prA,x.x,ZpD.t);
}
bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD)
{ if (mpn_cmp(x.x,y.x,ZpD.t)!=0)
{ return false; }
return true;
}
bool isZero(const modp& ans,const Zp_Data& ZpD)
{
for (int i=0; i<ZpD.t; i++)
{ if (ans.x[i]!=0) { return false; } }
return true;
}
/***********************************************************************
* All the remaining functions have Montgomery variants which we need *
* to deal with *
***********************************************************************/
void assignOne(modp& x,const Zp_Data& ZpD)
{ if (ZpD.montgomery)
{ mpn_copyi(x.x,ZpD.R,ZpD.t+1); }
else
{ assignZero(x,ZpD);
x.x[0]=1;
}
}
bool isOne(const modp& x,const Zp_Data& ZpD)
{ if (ZpD.montgomery)
{ if (mpn_cmp(x.x,ZpD.R,ZpD.t)!=0)
{ return false; }
}
else
{ if (x.x[0]!=1) { return false; }
for (int i=1; i<ZpD.t; i++)
{ if (x.x[i]!=0) { return false; } }
}
return true;
}
void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce)
{
mpz_t a;
mpz_init2(a,MAX_MOD_SZ*sizeof(mp_limb_t)*8);
if (ZpD.montgomery)
{ mp_limb_t one[MAX_MOD_SZ];
mpn_zero(one,ZpD.t+1);
one[0]=1;
ZpD.Mont_Mult(a->_mp_d,x.x,one);
}
else
{ mpn_copyi(a->_mp_d,x.x,ZpD.t+1); }
a->_mp_size=ZpD.t;
if (reduce)
while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0)
{ a->_mp_size--; }
ans=bigint(a);
mpz_clear(a);
}
void to_modp(modp& ans,int x,const Zp_Data& ZpD)
{
mpn_zero(ans.x,ZpD.t+1);
if (x>=0)
{ ans.x[0]=x;
if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; }
}
else
{ if (ZpD.t==1)
{ ans.x[0]=(ZpD.prA[0]+x)%ZpD.prA[0]; }
else
{ bigint xx=ZpD.pr+x;
to_modp(ans,xx,ZpD);
return;
}
}
if (ZpD.montgomery)
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); }
}
void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD)
{
bigint xx=x%ZpD.pr;
if (xx<0) { xx+=ZpD.pr; }
//mpz_mod(xx.get_mpz_t(),x.get_mpz_t(),ZpD.pr.get_mpz_t());
mpn_zero(ans.x,ZpD.t+1);
mpn_copyi(ans.x,xx.get_mpz_t()->_mp_d,xx.get_mpz_t()->_mp_size);
if (ZpD.montgomery)
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); }
}
void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
{
if (ZpD.montgomery)
{ ZpD.Mont_Mult(ans.x,x.x,y.x); }
else
{ //ans.x=(x.x*y.x)%ZpD.pr;
mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ];
mpn_mul_n(aa,x.x,y.x,ZpD.t);
mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t);
}
}
void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD)
{
if (ZpD.montgomery)
{ ZpD.Mont_Mult(ans.x,x.x,x.x); }
else
{ //ans.x=(x.x*x.x)%ZpD.pr;
mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ];
mpn_sqr(aa,x.x,ZpD.t);
mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t);
}
}
void Inv(modp& ans,const modp& x,const Zp_Data& ZpD)
{
mp_limb_t g[MAX_MOD_SZ],xx[MAX_MOD_SZ+1],yy[MAX_MOD_SZ+1];
mp_size_t sz;
mpn_copyi(xx,x.x,ZpD.t);
mpn_copyi(yy,ZpD.prA,ZpD.t);
mpn_gcdext(g,ans.x,&sz,xx,ZpD.t,yy,ZpD.t);
if (sz<0)
{ mpn_sub(ans.x,ZpD.prA,ZpD.t,ans.x,-sz);
sz=-sz;
}
else
{ for (int i=sz; i<ZpD.t; i++) { ans.x[i]=0; } }
if (ZpD.montgomery)
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R3); }
}
// XXXX This is a crap version. Hopefully this is not time critical
void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD)
{
if (exp==1) { ans=x; return; }
if (exp==0) { assignOne(ans,ZpD); return; }
if (exp<0) { throw not_implemented(); }
modp t=x;
assignOne(ans,ZpD);
while (exp!=0)
{ if ((exp&1)==1) { Mul(ans,ans,t,ZpD); }
exp>>=1;
Sqr(t,t,ZpD);
}
}
// XXXX This is a crap version. Hopefully this is not time critical
void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD)
{
if (exp==1) { ans=x; return; }
if (exp==0) { assignOne(ans,ZpD); return; }
if (exp<0) { throw not_implemented(); }
modp t=x;
assignOne(ans,ZpD);
bigint e=exp;
while (e!=0)
{ if ((e&1)==1) { Mul(ans,ans,t,ZpD); }
e>>=1;
Sqr(t,t,ZpD);
}
}
void modp::output(ostream& s,const Zp_Data& ZpD,bool human) const
{
if (human)
{ bigint te;
to_bigint(te,*this,ZpD);
if (te < ZpD.pr / 2)
s << te;
else
s << (te - ZpD.pr);
}
else
{ s.write((char*) x,ZpD.t*sizeof(mp_limb_t)); }
}
void modp::input(istream& s,const Zp_Data& ZpD,bool human)
{
if (s.peek() == EOF)
{ if (s.tellg() == 0)
{ cout << "IO problem. Empty file?" << endl;
throw file_error();
}
//throw end_of_file();
s.clear(); // unset EOF flag
s.seekg(0);
if (!rewind)
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
rewind = true;
}
if (human)
{ bigint te;
s >> te;
to_modp(*this,te,ZpD);
}
else
{ s.read((char*) x,ZpD.t*sizeof(mp_limb_t)); }
}

116
Math/modp.h Normal file
View File

@@ -0,0 +1,116 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Modp
#define _Modp
/*
* Currently we only support an MPIR based implementation.
*
* What ever is type-def'd to bigint is assumed to have
* operator overloading for all standard operators, has
* comparison operations and istream/ostream operators >>/<<.
*
* All "integer" operations will be done using operator notation
* all "modp" operations should be done using the function calls
* below (interchange with Montgomery arithmetic).
*
*/
#include "Tools/octetStream.h"
#include "Tools/random.h"
#include "Math/bigint.h"
#include "Math/Zp_Data.h"
class modp
{
static bool rewind;
mp_limb_t x[MAX_MOD_SZ];
public:
// NEXT FUNCTION IS FOR DEBUG PURPOSES ONLY
mp_limb_t get_limb(int i) { return x[i]; }
// use mem* functions instead of mpn_*, so the compiler can optimize
modp()
{ memset(x, 0, sizeof(x)); }
modp(const modp& y)
{ memcpy(x, y.x, sizeof(x)); }
modp& operator=(const modp& y)
{ if (this!=&y) { memcpy(x, y.x, sizeof(x)); }
return *this;
}
void assign(const char* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); }
void randomize(PRNG& G, const Zp_Data& ZpD);
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
// i.e. When we do montgomery we dont care about decoding
void pack(octetStream& o,const Zp_Data& ZpD) const;
void unpack(octetStream& o,const Zp_Data& ZpD);
/**********************************
* Modp Operations *
**********************************/
// Convert representation to and from a modp number
friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce=true);
friend void to_modp(modp& ans,int x,const Zp_Data& ZpD);
friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD);
template <int T>
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
{ ZpD.Add(ans.x, x.x, y.x); }
friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD);
friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD);
friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD);
friend void assignOne(modp& x,const Zp_Data& ZpD);
friend void assignZero(modp& x,const Zp_Data& ZpD);
friend bool isZero(const modp& x,const Zp_Data& ZpD);
friend bool isOne(const modp& x,const Zp_Data& ZpD);
friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD);
// Input and output from a stream
// - Can do in human or machine only format (later should be faster)
// - If human output appends a space to help with reading
// and also convert back/forth from Montgomery if needed
void output(ostream& s,const Zp_Data& ZpD,bool human) const;
void input(istream& s,const Zp_Data& ZpD,bool human);
friend class gfp;
};
inline void assignZero(modp& x,const Zp_Data& ZpD)
{
if (sizeof(x.x) <= 3 * 16)
// use memset to allow the compiler to optimize
// if x.x is at most 3*128 bits
memset(x.x, 0, sizeof(x.x));
else
mpn_zero(x.x, ZpD.t + 1);
}
template <int T>
inline void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
{
ZpD.Add<T>(ans.x, x.x, y.x);
}
#endif

37
Math/operators.h Normal file
View File

@@ -0,0 +1,37 @@
// (C) 2016 University of Bristol. See License.txt
/*
* operations.h
*
*/
#ifndef MATH_OPERATORS_H_
#define MATH_OPERATORS_H_
template <class T>
T operator*(const bool& x, const T& y) { return x ? y : T(); }
template <class T>
T operator*(const T& y, const bool& x) { return x ? y : T(); }
template <class T>
T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; }
template <class T, class U>
T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; }
template <class T, class U>
T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; }
template <class T, class U>
T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; }
template <class T, class U>
T& operator+=(T& x, const U& y) { x.add(y); return x; }
template <class T, class U>
T& operator*=(T& x, const U& y) { x.mul(y); return x; }
template <class T, class U>
T& operator-=(T& x, const U& y) { x.sub(y); return x; }
template <class T, class U>
T operator/(const T& x, const U& y) { U inv = y; inv.invert(); return x * inv; }
template <class T, class U>
T& operator/=(const T& x, const U& y) { U inv = y; inv.invert(); return x *= inv; }
#endif /* MATH_OPERATORS_H_ */

411
Networking/Player.cpp Normal file
View File

@@ -0,0 +1,411 @@
// (C) 2016 University of Bristol. See License.txt
#include "Player.h"
#include "Exceptions/Exceptions.h"
#include <sys/select.h>
// Use printf rather than cout so valgrind can detect thread issues
void Names::init(int player,int pnb,const char* servername)
{
player_no=player;
portnum_base=pnb;
setup_names(servername);
setup_server();
}
void Names::init(int player,int pnb,vector<octet*> Nms)
{
player_no=player;
portnum_base=pnb;
nplayers=Nms.size();
names.resize(nplayers);
for (int i=0; i<nplayers; i++)
{ names[i]=(char*)Nms[i]; }
setup_server();
}
void Names::init(int player,int pnb,vector<string> Nms)
{
player_no=player;
portnum_base=pnb;
nplayers=Nms.size();
names=Nms;
setup_server();
}
// initialize hostnames from file
void Names::init(int player, int _nplayers, int pnb, const string& filename)
{
ifstream hostsfile(filename.c_str());
if (hostsfile.fail())
{
stringstream ss;
ss << "Error opening " << filename << ". See HOSTS.example for an example.";
throw file_error(ss.str().c_str());
}
player_no = player;
nplayers = _nplayers;
portnum_base = pnb;
string line;
while (getline(hostsfile, line))
{
if (line.length() > 0 && line.at(0) != '#')
names.push_back(line);
}
if ((int)names.size() < nplayers)
throw invalid_params();
names.resize(nplayers);
for (unsigned int i = 0; i < names.size(); i++)
cerr << "name: " << names[i] << endl;
setup_server();
}
void Names::setup_names(const char *servername)
{
int socket_num;
int pn = portnum_base - 1;
set_up_client_socket(socket_num, servername, pn);
send(socket_num, (octet*)&player_no, sizeof(player_no));
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
int inst=-1; // wait until instruction to start.
while (inst != GO) { receive(socket_num, inst); }
// Send my name
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
fprintf(stderr, "My Name = %s\n",my_name);
send(socket_num,my_name,512);
cerr << "My number = " << player_no << endl;
// Now get the set of names
int i;
receive(socket_num,nplayers);
cerr << nplayers << " players\n";
names.resize(nplayers);
for (i=0; i<nplayers; i++)
{ octet tmp[512];
receive(socket_num,tmp,512);
names[i]=(char*)tmp;
cerr << "Player " << i << " is running on machine " << names[i] << endl;
}
close_client_socket(socket_num);
}
void Names::setup_server()
{
server = new ServerSocket(portnum_base + player_no);
}
Names::Names(const Names& other)
{
if (other.server != 0)
throw runtime_error("Can copy Names only when uninitialized");
player_no = other.player_no;
nplayers = other.nplayers;
portnum_base = other.portnum_base;
names = other.names;
server = 0;
}
Names::~Names()
{
if (server != 0)
delete server;
}
Player::Player(const Names& Nms, int id) : PlayerBase(Nms), send_to_self_socket(-1)
{
nplayers=Nms.nplayers;
player_no=Nms.player_no;
setup_sockets(Nms.names, Nms.portnum_base, id, *Nms.server);
blk_SHA1_Init(&ctx);
}
Player::~Player()
{
/* Close down the sockets */
for (int i=0; i<nplayers; i++)
close_client_socket(sockets[i]);
}
// Set up nmachines client and server sockets to send data back and fro
// A machine is a server between it and player i if i<=my_number
// Can also communicate with myself, but only with send_to and receive_from
void Player::setup_sockets(const vector<string>& names,int portnum_base,int id_base,ServerSocket& server)
{
sockets.resize(nplayers);
// Set up the client side
for (int i=player_no; i<nplayers; i++)
{ int pn=id_base+i*nplayers+player_no;
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),portnum_base+i,pn);
set_up_client_socket(sockets[i],names[i].c_str(),portnum_base+i);
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
}
send_to_self_socket = sockets[player_no];
// Setting up the server side
for (int i=0; i<=player_no; i++)
{ int id=id_base+player_no*nplayers+i;
fprintf(stderr, "Setting up server with id 0x%x\n",id);
sockets[i] = server.get_connection_socket(id);
}
for (int i = 0; i < nplayers; i++)
{
// timeout of 5 minutes
struct timeval tv;
tv.tv_sec = 300;
tv.tv_usec = 0;
int fl = setsockopt(sockets[i], SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(struct timeval));
if (fl<0) { error("set_up_socket:setsockopt"); }
socket_players[sockets[i]] = i;
}
}
void Player::send_to(int player,const octetStream& o,bool donthash) const
{
int socket = socket_to_send(player);
o.Send(socket);
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
}
void Player::send_all(const octetStream& o,bool donthash) const
{ for (int i=0; i<nplayers; i++)
{ if (i!=player_no)
{ o.Send(sockets[i]); }
}
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
}
void Player::receive_player(int i,octetStream& o,bool donthash) const
{
o.reset_write_head();
o.Receive(sockets[i]);
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
}
/* This is deliberately weird to avoid problems with OS max buffer
* size getting in the way
*/
void Player::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
{ for (int i=0; i<nplayers; i++)
{ if (i>player_no)
{ o[player_no].Send(sockets[i]); }
else if (i<player_no)
{ o[i].reset_write_head();
o[i].Receive(sockets[i]);
}
}
for (int i=0; i<nplayers; i++)
{ if (i<player_no)
{ o[player_no].Send(sockets[i]); }
else if (i>player_no)
{ o[i].reset_write_head();
o[i].Receive(sockets[i]);
}
}
if (!donthash)
{ for (int i=0; i<nplayers; i++)
{ blk_SHA1_Update(&ctx,o[i].get_data(),o[i].get_length()); }
}
}
void Player::Check_Broadcast() const
{
octet hashVal[HASH_SIZE];
vector<octetStream> h(nplayers);
blk_SHA1_Final(hashVal,&ctx);
h[player_no].append(hashVal,HASH_SIZE);
Broadcast_Receive(h,true);
for (int i=0; i<nplayers; i++)
{ if (i!=player_no)
{ if (!h[i].equals(h[player_no]))
{ throw broadcast_invalid(); }
}
}
blk_SHA1_Init(&ctx);
}
void Player::wait_for_available(vector<int>& players, vector<int>& result) const
{
fd_set rfds;
FD_ZERO(&rfds);
int highest = 0;
vector<int>::iterator it;
for (it = players.begin(); it != players.end(); it++)
{
if (*it >= 0)
{
FD_SET(sockets[*it], &rfds);
highest = max(highest, sockets[*it]);
}
}
int res = select(highest + 1, &rfds, 0, 0, 0);
if (res < 0)
error("select()");
result.clear();
result.reserve(res);
for (it = players.begin(); it != players.end(); it++)
{
if (res == 0)
break;
if (*it >= 0 && FD_ISSET(sockets[*it], &rfds))
{
res--;
result.push_back(*it);
}
}
}
ThreadPlayer::ThreadPlayer(const Names& Nms, int id_base) : Player(Nms, id_base)
{
for (int i = 0; i < Nms.num_players(); i++)
{
receivers.push_back(new Receiver(sockets[i]));
receivers[i]->start();
senders.push_back(new Sender(socket_to_send(i)));
senders[i]->start();
}
}
ThreadPlayer::~ThreadPlayer()
{
for (unsigned int i = 0; i < receivers.size(); i++)
{
receivers[i]->stop();
if (receivers[i]->timer.elapsed() > 0)
cerr << "Waiting for receiving from " << i << ": " << receivers[i]->timer.elapsed() << endl;
delete receivers[i];
}
for (unsigned int i = 0; i < senders.size(); i++)
{
senders[i]->stop();
if (senders[i]->timer.elapsed() > 0)
cerr << "Waiting for sending to " << i << ": " << senders[i]->timer.elapsed() << endl;
delete senders[i];
}
}
void ThreadPlayer::request_receive(int i, octetStream& o) const
{
receivers[i]->request(o);
}
void ThreadPlayer::wait_receive(int i, octetStream& o, bool donthash) const
{
receivers[i]->wait(o);
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
}
void ThreadPlayer::receive_player(int i, octetStream& o, bool donthash) const
{
request_receive(i, o);
wait_receive(i, o, donthash);
}
void ThreadPlayer::send_all(const octetStream& o,bool donthash) const
{
for (int i=0; i<nplayers; i++)
{ if (i!=player_no)
senders[i]->request(o);
}
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
for (int i = 0; i < nplayers; i++)
if (i != player_no)
senders[i]->wait(o);
}
TwoPartyPlayer::TwoPartyPlayer(const Names& Nms, int other_player, int id) : PlayerBase(Nms), other_player(other_player)
{
is_server = Nms.my_num() > other_player;
setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id);
}
TwoPartyPlayer::~TwoPartyPlayer()
{
close_client_socket(socket);
}
void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id)
{
if (is_server)
{
fprintf(stderr, "Setting up server with id %d\n",id);
socket = server.get_connection_socket(id);
}
else
{
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id);
set_up_client_socket(socket, hostname, pn);
::send(socket, (unsigned char*)&id, sizeof(id));
}
}
int TwoPartyPlayer::other_player_num() const
{
return other_player;
}
void TwoPartyPlayer::send(octetStream& o) const
{
o.Send(socket);
}
void TwoPartyPlayer::receive(octetStream& o) const
{
o.reset_write_head();
o.Receive(socket);
}
void TwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
{
if (is_server)
{
o[0].Send(socket);
o[1].reset_write_head();
o[1].Receive(socket);
}
else
{
o[1].reset_write_head();
o[1].Receive(socket);
o[0].Send(socket);
}
}
}

185
Networking/Player.h Normal file
View File

@@ -0,0 +1,185 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Player
#define _Player
/* Class to create a player, for KeyGen, Offline and Online phases.
*
* Basically handles connection to the server to obtain the names
* of the other players. Plus sending and receiving of data
*
*/
#include <vector>
#include <set>
#include <iostream>
#include <fstream>
using namespace std;
#include "Tools/octetStream.h"
#include "Networking/sockets.h"
#include "Networking/ServerSocket.h"
#include "Tools/sha1.h"
#include "Networking/Receiver.h"
#include "Networking/Sender.h"
/* Class to get the names off the server */
class Names
{
vector<string> names;
int nplayers;
int portnum_base;
int player_no;
void setup_names(const char *servername);
void setup_server();
public:
mutable ServerSocket* server;
// Usual setup names
void init(int player,int pnb,const char* servername);
Names(int player,int pnb,const char* servername)
{ init(player,pnb,servername); }
// Set up names when we KNOW who we are going to be using before hand
void init(int player,int pnb,vector<octet*> Nms);
Names(int player,int pnb,vector<octet*> Nms)
{ init(player,pnb,Nms); }
void init(int player,int pnb,vector<string> Nms);
Names(int player,int pnb,vector<string> Nms)
{ init(player,pnb,Nms); }
// Set up names from file -- reads the first nplayers names in the file
void init(int player, int nplayers, int pnb, const string& hostsfile);
Names(int player, int nplayers, int pnb, const string& hostsfile)
{ init(player, nplayers, pnb, hostsfile); }
Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; }
Names(const Names& other);
~Names();
int num_players() const { return nplayers; }
int my_num() const { return player_no; }
const string get_name(int i) const { return names[i]; }
int get_portnum_base() const { return portnum_base; }
friend class PlayerBase;
friend class Player;
friend class TwoPartyPlayer;
};
class PlayerBase
{
protected:
int player_no;
public:
PlayerBase(const Names& Nms) : player_no(Nms.my_num()) {}
int my_num() const { return player_no; }
};
class Player : public PlayerBase
{
protected:
vector<int> sockets;
int send_to_self_socket;
void setup_sockets(const vector<string>& names,int portnum_base,int id_base,ServerSocket& server);
int nplayers;
mutable blk_SHA_CTX ctx;
map<int,int> socket_players;
int socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
public:
// The offset is used for the multi-threaded call, to ensure different
// portnum bases in each thread
Player(const Names& Nms,int id_base=0);
virtual ~Player();
int num_players() const { return nplayers; }
int socket(int i) const { return sockets[i]; }
// Send/Receive data to/from player i
// 8-bit ints only (mainly for testing)
void send_int(int i,int a) const { send(sockets[i],a); }
void receive_int(int i,int& a) const { receive(sockets[i],a); }
// Send an octetStream to all other players
// -- And corresponding receive
virtual void send_all(const octetStream& o,bool donthash=false) const;
void send_to(int player,const octetStream& o,bool donthash=false) const;
virtual void receive_player(int i,octetStream& o,bool donthash=false) const;
// Receive one from player i
/* Broadcast and Receive data to/from all players
* - Assumes o[player_no] contains the thing broadcast by me
*/
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
/* Run Protocol To Verify Broadcast Is Correct
* - Resets the blk_SHA_CTX at the same time
*/
void Check_Broadcast() const;
// wait for available inputs
void wait_for_available(vector<int>& players, vector<int>& result) const;
// dummy functions for compatibility
virtual void request_receive(int i, octetStream& o) const { sockets[i]; o.get_length(); }
virtual void wait_receive(int i, octetStream& o, bool donthash=false) const { receive_player(i, o, donthash); }
};
class ThreadPlayer : public Player
{
public:
mutable vector<Receiver*> receivers;
mutable vector<Sender*> senders;
ThreadPlayer(const Names& Nms,int id_base=0);
virtual ~ThreadPlayer();
void request_receive(int i, octetStream& o) const;
void wait_receive(int i, octetStream& o, bool donthash=false) const;
void receive_player(int i,octetStream& o,bool donthash=false) const;
void send_all(const octetStream& o,bool donthash=false) const;
};
class TwoPartyPlayer : public PlayerBase
{
private:
// setup sockets for comm. with only one other player
void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id);
int socket;
bool is_server;
int other_player;
public:
TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
~TwoPartyPlayer();
void send(octetStream& o) const;
void receive(octetStream& o) const;
int other_player_num() const;
/* Send and receive to/from the other player
* - o[0] contains my data, received data put in o[1]
*/
void send_receive_player(vector<octetStream>& o) const;
};
#endif

58
Networking/Receiver.cpp Normal file
View File

@@ -0,0 +1,58 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Receiver.cpp
*
*/
#include "Receiver.h"
#include <iostream>
using namespace std;
void* run_receiver_thread(void* receiver)
{
((Receiver*)receiver)->run();
return 0;
}
Receiver::Receiver(int socket) : socket(socket), thread(0)
{
}
void Receiver::start()
{
pthread_create(&thread, 0, run_receiver_thread, this);
}
void Receiver::stop()
{
in.stop();
pthread_join(thread, 0);
}
void Receiver::run()
{
octetStream* os = 0;
while (in.pop(os))
{
os->reset_write_head();
timer.start();
os->Receive(socket);
timer.stop();
out.push(os);
}
}
void Receiver::request(octetStream& os)
{
in.push(&os);
}
void Receiver::wait(octetStream& os)
{
octetStream* queued = 0;
out.pop(queued);
if (queued != &os)
throw not_implemented();
}

40
Networking/Receiver.h Normal file
View File

@@ -0,0 +1,40 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Receiver.h
*
*/
#ifndef NETWORKING_RECEIVER_H_
#define NETWORKING_RECEIVER_H_
#include <pthread.h>
#include "Tools/octetStream.h"
#include "Tools/WaitQueue.h"
#include "Tools/time-func.h"
class Receiver
{
int socket;
WaitQueue<octetStream*> in;
WaitQueue<octetStream*> out;
pthread_t thread;
// prevent copying
Receiver(const Receiver& other);
public:
Timer timer;
Receiver(int socket);
void start();
void stop();
void run();
void request(octetStream& os);
void wait(octetStream& os);
};
#endif /* NETWORKING_RECEIVER_H_ */

54
Networking/Sender.cpp Normal file
View File

@@ -0,0 +1,54 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Sender.cpp
*
*/
#include "Sender.h"
void* run_sender_thread(void* sender)
{
((Sender*)sender)->run();
return 0;
}
Sender::Sender(int socket) : socket(socket), thread(0)
{
}
void Sender::start()
{
pthread_create(&thread, 0, run_sender_thread, this);
}
void Sender::stop()
{
in.stop();
pthread_join(thread, 0);
}
void Sender::run()
{
const octetStream* os = 0;
while (in.pop(os))
{
// timer.start();
os->Send(socket);
// timer.stop();
out.push(os);
}
}
void Sender::request(const octetStream& os)
{
in.push(&os);
}
void Sender::wait(const octetStream& os)
{
const octetStream* queued = 0;
out.pop(queued);
if (queued != &os)
throw not_implemented();
}

40
Networking/Sender.h Normal file
View File

@@ -0,0 +1,40 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Sender.h
*
*/
#ifndef NETWORKING_SENDER_H_
#define NETWORKING_SENDER_H_
#include <pthread.h>
#include "Tools/octetStream.h"
#include "Tools/WaitQueue.h"
#include "Tools/time-func.h"
class Sender
{
int socket;
WaitQueue<const octetStream*> in;
WaitQueue<const octetStream*> out;
pthread_t thread;
// prevent copying
Sender(const Sender& other);
public:
Timer timer;
Sender(int socket);
void start();
void stop();
void run();
void request(const octetStream& os);
void wait(const octetStream& os);
};
#endif /* NETWORKING_SENDER_H_ */

115
Networking/ServerSocket.cpp Normal file
View File

@@ -0,0 +1,115 @@
// (C) 2016 University of Bristol. See License.txt
/*
* ServerSocket.cpp
*
*/
#include <Networking/ServerSocket.h>
#include <Networking/sockets.h>
#include "Exceptions/Exceptions.h"
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <iostream>
#include <sstream>
using namespace std;
void* accept_thread(void* server_socket)
{
((ServerSocket*)server_socket)->accept_clients();
return 0;
}
ServerSocket::ServerSocket(int Portnum) : portnum(Portnum)
{
struct sockaddr_in serv; /* socket info about our server */
memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */
serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */
serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */
serv.sin_port = htons(Portnum); /* set the server port number */
main_socket = socket(AF_INET, SOCK_STREAM, 0);
if (main_socket<0) { error("set_up_socket:socket"); }
int one=1;
int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
/* disable Nagle's algorithm */
fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
/* bind serv information to mysocket
* - Just assume it will eventually wake up
*/
fl=1;
while (fl!=0)
{ fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
if (fl != 0)
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
sleep(1);
}
else
{ cerr << "Bound on port " << Portnum << endl; }
}
if (fl<0) { error("set_up_socket:bind"); }
/* start listening, allowing a queue of up to 1000 pending connection */
fl=listen(main_socket, 1000);
if (fl<0) { error("set_up_socket:listen"); }
pthread_create(&thread, 0, accept_thread, this);
}
ServerSocket::~ServerSocket()
{
pthread_cancel(thread);
pthread_join(thread, 0);
if (close(main_socket)) { error("close(main_socket"); };
}
void ServerSocket::accept_clients()
{
while (true)
{
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
int client_id;
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
data_signal.lock();
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
}
}
int ServerSocket::get_connection_socket(int id)
{
data_signal.lock();
if (used.find(id) != used.end())
{
stringstream ss;
ss << "Connection id " << hex << id << " already used";
throw IO_Error(ss.str());
}
while (clients.find(id) == clients.end())
data_signal.wait();
int client = clients[id];
used.insert(id);
data_signal.unlock();
return client;
}

44
Networking/ServerSocket.h Normal file
View File

@@ -0,0 +1,44 @@
// (C) 2016 University of Bristol. See License.txt
/*
* ServerSocket.h
*
*/
#ifndef NETWORKING_SERVERSOCKET_H_
#define NETWORKING_SERVERSOCKET_H_
#include <map>
#include <set>
using namespace std;
#include <pthread.h>
#include "Tools/WaitQueue.h"
#include "Tools/Signal.h"
class ServerSocket
{
int main_socket, portnum;
map<int,int> clients;
set<int> used;
Signal data_signal;
pthread_t thread;
// disable copying
ServerSocket(const ServerSocket& other);
public:
ServerSocket(int Portnum);
~ServerSocket();
void accept_clients();
// This depends on clients sending their id as int.
// Has to be thread-safe.
int get_connection_socket(int number);
void close_socket();
};
#endif /* NETWORKING_SERVERSOCKET_H_ */

45
Networking/data.h Normal file
View File

@@ -0,0 +1,45 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Data
#define _Data
#include <string.h>
#include "Exceptions/Exceptions.h"
typedef unsigned char octet;
// Assumes word is a 64 bit value
#ifdef WIN32
typedef unsigned __int64 word;
#else
typedef unsigned long word;
#endif
#define BROADCAST 0
#define ROUTE 1
#define TERMINATE 2
#define GO 3
void encode_length(octet *buff,int len);
int decode_length(octet *buff);
inline void encode_length(octet *buff,int len)
{
if (len<0) { throw invalid_length(); }
buff[0]=len&255;
buff[1]=(len>>8)&255;
buff[2]=(len>>16)&255;
buff[3]=(len>>24)&255;
}
inline int decode_length(octet *buff)
{
int len=buff[0]+256*buff[1]+65536*buff[2]+16777216*buff[3];
if (len<0) { throw invalid_length(); }
return len;
}
#endif

225
Networking/sockets.cpp Normal file
View File

@@ -0,0 +1,225 @@
// (C) 2016 University of Bristol. See License.txt
#include "sockets.h"
#include "Exceptions/Exceptions.h"
#include <iostream>
using namespace std;
void error(const char *str)
{
char err[1000];
gethostname(err,1000);
strcat(err," : ");
strcat(err,str);
perror(err);
throw bad_value();
}
void error(const char *str1,const char *str2)
{
char err[1000];
gethostname(err,1000);
strcat(err," : ");
strcat(err,str1);
strcat(err,str2);
perror(err);
throw bad_value();
}
void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum)
{
struct sockaddr_in serv; /* socket info about our server */
int socksize = sizeof(struct sockaddr_in);
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */
serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */
serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */
serv.sin_port = htons(Portnum); /* set the server port number */
main_socket = socket(AF_INET, SOCK_STREAM, 0);
if (main_socket<0) { error("set_up_socket:socket"); }
int one=1;
int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
/* disable Nagle's algorithm */
fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
/* bind serv information to mysocket
* - Just assume it will eventually wake up
*/
fl=1;
while (fl!=0)
{ fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
if (fl != 0)
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
sleep(1);
}
else
{ cerr << "Bound on port " << Portnum << endl; }
}
if (fl<0) { error("set_up_socket:bind"); }
/* start listening, allowing a queue of up to 1 pending connection */
fl=listen(main_socket, 1);
if (fl<0) { error("set_up_socket:listen"); }
consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
}
void close_server_socket(int consocket,int main_socket)
{
if (close(consocket)) { error("close(socket)"); }
if (close(main_socket)) { error("close(main_socket"); };
}
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
{
mysocket = socket(AF_INET, SOCK_STREAM, 0);
if (mysocket<0) { error("set_up_socket:socket"); }
/* disable Nagle's algorithm */
int one=1;
int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
fl=setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (char*)&one, sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
struct sockaddr_in dest;
dest.sin_family = AF_INET;
dest.sin_port = htons(Portnum); // set destination port number
/*
struct hostent *server;
server=gethostbyname(hostname);
if (server== NULL)
{ error("set_up_socket:gethostbyname"); }
bcopy((char *)server->h_addr,
(char *)&dest.sin_addr.s_addr,
server->h_length); // set destination IP number
*/
struct addrinfo hints, *ai=NULL,*rp;
memset (&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_flags = AI_CANONNAME;
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
int erp;
for (int i = 0; i < 60; i++)
{ erp=getaddrinfo (hostname, NULL, &hints, &ai);
if (erp == 0)
{ break; }
else
{ cerr << "getaddrinfo on " << my_name << " has returned '" << gai_strerror(erp) <<
"' for " << hostname << ", trying again in a second ..." << endl;
if (ai)
freeaddrinfo(ai);
sleep(1);
}
}
if (erp!=0)
{ error("set_up_socket:getaddrinfo"); }
for (rp=ai; rp!=NULL; rp=rp->ai_next)
{ const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr;
if (ai->ai_family == AF_INET)
{ memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr));
continue;
}
}
freeaddrinfo(ai);
do
{ fl=1;
while (fl==1 || errno==EINPROGRESS)
{ fl=connect(mysocket, (struct sockaddr *)&dest, sizeof(struct sockaddr)); }
}
while (fl==-1 && errno==ECONNREFUSED);
if (fl<0) { error("set_up_socket:connect:",hostname); }
}
void close_client_socket(int socket)
{
if (close(socket))
{
char tmp[1000];
sprintf(tmp, "close(%d)", socket);
error(tmp);
}
}
unsigned long long sent_amount = 0, sent_counter = 0;
void send(int socket,int a)
{
unsigned char msg[1];
msg[0]=a&255;
if (send(socket,msg,1,0)!=1)
{ error("Send error - 2 "); }
}
void receive(int socket,int& a)
{
unsigned char msg[1];
int i=0;
while (i==0)
{ i=recv(socket,msg,1,0);
if (i<0) { error("Receiving error - 2"); }
}
a=msg[0];
}
void send_ack(int socket)
{
char msg[]="OK";
if (send(socket,msg,2,0)!=2)
{ error("Send Ack"); }
}
int get_ack(int socket)
{
char msg[]="OK";
char msg_r[2];
int i=0,j;
while (2-i>0)
{ j=recv(socket,msg_r+i,2-i,0);
i=i+j;
}
if (msg_r[0]!=msg[0] || msg_r[1]!=msg[1]) { return 1; }
return 0;
}

67
Networking/sockets.h Normal file
View File

@@ -0,0 +1,67 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _sockets
#define _sockets
#include "Networking/data.h"
#include <errno.h> /* Errors */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/wait.h> /* Wait for Process Termination */
void error(const char *str1,const char *str2);
void error(const char *str);
void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum);
void close_server_socket(int consocket,int main_socket);
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum);
void close_client_socket(int socket);
void send(int socket,octet *msg,int len);
void receive(int socket,octet *msg,int len);
/* Send and receive 8 bit integers */
void send(int socket,int a);
void receive(int socket,int& a);
void send_ack(int socket);
int get_ack(int socket);
extern unsigned long long sent_amount, sent_counter;
inline void send(int socket,octet *msg,int len)
{
if (send(socket,msg,len,0)!=len)
{ error("Send error - 1 "); }
sent_amount += len;
sent_counter++;
}
inline void receive(int socket,octet *msg,int len)
{
int i=0,j;
while (len-i>0)
{ j=recv(socket,msg+i,len-i,0);
if (j<0) { error("Receiving error - 1"); }
i=i+j;
}
}
#endif

294
OT/BaseOT.cpp Normal file
View File

@@ -0,0 +1,294 @@
// (C) 2016 University of Bristol. See License.txt
#include "OT/BaseOT.h"
#include "Tools/random.h"
#include <stdio.h>
#include <iostream>
#include <fstream>
#include <pthread.h>
extern "C" {
#include "SimpleOT/ot_sender.h"
#include "SimpleOT/ot_receiver.h"
}
using namespace std;
const char* role_to_str(OT_ROLE role)
{
if (role == RECEIVER)
return "RECEIVER";
if (role == SENDER)
return "SENDER";
return "BOTH";
}
OT_ROLE INV_ROLE(OT_ROLE role)
{
if (role == RECEIVER)
return SENDER;
if (role == SENDER)
return RECEIVER;
else
return BOTH;
}
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
{
if (role == SENDER)
{
P->send(os[0]);
}
else if (role == RECEIVER)
{
P->receive(os[1]);
}
else
{
// both sender + receiver
P->send_receive_player(os);
}
}
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
{
if (role == RECEIVER)
{
P->send(os[0]);
}
else if (role == SENDER)
{
P->receive(os[1]);
}
else
{
// both
P->send_receive_player(os);
}
}
void BaseOT::exec_base(bool new_receiver_inputs)
{
int i, j, k, len;
PRNG G;
G.ReSeed();
vector<octetStream> os(2);
SIMPLEOT_SENDER sender;
SIMPLEOT_RECEIVER receiver;
unsigned char S_pack[ PACKBYTES ];
unsigned char Rs_pack[ 2 ][ 4 * PACKBYTES ];
unsigned char sender_keys[ 2 ][ 4 ][ HASHBYTES ];
unsigned char receiver_keys[ 4 ][ HASHBYTES ];
unsigned char cs[ 4 ];
if (ot_role & SENDER)
{
sender_genS(&sender, S_pack);
os[0].store_bytes(S_pack, sizeof(S_pack));
}
send_if_ot_sender(P, os, ot_role);
if (ot_role & RECEIVER)
{
os[1].get_bytes((octet*) receiver.S_pack, len);
if (len != HASHBYTES)
{
cerr << "Received invalid length in base OT\n";
exit(1);
}
receiver_procS(&receiver);
receiver_maketable(&receiver);
}
for (i = 0; i < nOT; i += 4)
{
if (ot_role & RECEIVER)
{
for (j = 0; j < 4; j++)
{
if (new_receiver_inputs)
receiver_inputs[i + j] = G.get_uchar()&1;
cs[j] = receiver_inputs[i + j];
}
receiver_rsgen(&receiver, Rs_pack[0], cs);
os[0].reset_write_head();
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
receiver_keygen(&receiver, receiver_keys);
}
send_if_ot_receiver(P, os, ot_role);
if (ot_role & SENDER)
{
os[1].get_bytes((octet*) Rs_pack[1], len);
if (len != sizeof(Rs_pack[1]))
{
cerr << "Received invalid length in base OT\n";
exit(1);
}
sender_keygen(&sender, Rs_pack[1], sender_keys);
// Copy 128 bits of keys to sender_inputs
for (j = 0; j < 4; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
sender_inputs[i + j][0].set_byte(k, sender_keys[0][j][k]);
sender_inputs[i + j][1].set_byte(k, sender_keys[1][j][k]);
}
}
}
if (ot_role & RECEIVER)
{
// Copy keys to receiver_outputs
for (j = 0; j < 4; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
}
}
}
#ifdef BASE_OT_DEBUG
for (j = 0; j < 4; j++)
{
if (ot_role & SENDER)
{
printf("%4d-th sender keys:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[0][j][k]);
printf(" ");
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]);
printf("\n");
}
if (ot_role & RECEIVER)
{
printf("%4d-th receiver key:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
printf("\n");
}
}
printf("\n");
#endif
}
set_seeds();
}
void BaseOT::set_seeds()
{
for (int i = 0; i < nOT; i++)
{
// Set PRG seeds
if (ot_role & SENDER)
{
G_sender[i][0].SetSeed(sender_inputs[i][0].get_ptr());
G_sender[i][1].SetSeed(sender_inputs[i][1].get_ptr());
}
if (ot_role & RECEIVER)
{
G_receiver[i].SetSeed(receiver_outputs[i].get_ptr());
}
}
extend_length();
}
void BaseOT::extend_length()
{
for (int i = 0; i < nOT; i++)
{
if (ot_role & SENDER)
{
sender_inputs[i][0].randomize(G_sender[i][0]);
sender_inputs[i][1].randomize(G_sender[i][1]);
}
if (ot_role & RECEIVER)
{
receiver_outputs[i].randomize(G_receiver[i]);
}
}
}
void BaseOT::check()
{
vector<octetStream> os(2);
BitVector tmp_vector(8 * AES_BLK_SIZE);
for (int i = 0; i < nOT; i++)
{
if (ot_role == SENDER)
{
// send both inputs over
sender_inputs[i][0].pack(os[0]);
sender_inputs[i][1].pack(os[0]);
P->send(os[0]);
}
else if (ot_role == RECEIVER)
{
P->receive(os[1]);
}
else
{
// both sender + receiver
sender_inputs[i][0].pack(os[0]);
sender_inputs[i][1].pack(os[0]);
P->send_receive_player(os);
}
if (ot_role & RECEIVER)
{
tmp_vector.unpack(os[1]);
if (receiver_inputs[i] == 1)
{
tmp_vector.unpack(os[1]);
}
if (!tmp_vector.equals(receiver_outputs[i]))
{
cerr << "Incorrect OT\n";
exit(1);
}
}
os[0].reset_write_head();
os[1].reset_write_head();
}
}
void FakeOT::exec_base(bool new_receiver_inputs)
{
PRNG G;
G.ReSeed();
vector<octetStream> os(2);
vector<BitVector> bv(2);
if ((ot_role & RECEIVER) && new_receiver_inputs)
{
for (int i = 0; i < nOT; i++)
// Generate my receiver inputs
receiver_inputs[i] = G.get_uchar()&1;
}
if (ot_role & SENDER)
for (int i = 0; i < nOT; i++)
for (int j = 0; j < 2; j++)
{
sender_inputs[i][j].randomize(G);
sender_inputs[i][j].pack(os[0]);
}
send_if_ot_sender(P, os, ot_role);
if (ot_role & RECEIVER)
for (int i = 0; i < nOT; i++)
{
for (int j = 0; j < 2; j++)
bv[j].unpack(os[1]);
receiver_outputs[i] = bv[receiver_inputs[i]];
}
set_seeds();
}

93
OT/BaseOT.h Normal file
View File

@@ -0,0 +1,93 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _BASE_OT
#define _BASE_OT
/* The OT thread uses the Miracl library, which is not thread safe.
* Thus all Miracl based code is contained in this one thread so as
* to avoid locking issues etc.
*
* Thus this thread serves all base OTs to all other threads
*/
#include "Networking/Player.h"
#include "Tools/random.h"
#include "OT/BitVector.h"
// currently always assumes BOTH, i.e. do 2 sets of OT symmetrically,
// use bitwise & to check for role
enum OT_ROLE
{
RECEIVER = 0x01,
SENDER = 0x10,
BOTH = 0x11
};
OT_ROLE INV_ROLE(OT_ROLE role);
const char* role_to_str(OT_ROLE role);
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
class BaseOT
{
public:
vector<int> receiver_inputs;
vector< vector<BitVector> > sender_inputs;
vector<BitVector> receiver_outputs;
TwoPartyPlayer* P;
BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH)
: P(player), nOT(nOT), ot_length(ot_length), ot_role(role)
{
receiver_inputs.resize(nOT);
sender_inputs.resize(nOT, vector<BitVector>(2));
receiver_outputs.resize(nOT);
G_sender.resize(nOT, vector<PRNG>(2));
G_receiver.resize(nOT);
for (int i = 0; i < nOT; i++)
{
sender_inputs[i][0] = BitVector(8 * AES_BLK_SIZE);
sender_inputs[i][1] = BitVector(8 * AES_BLK_SIZE);
receiver_outputs[i] = BitVector(8 * AES_BLK_SIZE);
}
}
virtual ~BaseOT() {}
int length() { return ot_length; }
void set_receiver_inputs(const vector<int>& new_inputs)
{
if ((int)new_inputs.size() != nOT)
throw invalid_length();
receiver_inputs = new_inputs;
}
// do the OTs -- generate fresh random choice bits by default
virtual void exec_base(bool new_receiver_inputs=true);
// use PRG to get the next ot_length bits
void extend_length();
void check();
protected:
int nOT, ot_length;
OT_ROLE ot_role;
vector< vector<PRNG> > G_sender;
vector<PRNG> G_receiver;
bool is_sender() { return (bool) (ot_role & SENDER); }
bool is_receiver() { return (bool) (ot_role & RECEIVER); }
void set_seeds();
};
class FakeOT : public BaseOT
{
public:
FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
BaseOT(nOT, ot_length, player, role) {}
void exec_base(bool new_receiver_inputs=true);
};
#endif

646
OT/BitMatrix.cpp Normal file
View File

@@ -0,0 +1,646 @@
// (C) 2016 University of Bristol. See License.txt
/*
* BitMatrix.cpp
*
*/
#include <smmintrin.h>
#include <immintrin.h>
#include <mpirxx.h>
#include "BitMatrix.h"
#include "Math/gf2n.h"
#include "Math/gfp.h"
union matrix16x8
{
__m128i whole;
octet rows[16];
bool get_bit(int x, int y)
{ return (rows[x] >> y) & 1; }
void input(square128& input, int x, int y);
void transpose(square128& output, int x, int y);
};
class square16
{
public:
// 16x16 in two halves, 128 bits each
matrix16x8 halves[2];
bool get_bit(int x, int y)
{ return halves[y/8].get_bit(x, y % 8); }
void input(square128& output, int x, int y);
void transpose(square128& output, int x, int y);
void check_transpose(square16& dual);
void print();
};
__attribute__((optimize("unroll-loops")))
inline void matrix16x8::input(square128& input, int x, int y)
{
for (int l = 0; l < 16; l++)
rows[l] = input.bytes[16*x+l][y];
}
__attribute__((optimize("unroll-loops")))
inline void square16::input(square128& input, int x, int y)
{
for (int i = 0; i < 2; i++)
halves[i].input(input, x, 2 * y + i);
}
__attribute__((optimize("unroll-loops")))
inline void matrix16x8::transpose(square128& output, int x, int y)
{
for (int j = 0; j < 8; j++)
{
int row = _mm_movemask_epi8(whole);
whole = _mm_slli_epi64(whole, 1);
// _mm_movemask_epi8 uses most significant bit, hence +7-j
output.doublebytes[8*x+7-j][y] = row;
}
}
__attribute__((optimize("unroll-loops")))
inline void square16::transpose(square128& output, int x, int y)
{
for (int i = 0; i < 2; i++)
halves[i].transpose(output, 2 * x + i, y);
}
#ifdef __AVX2__
union matrix32x8
{
__m256i whole;
octet rows[32];
void input(square128& input, int x, int y);
void transpose(square128& output, int x, int y);
};
class square32
{
public:
matrix32x8 quarters[4];
void input(square128& input, int x, int y);
void transpose(square128& output, int x, int y);
};
__attribute__((optimize("unroll-loops")))
inline void matrix32x8::input(square128& input, int x, int y)
{
for (int l = 0; l < 32; l++)
rows[l] = input.bytes[32*x+l][y];
}
__attribute__((optimize("unroll-loops")))
inline void square32::input(square128& input, int x, int y)
{
for (int i = 0; i < 4; i++)
quarters[i].input(input, x, 4 * y + i);
}
__attribute__((optimize("unroll-loops")))
inline void matrix32x8::transpose(square128& output, int x, int y)
{
for (int j = 0; j < 8; j++)
{
int row = _mm256_movemask_epi8(whole);
whole = _mm256_slli_epi64(whole, 1);
// _mm_movemask_epi8 uses most significant bit, hence +7-j
output.words[8*x+7-j][y] = row;
}
}
__attribute__((optimize("unroll-loops")))
inline void square32::transpose(square128& output, int x, int y)
{
for (int i = 0; i < 4; i++)
quarters[i].transpose(output, 4 * x + i, y);
}
#endif
#ifdef __AVX2__
#warning Using AVX2 for transpose
typedef square32 subsquare;
#define N_SUBSQUARES 4
#else
typedef square16 subsquare;
#define N_SUBSQUARES 8
#endif
__attribute__((optimize("unroll-loops")))
void square128::transpose()
{
for (int j = 0; j < N_SUBSQUARES; j++)
for (int k = 0; k < j; k++)
{
subsquare a, b;
a.input(*this, k, j);
b.input(*this, j, k);
a.transpose(*this, j, k);
b.transpose(*this, k, j);
}
for (int j = 0; j < N_SUBSQUARES; j++)
{
subsquare a;
a.input(*this, j, j);
a.transpose(*this, j, j);
}
}
void square128::randomize(PRNG& G)
{
G.get_octets((octet*)&rows, sizeof(rows));
}
template <>
void square128::randomize<gf2n>(int row, PRNG& G)
{
rows[row] = G.get_doubleword();
}
template <>
void square128::randomize<gfp>(int row, PRNG& G)
{
rows[row] = gfp::get_ZpD().get_random128(G);
}
void gfp_iadd(__m128i& a, __m128i& b)
{
gfp::get_ZpD().Add((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b);
}
void gfp_isub(__m128i& a, __m128i& b)
{
gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b);
}
void gfp_irsub(__m128i& a, __m128i& b)
{
gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&b, (mp_limb_t*)&a);
}
template<>
void square128::conditional_add<gf2n>(BitVector& conditions, square128& other, int offset)
{
for (int i = 0; i < 128; i++)
if (conditions.get_bit(128 * offset + i))
rows[i] ^= other.rows[i];
}
template<>
void square128::conditional_add<gfp>(BitVector& conditions, square128& other, int offset)
{
for (int i = 0; i < 128; i++)
if (conditions.get_bit(128 * offset + i))
gfp_iadd(rows[i], other.rows[i]);
}
template <class T>
void square128::hash_row_wise(MMO& mmo, square128& input)
{
mmo.hashBlockWise<T,128>((octet*)rows, (octet*)input.rows);
}
template <>
void square128::to(gf2n_long& result)
{
int128 high, low;
for (int i = 0; i < 128; i++)
{
low ^= int128(rows[i]) << i;
high ^= int128(rows[i]) >> (128 - i);
}
result.reduce(high, low);
}
template <>
void square128::to(gfp& result)
{
mp_limb_t product[4], sum[4], tmp[2][4];
memset(tmp, 0, sizeof(tmp));
memset(sum, 0, sizeof(sum));
for (int i = 0; i < 128; i++)
{
memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i]));
mpn_lshift(product, tmp[i/64], 4, i % 64);
mpn_add_n(sum, product, sum, 4);
}
mp_limb_t q[4], ans[4];
mpn_tdiv_qr(q, ans, 0, sum, 4, gfp::get_ZpD().get_prA(), 2);
result = *(__m128i*)ans;
}
void square128::check_transpose(square128& dual, int i, int k)
{
for (int j = 0; j < 16; j++)
for (int l = 0; l < 16; l++)
if (get_bit(16 * i + j, 16 * k + l) != dual.get_bit(16 * k + l, 16 * i + j))
{
cout << "Error in 16x16 square (" << i << "," << k << ")" << endl;
print(i, k);
dual.print(i, k);
exit(1);
}
}
void square16::print()
{
for (int i = 0; i < 2; i++)
{
for (int j = 0; j < 8; j++)
{
for (int k = 0; k < 2; k++)
{
for (int l = 0; l < 8; l++)
cout << halves[k].get_bit(8 * i + j, l);
cout << " ";
}
cout << endl;
}
cout << endl;
}
}
void square128::print(int i, int k)
{
square16 a;
a.input(*this, i, k);
a.print();
}
void square128::print()
{
for (int i = 0; i < 128; i++)
{
for (int j = 0; j < 128; j++)
cout << get_bit(i, j);
cout << endl;
}
}
void square128::set_zero()
{
for (int i = 0; i < 128; i++)
rows[i] = _mm_setzero_si128();
}
square128& square128::operator^=(square128& other)
{
for (int i = 0; i < 128; i++)
rows[i] ^= other.rows[i];
return *this;
}
template<>
square128& square128::add<gf2n>(square128& other)
{
return *this ^= other;
}
template<>
square128& square128::add<gfp>(square128& other)
{
for (int i = 0; i < 128; i++)
gfp_iadd(rows[i], other.rows[i]);
return *this;
}
template<>
square128& square128::sub<gf2n>(square128& other)
{
return *this ^= other;
}
template<>
square128& square128::sub<gfp>(square128& other)
{
for (int i = 0; i < 128; i++)
gfp_isub(rows[i], other.rows[i]);
return *this;
}
template<>
square128& square128::rsub<gf2n>(square128& other)
{
return *this ^= other;
}
template<>
square128& square128::rsub<gfp>(square128& other)
{
for (int i = 0; i < 128; i++)
gfp_irsub(rows[i], other.rows[i]);
return *this;
}
square128& square128::operator^=(__m128i* other)
{
__m128i value = _mm_loadu_si128(other);
for (int i = 0; i < 128; i++)
rows[i] ^= value;
return *this;
}
template <>
square128& square128::sub<gf2n>(__m128i* other)
{
return *this ^= other;
}
template <>
square128& square128::sub<gfp>(__m128i* other)
{
__m128i value = _mm_loadu_si128(other);
for (int i = 0; i < 128; i++)
gfp_isub(rows[i], value);
return *this;
}
square128& square128::operator^=(BitVector& other)
{
return *this ^= (__m128i*)other.get_ptr();
}
bool square128::operator==(square128& other)
{
for (int i = 0; i < 128; i++)
{
__m128i tmp = rows[i] ^ other.rows[i];
if (not _mm_test_all_zeros(tmp, tmp))
return false;
}
return true;
}
void square128::pack(octetStream& o) const
{
o.append((octet*)this->bytes, sizeof(bytes));
}
void square128::unpack(octetStream &o)
{
o.consume((octet*)this->bytes, sizeof(bytes));
}
BitMatrix::BitMatrix(int length)
{
resize(length);
}
void BitMatrix::resize(int length)
{
if (length % 128 != 0)
throw invalid_length();
squares.resize(length / 128);
}
int BitMatrix::size()
{
return squares.size() * 128;
}
template <class T>
BitMatrix& BitMatrix::add(BitMatrix& other)
{
if (squares.size() != other.squares.size())
throw invalid_length();
for (size_t i = 0; i < squares.size(); i++)
squares[i].add<T>(other.squares[i]);
return *this;
}
template <class T>
BitMatrix& BitMatrix::sub(BitMatrix& other)
{
if (squares.size() != other.squares.size())
throw invalid_length();
for (size_t i = 0; i < squares.size(); i++)
squares[i].sub<T>(other.squares[i]);
return *this;
}
template <class T>
BitMatrix& BitMatrix::rsub(BitMatrixSlice& other)
{
if (squares.size() < other.end)
throw invalid_length();
for (size_t i = other.start; i < other.end; i++)
squares[i].rsub<T>(other.bm.squares[i]);
return *this;
}
template <class T>
BitMatrix& BitMatrix::sub(BitVector& other)
{
if (squares.size() * 128 != other.size())
throw invalid_length();
for (size_t i = 0; i < squares.size(); i++)
squares[i].sub<T>((__m128i*)other.get_ptr() + i);
return *this;
}
bool BitMatrix::operator==(BitMatrix& other)
{
if (squares.size() != other.squares.size())
throw invalid_length();
for (size_t i = 0; i < squares.size(); i++)
if (not(squares[i] == other.squares[i]))
return false;
return true;
}
bool BitMatrix::operator!=(BitMatrix& other)
{
return not (*this == other);
}
void BitMatrix::randomize(PRNG& G)
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].randomize(G);
}
void BitMatrix::randomize(int row, PRNG& G)
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].randomize<gf2n>(row, G);
}
void BitMatrix::transpose()
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].transpose();
}
void BitMatrix::check_transpose(BitMatrix& dual)
{
for (size_t i = 0; i < squares.size(); i++)
{
for (int j = 0; j < 128; j++)
for (int k = 0; k < 128; k++)
if (squares[i].get_bit(j, k) != dual.squares[i].get_bit(k, j))
{
cout << "First error in square " << i << " row " << j
<< " column " << k << endl;
squares[i].print(i / 8, j / 8);
dual.squares[i].print(i / 8, j / 8);
return;
}
}
cout << "No errors in transpose" << endl;
}
void BitMatrix::print_side_by_side(BitMatrix& other)
{
for (int i = 0; i < 32; i++)
{
for (int j = 0; j < 64; j++)
cout << squares[0].get_bit(i,j);
cout << " ";
for (int j = 0; j < 64; j++)
cout << other.squares[0].get_bit(i,j);
cout << endl;
}
}
void BitMatrix::print_conditional(BitVector& conditions)
{
for (int i = 0; i < 32; i++)
{
if (conditions.get_bit(i))
for (int j = 0; j < 65; j++)
cout << " ";
for (int j = 0; j < 64; j++)
cout << squares[0].get_bit(i,j);
if (!conditions.get_bit(i))
for (int j = 0; j < 65; j++)
cout << " ";
cout << endl;
}
}
void BitMatrix::pack(octetStream& os) const
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].pack(os);
}
void BitMatrix::unpack(octetStream& os)
{
for (size_t i = 0; i < squares.size(); i++)
squares[i].unpack(os);
}
void BitMatrix::to(vector<BitVector>& output)
{
output.resize(128);
for (int i = 0; i < 128; i++)
{
output[i].resize(128 * squares.size());
for (size_t j = 0; j < squares.size(); j++)
output[i].set_int128(j, squares[j].rows[i]);
}
}
BitMatrixSlice::BitMatrixSlice(BitMatrix& bm, size_t start, size_t size) :
bm(bm), start(start), size(size)
{
end = start + size;
if (end > bm.squares.size())
{
stringstream ss;
ss << "Matrix slice (" << start << "," << end << ") larger than matrix (" << bm.squares.size() << ")";
throw invalid_argument(ss.str());
}
}
template <class T>
BitMatrixSlice& BitMatrixSlice::rsub(BitMatrixSlice& other)
{
bm.rsub<T>(other);
return *this;
}
template <class T>
BitMatrixSlice& BitMatrixSlice::add(BitVector& other, int repeat)
{
if (end * 128 > other.size() * repeat)
throw invalid_length();
for (size_t i = start; i < end; i++)
bm.squares[i].sub<T>((__m128i*)other.get_ptr() + i / repeat);
return *this;
}
template <class T>
void BitMatrixSlice::randomize(int row, PRNG& G)
{
for (size_t i = start; i < end; i++)
bm.squares[i].randomize<T>(row, G);
}
template <class T>
void BitMatrixSlice::conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset)
{
for (size_t i = start; i < end; i++)
bm.squares[i].conditional_add<T>(conditions, other.squares[i], useOffset * i);
}
void BitMatrixSlice::transpose()
{
for (size_t i = start; i < end; i++)
bm.squares[i].transpose();
}
template <class T>
void BitMatrixSlice::print()
{
cout << "hex / value" << endl;
for (int i = 0; i < 16; i++)
{
cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl;
}
cout << endl;
}
void BitMatrixSlice::pack(octetStream& os) const
{
for (size_t i = start; i < end; i++)
bm.squares[i].pack(os);
}
void BitMatrixSlice::unpack(octetStream& os)
{
for (size_t i = start; i < end; i++)
bm.squares[i].unpack(os);
}
template void BitMatrixSlice::conditional_add<gf2n>(BitVector& conditions, BitMatrix& other, bool useOffset);
template void BitMatrixSlice::conditional_add<gfp>(BitVector& conditions, BitMatrix& other, bool useOffset);
template BitMatrixSlice& BitMatrixSlice::rsub<gf2n>(BitMatrixSlice& other);
template BitMatrixSlice& BitMatrixSlice::rsub<gfp>(BitMatrixSlice& other);
template BitMatrixSlice& BitMatrixSlice::add<gf2n>(BitVector& other, int repeat);
template BitMatrixSlice& BitMatrixSlice::add<gfp>(BitVector& other, int repeat);
template BitMatrix& BitMatrix::add<gf2n>(BitMatrix& other);
template BitMatrix& BitMatrix::add<gfp>(BitMatrix& other);
template BitMatrix& BitMatrix::sub<gf2n>(BitMatrix& other);
template BitMatrix& BitMatrix::sub<gfp>(BitMatrix& other);
template void BitMatrixSlice::print<gf2n>();
template void BitMatrixSlice::print<gfp>();
template void BitMatrixSlice::randomize<gf2n>(int row, PRNG& G);
template void BitMatrixSlice::randomize<gfp>(int row, PRNG& G);
template void square128::hash_row_wise<gf2n>(MMO& mmo, square128& input);
template void square128::hash_row_wise<gfp>(MMO& mmo, square128& input);

136
OT/BitMatrix.h Normal file
View File

@@ -0,0 +1,136 @@
// (C) 2016 University of Bristol. See License.txt
/*
* BitMatrix.h
*
*/
#ifndef OT_BITMATRIX_H_
#define OT_BITMATRIX_H_
#include <vector>
#include <emmintrin.h>
#include "BitVector.h"
#include "Tools/random.h"
#include "Tools/MMO.h"
#include "Math/gf2nlong.h"
using namespace std;
union square128 {
__m128i rows[128];
octet bytes[128][16];
int16_t doublebytes[128][8];
int32_t words[128][4];
bool get_bit(int x, int y)
{ return (bytes[x][y/8] >> (y % 8)) & 1; }
void set_zero();
square128& operator^=(square128& other);
square128& operator^=(__m128i* other);
square128& operator^=(BitVector& other);
bool operator==(square128& other);
template <class T>
square128& add(square128& other);
template <class T>
square128& sub(square128& other);
template <class T>
square128& rsub(square128& other);
template <class T>
square128& sub(__m128i* other);
void randomize(PRNG& G);
template <class T>
void randomize(int row, PRNG& G);
template <class T>
void conditional_add(BitVector& conditions, square128& other, int offset);
void transpose();
template <class T>
void hash_row_wise(MMO& mmo, square128& input);
template <class T>
void to(T& result);
void check_transpose(square128& dual, int i, int k);
void print(int i, int k);
void print();
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const;
void unpack(octetStream& o);
};
class BitMatrixSlice;
class BitMatrix
{
public:
vector<square128> squares;
BitMatrix() {}
BitMatrix(int length);
void resize(int length);
int size();
template <class T>
BitMatrix& add(BitMatrix& other);
template <class T>
BitMatrix& sub(BitMatrix& other);
template <class T>
BitMatrix& rsub(BitMatrixSlice& other);
template <class T>
BitMatrix& sub(BitVector& other);
bool operator==(BitMatrix& other);
bool operator!=(BitMatrix& other);
void randomize(PRNG& G);
void randomize(int row, PRNG& G);
void transpose();
void check_transpose(BitMatrix& dual);
void print_side_by_side(BitMatrix& other);
void print_conditional(BitVector& conditions);
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const;
void unpack(octetStream& o);
void to(vector<BitVector>& output);
};
class BitMatrixSlice
{
friend class BitMatrix;
BitMatrix& bm;
size_t start, size, end;
public:
BitMatrixSlice(BitMatrix& bm, size_t start, size_t size);
template <class T>
BitMatrixSlice& rsub(BitMatrixSlice& other);
template <class T>
BitMatrixSlice& add(BitVector& other, int repeat = 1);
template <class T>
void randomize(int row, PRNG& G);
template <class T>
void conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset = false);
void transpose();
template <class T>
void print();
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const;
void unpack(octetStream& o);
};
#endif /* OT_BITMATRIX_H_ */

107
OT/BitVector.cpp Normal file
View File

@@ -0,0 +1,107 @@
// (C) 2016 University of Bristol. See License.txt
#include "OT/BitVector.h"
#include "Tools/random.h"
#include "Tools/octetStream.h"
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include <fstream>
void BitVector::randomize(PRNG& G)
{
G.get_octets(bytes, nbytes);
}
template<>
void BitVector::randomize_blocks<gf2n>(PRNG& G)
{
randomize(G);
}
template<>
void BitVector::randomize_blocks<gfp>(PRNG& G)
{
gfp tmp;
for (size_t i = 0; i < (nbits / 128); i++)
{
tmp.randomize(G);
for (int j = 0; j < 2; j++)
((mp_limb_t*)bytes)[2*i+j] = tmp.get().get_limb(j);
}
}
void BitVector::randomize_at(int a, int nb, PRNG& G)
{
if (nb < 1)
throw invalid_length();
G.get_octets(bytes + a, nb);
}
/*
*/
void BitVector::output(ostream& s,bool human) const
{
if (human)
{
s << nbits << " " << hex;
for (unsigned int i = 0; i < nbytes; i++)
{
s << int(bytes[i]) << " ";
}
s << dec << endl;
}
else
{
int len = nbits;
s.write((char*) &len, sizeof(int));
s.write((char*) bytes, nbytes);
}
}
void BitVector::input(istream& s,bool human)
{
if (s.peek() == EOF)
{
if (s.tellg() == 0)
{
cout << "IO problem. Empty file?" << endl;
throw file_error();
}
throw end_of_file();
}
int len;
if (human)
{
s >> len >> hex;
resize(len);
for (size_t i = 0; i < nbytes; i++)
{
s >> bytes[i];
}
s >> dec;
}
else
{
s.read((char*) &len, sizeof(int));
resize(len);
s.read((char*) bytes, nbytes);
}
}
void BitVector::pack(octetStream& o) const
{
o.append((octet*)bytes, nbytes);
}
void BitVector::unpack(octetStream& o)
{
o.consume((octet*)bytes, nbytes);
}

212
OT/BitVector.h Normal file
View File

@@ -0,0 +1,212 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _BITVECTOR
#define _BITVECTOR
/* Vector of bits */
#include <iostream>
#include <vector>
using namespace std;
#include <stdlib.h>
#include <pmmintrin.h>
#include "Exceptions/Exceptions.h"
#include "Networking/data.h"
// just for util functions
#include "Math/bigint.h"
#include "Math/gf2nlong.h"
class PRNG;
class octetStream;
class BitVector
{
octet* bytes;
size_t nbytes;
size_t nbits;
size_t length;
public:
void assign(const BitVector& K)
{
if (nbits != K.nbits)
{
resize(K.nbits);
}
memcpy(bytes, K.bytes, nbytes);
}
void assign_bytes(char* new_bytes, int len)
{
resize(len*8);
memcpy(bytes, new_bytes, len);
}
void assign_zero()
{
memset(bytes, 0, nbytes);
}
// only grows, never destroys
void resize(size_t new_nbits)
{
if (nbits != new_nbits)
{
int new_nbytes = DIV_CEIL(new_nbits,8);
if (nbits < new_nbits)
{
octet* tmp = new octet[new_nbytes];
memcpy(tmp, bytes, nbytes);
delete[] bytes;
bytes = tmp;
}
nbits = new_nbits;
nbytes = new_nbytes;
/*
// use realloc to preserve original contents
if (new_nbits < nbits)
{
memcpy(tmp, bytes, new_nbytes);
}
else
{
memset(tmp, 0, new_nbytes);
memcpy(tmp, bytes, nbytes);
}*/
// realloc may fail on size 0
/*if (new_nbits == 0)
{
free(bytes);
bytes = (octet*) malloc(0);//new octet[0];
//free(bytes);
return;
}
bytes = (octet*)realloc(bytes, nbytes);
if (bytes == NULL)
{
cerr << "realloc failed\n";
exit(1);
}*/
/*delete[] bytes;
nbits = new_nbits;
nbytes = DIV_CEIL(nbits, 8);
bytes = new octet[nbytes];*/
}
}
unsigned int size() const { return nbits; }
unsigned int size_bytes() const { return nbytes; }
octet* get_ptr() { return bytes; }
BitVector(size_t n=128)
{
nbits = n;
nbytes = DIV_CEIL(nbits, 8);
bytes = new octet[nbytes];
length = n;
assign_zero();
}
BitVector(const BitVector& K)
{
bytes = new octet[K.nbytes];
nbytes = K.nbytes;
nbits = K.nbits;
assign(K);
}
~BitVector() {
//cout << "Destroy, size = " << nbytes << endl;
delete[] bytes;
}
BitVector& operator=(const BitVector& K)
{
if (this!=&K) { assign(K); }
return *this;
}
octet get_byte(int i) const { return bytes[i]; }
void set_byte(int i, octet b) { bytes[i] = b; }
// get the i-th 64-bit word
word get_word(int i) const { return *(word*)(bytes + i*8); }
void set_word(int i, word w)
{
int offset = i * sizeof(word);
memcpy(bytes + offset, (octet*)&w, sizeof(word));
}
int128 get_int128(int i) const { return _mm_lddqu_si128((__m128i*)bytes + i); }
void set_int128(int i, int128 a) { *((__m128i*)bytes + i) = a.a; }
int get_bit(int i) const
{
return (bytes[i/8] >> (i % 8)) & 1;
}
void set_bit(int i,unsigned int a)
{
int j = i/8, k = i&7;
if (a==1)
{ bytes[j] |= (octet)(1UL<<k); }
else
{ bytes[j] &= (octet)~(1UL<<k); }
}
void add(const BitVector& A, const BitVector& B)
{
if (A.nbits != B.nbits)
{ throw invalid_length(); }
resize(A.nbits);
for (unsigned int i=0; i < nbytes; i++)
{
bytes[i] = A.bytes[i] ^ B.bytes[i];
}
}
void add(const BitVector& A)
{
if (nbits != A.nbits)
{ throw invalid_length(); }
for (unsigned int i = 0; i < nbytes; i++)
{
bytes[i] ^= A.bytes[i];
}
}
bool equals(const BitVector& K) const
{
if (nbits != K.nbits)
{ throw invalid_length(); }
for (unsigned int i = 0; i < nbytes; i++)
{ if (bytes[i] != K.bytes[i]) { return false; } }
return true;
}
void randomize(PRNG& G);
template <class T>
void randomize_blocks(PRNG& G);
// randomize bytes a, ..., a+nb-1
void randomize_at(int a, int nb, PRNG& G);
void output(ostream& s,bool human) const;
void input(istream& s,bool human);
// Pack and unpack in native format
// i.e. Dont care about conversion to human readable form
void pack(octetStream& o) const;
void unpack(octetStream& o);
string str()
{
stringstream ss;
ss << hex;
for(size_t i(0);i < nbytes;++i)
ss << (int)bytes[i] << " ";
return ss.str();
}
};
#endif

View File

@@ -0,0 +1,555 @@
// (C) 2016 University of Bristol. See License.txt
#include "NPartyTripleGenerator.h"
#include "OT/OTExtensionWithMatrix.h"
#include "OT/OTMultiplier.h"
#include "Math/gfp.h"
#include "Math/Share.h"
#include "Math/operators.h"
#include "Auth/Subroutines.h"
#include "Auth/MAC_Check.h"
#include <sstream>
#include <fstream>
#include <math.h>
template <class T, int N>
class Triple
{
public:
T a[N];
T b;
T c[N];
int repeat(int l)
{
switch (l)
{
case 0:
case 2:
return N;
case 1:
return 1;
default:
throw bad_value();
}
}
T& byIndex(int l, int j)
{
switch (l)
{
case 0:
return a[j];
case 1:
return b;
case 2:
return c[j];
default:
throw bad_value();
}
}
template <int M>
void amplify(const Triple<T,M>& uncheckedTriple, PRNG& G)
{
b = uncheckedTriple.b;
for (int i = 0; i < N; i++)
for (int j = 0; j < M; j++)
{
typename T::value_type r;
r.randomize(G);
a[i] += r * uncheckedTriple.a[j];
c[i] += r * uncheckedTriple.c[j];
}
}
void output(ostream& outputStream, int n = N, bool human = false)
{
for (int i = 0; i < n; i++)
{
a[i].output(outputStream, human);
b.output(outputStream, human);
c[i].output(outputStream, human);
}
}
};
template <class T, int N>
class PlainTriple : public Triple<T,N>
{
public:
// this assumes that valueBits[1] is still set to the bits of b
void to(vector<BitVector>& valueBits, int i)
{
for (int j = 0; j < N; j++)
{
valueBits[0].set_int128(i * N + j, this->a[j].to_m128i());
valueBits[2].set_int128(i * N + j, this->c[j].to_m128i());
}
}
};
template <class T, int N>
class ShareTriple : public Triple<Share<T>, N>
{
public:
void from(PlainTriple<T,N>& triple, vector<OTMultiplier<T>*>& ot_multipliers,
int iTriple, const NPartyTripleGenerator& generator)
{
for (int l = 0; l < 3; l++)
{
int repeat = this->repeat(l);
for (int j = 0; j < repeat; j++)
{
T value = triple.byIndex(l,j);
T mac = value * generator.machine.get_mac_key<T>();
for (int i = 0; i < generator.nparties-1; i++)
mac += ot_multipliers[i]->macs[l][iTriple * repeat + j];
Share<T>& share = this->byIndex(l,j);
share.set_share(value);
share.set_mac(mac);
}
}
}
T computeCheckMAC(const T& maskedA)
{
return this->c[0].get_mac() - maskedA * this->b.get_mac();
}
};
/*
* Copies the relevant base OTs from setup
* N.B. setup must not be stored as it will be used by other threads
*/
NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
TripleMachine& machine) :
globalPlayer(names, - thread_num * machine.nplayers * machine.nplayers),
thread_num(thread_num),
my_num(setup.get_my_num()),
nloops(nloops),
nparties(setup.get_nparties()),
machine(machine)
{
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
nTriples = nTriplesPerLoop * nloops;
field_size = 128;
nAmplify = machine.amplify ? N_AMPLIFY : 1;
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
int n = nparties;
//baseReceiverInput = machines[0]->baseReceiverInput;
//baseSenderInputs.resize(n-1);
//baseReceiverOutputs.resize(n-1);
nbase = setup.get_nbase();
baseReceiverInput.resize(nbase);
baseReceiverOutputs.resize(n - 1);
baseSenderInputs.resize(n - 1);
players.resize(n-1);
gf2n_long::init_field(128);
for (int i = 0; i < n-1; i++)
{
// i for indexing, other_player is actual number
int other_player, id;
if (i >= my_num)
other_player = i + 1;
else
other_player = i;
// copy base OT inputs + outputs
for (int j = 0; j < 128; j++)
{
baseReceiverInput.set_bit(j, (unsigned int)setup.get_base_receiver_input(j));
}
baseReceiverOutputs[i] = setup.baseOTs[i]->receiver_outputs;
baseSenderInputs[i] = setup.baseOTs[i]->sender_inputs;
// new TwoPartyPlayer with unique id for each thread + pair of players
if (my_num < other_player)
id = (thread_num+1)*n*n + my_num*n + other_player;
else
id = (thread_num+1)*n*n + other_player*n + my_num;
players[i] = new TwoPartyPlayer(names, other_player, id);
cout << "Set up with player " << other_player << " in thread " << thread_num << " with id " << id << endl;
}
pthread_mutex_init(&mutex, 0);
pthread_cond_init(&ready, 0);
}
NPartyTripleGenerator::~NPartyTripleGenerator()
{
for (size_t i = 0; i < players.size(); i++)
delete players[i];
//delete nplayer;
pthread_mutex_destroy(&mutex);
pthread_cond_destroy(&ready);
}
template<class T>
void* run_ot_thread(void* ptr)
{
((OTMultiplier<T>*)ptr)->multiply();
return NULL;
}
template<class T>
void NPartyTripleGenerator::generate()
{
vector< OTMultiplier<T>* > ot_multipliers(nparties-1);
timers["Generator thread"].start();
for (int i = 0; i < nparties-1; i++)
{
ot_multipliers[i] = new OTMultiplier<T>(*this, i);
pthread_mutex_lock(&ot_multipliers[i]->mutex);
pthread_create(&(ot_multipliers[i]->thread), 0, run_ot_thread<T>, ot_multipliers[i]);
}
// add up the shares from each thread and write to file
stringstream ss;
ss << machine.prep_data_dir;
if (machine.generateBits)
ss << "Bits-";
else
ss << "Triples-";
ss << T::type_char() << "-P" << my_num;
if (thread_num != 0)
ss << "-" << thread_num;
ofstream outputFile(ss.str().c_str());
if (machine.generateBits)
generateBits(ot_multipliers, outputFile);
else
generateTriples(ot_multipliers, outputFile);
timers["Generator thread"].stop();
if (machine.output)
cout << "Written " << nTriples << " outputs to " << ss.str() << endl;
else
cout << "Generated " << nTriples << " outputs" << endl;
// wait for threads to finish
for (int i = 0; i < nparties-1; i++)
{
pthread_mutex_unlock(&ot_multipliers[i]->mutex);
pthread_join(ot_multipliers[i]->thread, NULL);
cout << "OT thread " << i << " finished\n" << flush;
}
cout << "OT threads finished\n";
for (size_t i = 0; i < ot_multipliers.size(); i++)
delete ot_multipliers[i];
}
template<>
void NPartyTripleGenerator::generateBits(vector< OTMultiplier<gf2n>* >& ot_multipliers,
ofstream& outputFile)
{
PRNG share_prg;
share_prg.ReSeed();
int nBitsToCheck = nTriplesPerLoop + field_size;
valueBits.resize(1);
valueBits[0].resize(ceil(1.0 * nBitsToCheck / field_size) * field_size);
MAC_Check<gf2n> MC(machine.get_mac_key<gf2n>());
vector< Share<gf2n> > bits(nBitsToCheck);
vector< Share<gf2n> > to_open(1);
vector<gf2n> opened(1);
start_progress(ot_multipliers);
for (int k = 0; k < nloops; k++)
{
print_progress(k);
valueBits[0].randomize_blocks<gf2n>(share_prg);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
timers["Authentication OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
timers["Authentication OTs"].stop();
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
PRNG G;
G.SetSeed(seed);
Share<gf2n> check_sum;
gf2n r;
for (int j = 0; j < nBitsToCheck; j++)
{
gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key<gf2n>();
for (int i = 0; i < nparties-1; i++)
mac_sum += ot_multipliers[i]->macs[0][j];
bits[j].set_share(valueBits[0].get_bit(j));
bits[j].set_mac(mac_sum);
r.randomize(G);
check_sum += r * bits[j];
}
to_open[0] = check_sum;
MC.POpen_Begin(opened, to_open, globalPlayer);
MC.POpen_End(opened, to_open, globalPlayer);
MC.Check(globalPlayer);
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
bits[j].output(outputFile, false);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
}
}
template<>
void NPartyTripleGenerator::generateBits(vector< OTMultiplier<gfp>* >& ot_multipliers,
ofstream& outputFile)
{
generateTriples(ot_multipliers, outputFile);
}
template<class T>
void NPartyTripleGenerator::generateTriples(vector< OTMultiplier<T>* >& ot_multipliers,
ofstream& outputFile)
{
PRNG share_prg;
share_prg.ReSeed();
valueBits.resize(3);
for (int i = 0; i < 2; i++)
valueBits[2*i].resize(field_size * nPreampTriplesPerLoop);
valueBits[1].resize(field_size * nTriplesPerLoop);
vector< PlainTriple<T,N_AMPLIFY> > preampTriples;
vector< PlainTriple<T,2> > amplifiedTriples;
vector< ShareTriple<T,2> > uncheckedTriples;
MAC_Check<T> MC(machine.get_mac_key<T>());
if (machine.amplify)
preampTriples.resize(nTriplesPerLoop);
if (machine.generateMACs)
{
amplifiedTriples.resize(nTriplesPerLoop);
uncheckedTriples.resize(nTriplesPerLoop);
}
start_progress(ot_multipliers);
for (int k = 0; k < nloops; k++)
{
print_progress(k);
for (int j = 0; j < 2; j++)
valueBits[j].randomize_blocks<T>(share_prg);
timers["OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
timers["OTs"].stop();
for (int j = 0; j < nPreampTriplesPerLoop; j++)
{
T a(valueBits[0].get_int128(j));
T b(valueBits[1].get_int128(j / nAmplify));
T c = a * b;
timers["Triple computation"].start();
for (int i = 0; i < nparties-1; i++)
{
c += ot_multipliers[i]->c_output[j];
}
timers["Triple computation"].stop();
if (machine.amplify)
{
preampTriples[j/nAmplify].a[j%nAmplify] = a;
preampTriples[j/nAmplify].b = b;
preampTriples[j/nAmplify].c[j%nAmplify] = c;
}
else
{
timers["Writing"].start();
a.output(outputFile, false);
b.output(outputFile, false);
c.output(outputFile, false);
timers["Writing"].stop();
}
}
if (machine.amplify)
{
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
PRNG G;
G.SetSeed(seed);
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
{
PlainTriple<T,2> triple;
triple.amplify(preampTriples[iTriple], G);
if (machine.generateMACs)
amplifiedTriples[iTriple] = triple;
else
{
timers["Writing"].start();
triple.output(outputFile);
timers["Writing"].stop();
}
}
if (machine.generateMACs)
{
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
amplifiedTriples[iTriple].to(valueBits, iTriple);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
timers["Authentication OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
timers["Authentication OTs"].stop();
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
{
uncheckedTriples[iTriple].from(amplifiedTriples[iTriple], ot_multipliers, iTriple, *this);
if (!machine.check)
{
timers["Writing"].start();
amplifiedTriples[iTriple].output(outputFile);
timers["Writing"].stop();
}
}
if (machine.check)
{
vector< Share<T> > maskedAs(nTriplesPerLoop);
vector< ShareTriple<T,1> > maskedTriples(nTriplesPerLoop);
for (int j = 0; j < nTriplesPerLoop; j++)
{
maskedTriples[j].amplify(uncheckedTriples[j], G);
maskedAs[j] = maskedTriples[j].a[0];
}
vector<T> openedAs(nTriplesPerLoop);
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
MC.POpen_End(openedAs, maskedAs, globalPlayer);
for (int j = 0; j < nTriplesPerLoop; j++)
MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), int128(0), globalPlayer);
MC.Check(globalPlayer);
if (machine.generateBits)
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
else
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
uncheckedTriples[j].output(outputFile, 1);
}
}
}
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
}
}
template<>
void NPartyTripleGenerator::generateBitsFromTriples(
vector< ShareTriple<gfp,2> >& triples, MAC_Check<gfp>& MC, ofstream& outputFile)
{
vector< Share<gfp> > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop);
for (int i = 0; i < nTriplesPerLoop; i++)
a_plus_b[i] = triples[i].a[0] + triples[i].b;
vector<gfp> opened(nTriplesPerLoop);
MC.POpen_Begin(opened, a_plus_b, globalPlayer);
MC.POpen_End(opened, a_plus_b, globalPlayer);
for (int i = 0; i < nTriplesPerLoop; i++)
a_squared[i] = triples[i].a[0] * opened[i] - triples[i].c[0];
MC.POpen_Begin(opened, a_squared, globalPlayer);
MC.POpen_End(opened, a_squared, globalPlayer);
Share<gfp> one(gfp(1), globalPlayer.my_num(), MC.get_alphai());
for (int i = 0; i < nTriplesPerLoop; i++)
{
gfp root = opened[i].sqrRoot();
if (root.is_zero())
continue;
Share<gfp> bit = (triples[i].a[0] / root + one) / gfp(2);
if (machine.output)
bit.output(outputFile, false);
}
}
template<>
void NPartyTripleGenerator::generateBitsFromTriples(
vector< ShareTriple<gf2n,2> >& triples, MAC_Check<gf2n>& MC, ofstream& outputFile)
{
throw how_would_that_work();
// warning gymnastics
triples[0];
MC.number();
outputFile << "";
}
template <class T>
void NPartyTripleGenerator::start_progress(vector< OTMultiplier<T>* >& ot_multipliers)
{
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
lock();
signal();
wait();
gettimeofday(&last_lap, 0);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
}
void NPartyTripleGenerator::print_progress(int k)
{
if (thread_num == 0 && my_num == 0)
{
struct timeval stop;
gettimeofday(&stop, 0);
if (timeval_diff_in_seconds(&last_lap, &stop) > 1)
{
double diff = timeval_diff_in_seconds(&machine.start, &stop);
double throughput = k * nTriplesPerLoop * machine.nthreads / diff;
double remaining = diff * (nloops - k) / k;
cout << k << '/' << nloops << ", throughput: " << throughput
<< ", time left: " << remaining << ", elapsed: " << diff
<< ", estimated total: " << (diff + remaining) << endl;
last_lap = stop;
}
}
}
void NPartyTripleGenerator::lock()
{
pthread_mutex_lock(&mutex);
}
void NPartyTripleGenerator::unlock()
{
pthread_mutex_unlock(&mutex);
}
void NPartyTripleGenerator::signal()
{
pthread_cond_signal(&ready);
}
void NPartyTripleGenerator::wait()
{
pthread_cond_wait(&ready, &mutex);
}
template void NPartyTripleGenerator::generate<gf2n>();
template void NPartyTripleGenerator::generate<gfp>();

View File

@@ -0,0 +1,84 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef OT_NPARTYTRIPLEGENERATOR_H_
#define OT_NPARTYTRIPLEGENERATOR_H_
#include "Networking/Player.h"
#include "OT/BaseOT.h"
#include "Tools/random.h"
#include "Tools/time-func.h"
#include "Math/gfp.h"
#include "Auth/MAC_Check.h"
#include "OT/OTTripleSetup.h"
#include "OT/TripleMachine.h"
#include "OT/OTMultiplier.h"
#include <map>
#include <vector>
#define N_AMPLIFY 3
template <class T, int N>
class ShareTriple;
class NPartyTripleGenerator
{
//OTTripleSetup* setup;
Player globalPlayer;
int thread_num;
int my_num;
int nbase;
struct timeval last_lap;
pthread_mutex_t mutex;
pthread_cond_t ready;
template <class T>
void generateTriples(vector< OTMultiplier<T>* >& ot_multipliers, ofstream& outputFile);
template <class T>
void generateBits(vector< OTMultiplier<T>* >& ot_multipliers, ofstream& outputFile);
template <class T, int N>
void generateBitsFromTriples(vector<ShareTriple<T, N> >& triples,
MAC_Check<T>& MC, ofstream& outputFile);
template <class T>
void start_progress(vector< OTMultiplier<T>* >& ot_multipliers);
void print_progress(int k);
public:
// TwoPartyPlayer's for OTs, n-party Player for sacrificing
vector<TwoPartyPlayer*> players;
//vector<OTMachine*> machines;
BitVector baseReceiverInput; // same for every set of OTs
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
vector<BitVector> valueBits;
int nTriples;
int nTriplesPerLoop;
int nloops;
int field_size;
int nAmplify;
int nPreampTriplesPerLoop;
int repeat[3];
int nparties;
TripleMachine& machine;
map<string,Timer> timers;
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, TripleMachine& machine);
~NPartyTripleGenerator();
template <class T>
void generate();
void lock();
void unlock();
void signal();
void wait();
};
#endif

791
OT/OTExtension.cpp Normal file
View File

@@ -0,0 +1,791 @@
// (C) 2016 University of Bristol. See License.txt
#include "OTExtension.h"
#include "OT/Tools.h"
#include "Math/gf2n.h"
#include "Tools/aes.h"
#include "Tools/MMO.h"
#include <wmmintrin.h>
#include <emmintrin.h>
word TRANSPOSE_MASKS128[7][2] = {
{ 0x0000000000000000, 0xFFFFFFFFFFFFFFFF },
{ 0x00000000FFFFFFFF, 0x00000000FFFFFFFF },
{ 0x0000FFFF0000FFFF, 0x0000FFFF0000FFFF },
{ 0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF },
{ 0x0F0F0F0F0F0F0F0F, 0x0F0F0F0F0F0F0F0F },
{ 0x3333333333333333, 0x3333333333333333 },
{ 0x5555555555555555, 0x5555555555555555 }
};
string word_to_str(word a)
{
stringstream ss;
ss << hex;
for(int i = 0;i < 8; i++)
ss << ((a >> (i*8)) & 255) << " ";
return ss.str();
}
// Transpose 16x16 matrix starting at bv[x][y] in-place using SSE2
void sse_transpose16(vector<BitVector>& bv, int x, int y)
{
__m128i input[2];
// 16x16 in two halves, 128 bits each
for (int i = 0; i < 2; i++)
for (int j = 0; j < 16; j++)
((octet*)&input[i])[j] = bv[x+j].get_byte(y / 8 + i);
for (int i = 0; i < 2; i++)
for (int j = 0; j < 8; j++)
{
int output = _mm_movemask_epi8(input[i]);
input[i] = _mm_slli_epi64(input[i], 1);
for (int k = 0; k < 2; k++)
// _mm_movemask_epi8 uses most significant bit, hence +7-j
bv[x+8*i+7-j].set_byte(y / 8 + k, ((octet*)&output)[k]);
}
}
/*
* Transpose 128x128 bit-matrix using Eklundh's algorithm
*
* Input is in input[i] [ bits <offset> to <offset+127> ], i = 0, ..., 127
* Output is in output[i + offset] (entire 128-bit vector), i = 0, ..., 127.
*
* Transposes 128-bit vectors in little-endian format.
*/
//void eklundh_transpose64(vector<word>& output, const vector<Big_Keys>& input, int offset)
void eklundh_transpose128(vector<BitVector>& output, const vector<BitVector>& input,
int offset)
{
int width = 64;
int logn = 7, nswaps = 1;
#ifdef TRANSPOSE_DEBUG
stringstream input_ss[128];
stringstream output_ss[128];
#endif
// first copy input to output
for (int i = 0; i < 128; i++)
{
//output[i + offset*64] = input[i].get(offset);
output[i + offset].set_word(0, input[i].get_word(offset/64));
output[i + offset].set_word(1, input[i].get_word(offset/64 + 1));
#ifdef TRANSPOSE_DEBUG
for (int j = 0; j < 128; j++)
{
input_ss[j] << input[i].get_bit(offset + j);
}
#endif
}
// now transpose output in-place
for (int i = 0; i < logn; i++)
{
word mask1 = TRANSPOSE_MASKS128[i][1], mask2 = TRANSPOSE_MASKS128[i][0];
word inv_mask1 = ~mask1, inv_mask2 = ~mask2;
if (width == 8)
{
for (int j = 0; j < 8; j++)
for (int k = 0; k < 8; k++)
sse_transpose16(output, offset + 16 * j, 16 * k);
break;
}
else
// for width >= 64, shift is undefined so treat as a special case
// (and avoid branching in inner loop)
if (width < 64)
{
for (int j = 0; j < nswaps; j++)
{
for (int k = 0; k < width; k++)
{
int i1 = k + 2*width*j;
int i2 = k + width + 2*width*j;
// t1 is lower 64 bits, t2 is upper 64 bits
// (remember we're transposing in little-endian format)
word t1 = output[i1 + offset].get_word(0);
word t2 = output[i1 + offset].get_word(1);
word tt1 = output[i2 + offset].get_word(0);
word tt2 = output[i2 + offset].get_word(1);
// swap operations due to little endian-ness
output[i1 + offset].set_word(0, (t1 & mask1) ^
((tt1 & mask1) << width));
output[i1 + offset].set_word(1, (t2 & mask2) ^
((tt2 & mask2) << width) ^
((tt1 & mask1) >> (64 - width)));
output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^
((t1 & inv_mask1) >> width) ^
((t2 & inv_mask2)) << (64 - width));
output[i2 + offset].set_word(1, (tt2 & inv_mask2) ^
((t2 & inv_mask2) >> width));
}
}
}
else
{
for (int j = 0; j < nswaps; j++)
{
for (int k = 0; k < width; k++)
{
int i1 = k + 2*width*j;
int i2 = k + width + 2*width*j;
// t1 is lower 64 bits, t2 is upper 64 bits
// (remember we're transposing in little-endian format)
word t1 = output[i1 + offset].get_word(0);
word t2 = output[i1 + offset].get_word(1);
word tt1 = output[i2 + offset].get_word(0);
word tt2 = output[i2 + offset].get_word(1);
output[i1 + offset].set_word(0, (t1 & mask1));
output[i1 + offset].set_word(1, (t2 & mask2) ^
((tt1 & mask1) >> (64 - width)));
output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^
((t2 & inv_mask2)) << (64 - width));
output[i2 + offset].set_word(1, (tt2 & inv_mask2));
}
}
}
nswaps *= 2;
width /= 2;
}
#ifdef TRANSPOSE_DEBUG
for (int i = 0; i < 128; i++)
{
for (int j = 0; j < 128; j++)
{
output_ss[j] << output[offset + j].get_bit(i);
}
}
for (int i = 0; i < 128; i++)
{
if (output_ss[i].str().compare(input_ss[i].str()) != 0)
{
cerr << "String " << i << " failed. offset = " << offset << endl;
cerr << input_ss[i].str() << endl;
cerr << output_ss[i].str() << endl;
exit(1);
}
}
cout << "\ttranspose with offset " << offset << " ok\n";
#endif
}
// get bit, starting from MSB as bit 0
int get_bit(word x, int b)
{
return (x >> (63 - b)) & 1;
}
int get_bit128(word x1, word x2, int b)
{
if (b < 64)
{
return (x1 >> (b - 64)) & 1;
}
else
{
return (x2 >> b) & 1;
}
}
void naive_transpose128(vector<BitVector>& output, const vector<BitVector>& input,
int offset)
{
for (int i = 0; i < 128; i++)
{
// NB: words are read from input in big-endian format
word w1 = input[i].get_word(offset/64);
word w2 = input[i].get_word(offset/64 + 1);
for (int j = 0; j < 128; j++)
{
//output[j + offset].set_bit(i, input[i].get_bit(j + offset));
if (j < 64)
output[j + offset].set_bit(i, (w1 >> j) & 1);
else
output[j + offset].set_bit(i, w2 >> (j-64) & 1);
}
}
}
void transpose64(
vector<BitVector>::iterator& output_it,
vector<BitVector>::iterator& input_it)
{
for (int i = 0; i < 64; i++)
{
for (int j = 0; j < 64; j++)
{
(output_it + j)->set_bit(i, (input_it + i)->get_bit(j));
}
}
}
// Naive 64x64 bit matrix transpose
void naive_transpose64(vector<BitVector>& output, const vector<BitVector>& input,
int xoffset, int yoffset)
{
int word_size = 64;
for (int i = 0; i < word_size; i++)
{
word w = input[i + yoffset].get_word(xoffset);
for (int j = 0; j < word_size; j++)
{
//cout << j + xoffset*word_size << ", " << yoffset/word_size << endl;
//int wbit = (((w >> j) & 1) << i); cout << "wbit " << wbit << endl;
// set i-th bit of output to j-th bit of w
// scale yoffset by 64 since we're selecting words from the BitVector
word tmp = output[j + xoffset*word_size].get_word(yoffset/word_size);
output[j + xoffset*word_size].set_word(yoffset/word_size, tmp ^ ((w >> j) & 1) << i);
// set i-th bit of output to j-th bit of w
//output[j + offset*word_size] ^= ((w >> j) & 1) << i;
}
}
}
void OTExtension::transfer(int nOTs,
const BitVector& receiverInput)
{
#ifdef OTEXT_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl;
// add k + s to account for discarding k OTs
nOTs += 2 * 128;
if (nOTs % nbaseOTs != 0)
throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
if (nOTs == 0)
return;
vector<BitVector> t0(nbaseOTs, BitVector(nOTs)), tmp(nbaseOTs, BitVector(nOTs)), t1(nbaseOTs, BitVector(nOTs));
BitVector u(nOTs);
senderOutput.resize(2, vector<BitVector>(nOTs, BitVector(nbaseOTs)));
// resize to account for extra k OTs that are discarded
PRNG G;
G.ReSeed();
BitVector newReceiverInput(nOTs);
for (unsigned int i = 0; i < receiverInput.size_bytes(); i++)
{
newReceiverInput.set_byte(i, receiverInput.get_byte(i));
}
//BitVector newReceiverInput(receiverInput);
newReceiverInput.resize(nOTs);
receiverOutput.resize(nOTs, BitVector(nbaseOTs));
for (int loop = 0; loop < nloops; loop++)
{
vector<octetStream> os(2), tmp_os(2);
// randomize last 128 + 128 bits that will be discarded
for (int i = 0; i < 4; i++)
newReceiverInput.set_word(nOTs/64 - i, G.get_word());
// expand with PRG and create correlation
if (ot_role & RECEIVER)
{
for (int i = 0; i < nbaseOTs; i++)
{
t0[i].randomize(G_sender[i][0]);
t1[i].randomize(G_sender[i][1]);
tmp[i].assign(t1[i]);
tmp[i].add(t0[i]);
tmp[i].add(newReceiverInput);
tmp[i].pack(os[0]);
/*cout << "t0: " << t0[i].str() << endl;
cout << "t1: " << t1[i].str() << endl;
cout << "Sending tmp: " << tmp[i].str() << endl;*/
}
}
#ifdef OTEXT_TIMER
timeval commst1, commst2;
gettimeofday(&commst1, NULL);
#endif
// send t0 + t1 + x
send_if_ot_receiver(player, os, ot_role);
// sender adjusts using base receiver bits
if (ot_role & SENDER)
{
for (int i = 0; i < nbaseOTs; i++)
{
// randomize base receiver output
tmp[i].randomize(G_receiver[i]);
// u = t0 + t1 + x
u.unpack(os[1]);
if (baseReceiverInput.get_bit(i) == 1)
{
// now tmp is q[i] = t0[i] + Delta[i] * x
tmp[i].add(u);
}
}
}
#ifdef OTEXT_TIMER
gettimeofday(&commst2, NULL);
double commstime = timeval_diff(&commst1, &commst2);
cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush;
times["Communication"] += timeval_diff(&commst1, &commst2);
#endif
// transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0]
// stupid transpose
/*for (int j = 0; j < nOTs; j++)
{
for (int i = 0; i < nbaseOTs; i++)
{
senderOutput[0][j].set_bit(i, t0[i].get_bit(j));
receiverOutput[j].set_bit(i, tmp[i].get_bit(j));
}
}*/
cout << "Starting matrix transpose\n" << flush << endl;
#ifdef OTEXT_TIMER
timeval transt1, transt2;
gettimeofday(&transt1, NULL);
#endif
// transpose in 128-bit chunks with Eklundh's algorithm
for (int i = 0; i < nOTs / 128; i++)
{
if (ot_role & RECEIVER)
{
eklundh_transpose128(receiverOutput, t0, i*128);
//naive_transpose128(receiverOutput, t0, i*128);
}
if (ot_role & SENDER)
{
eklundh_transpose128(senderOutput[0], tmp, i*128);
//naive_transpose128(senderOutput[0], tmp, i*128);
}
}
#ifdef OTEXT_TIMER
gettimeofday(&transt2, NULL);
double transtime = timeval_diff(&transt1, &transt2);
cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush;
times["Matrix transpose"] += timeval_diff(&transt1, &transt2);
#endif
#ifdef OTEXT_DEBUG
// verify correctness of the OTs
// i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i]
// (where Delta = baseReceiverOutput)
BitVector tmp_vector1(nbaseOTs), tmp_vector2(nOTs);//nbaseOTs);
cout << "\tVerifying OT extensions (debugging)\n";
for (int i = 0; i < nOTs; i++)
{
os[0].reset_write_head();
os[1].reset_write_head();
if (ot_role & RECEIVER)
{
// send t0 and x over
receiverOutput[i].pack(os[0]);
//t0[i].pack(os[0]);
newReceiverInput.pack(os[0]);
}
send_if_ot_receiver(player, os, ot_role);
if (ot_role & SENDER)
{
tmp_vector1.unpack(os[1]);
tmp_vector2.unpack(os[1]);
// if x_i = 1, add Delta
if (tmp_vector2.get_bit(i) == 1)
{
tmp_vector1.add(baseReceiverInput);
}
if (!tmp_vector1.equals(senderOutput[0][i]))
{
cerr << "Incorrect OT at " << i << "\n";
exit(1);
}
}
}
cout << "Correlated OTs all OK\n";
#endif
#ifdef OTEXT_TIMER
double elapsed;
#endif
// correlation check
if (!passive_only)
{
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
check_correlation(nOTs, newReceiverInput);
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
elapsed = timeval_diff(&startv, &endv);
cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush;
times["Total correlation check"] += timeval_diff(&startv, &endv);
#endif
}
hash_outputs(nOTs, receiverOutput);
#ifdef OTEXT_TIMER
gettimeofday(&totalendv, NULL);
elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush;
#endif
#ifdef OTEXT_DEBUG
// verify correctness of the random OTs
// i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i]
// (where Delta = baseReceiverOutput)
cout << "Verifying random OTs (debugging)\n";
for (int i = 0; i < nOTs; i++)
{
os[0].reset_write_head();
os[1].reset_write_head();
if (ot_role & RECEIVER)
{
// send receiver's input/output over
receiverOutput[i].pack(os[0]);
newReceiverInput.pack(os[0]);
}
send_if_ot_receiver(player, os, ot_role);
//player->send_receive_player(os);
if (ot_role & SENDER)
{
tmp_vector1.unpack(os[1]);
tmp_vector2.unpack(os[1]);
// if x_i = 1, comp with sender output[1]
if ((tmp_vector2.get_bit(i) == 1))
{
if (!tmp_vector1.equals(senderOutput[1][i]))
{
cerr << "Incorrect OT\n";
exit(1);
}
}
// else should be sender output[0]
else if (!tmp_vector1.equals(senderOutput[0][i]))
{
cerr << "Incorrect OT\n";
exit(1);
}
}
}
cout << "Random OTs all OK\n";
#endif
}
#ifdef OTEXT_TIMER
gettimeofday(&totalendv, NULL);
times["Total thread"] += timeval_diff(&totalstartv, &totalendv);
#endif
receiverOutput.resize(nOTs - 2 * 128);
senderOutput[0].resize(nOTs - 2 * 128);
senderOutput[1].resize(nOTs - 2 * 128);
}
/*
* Hash outputs to make into random OT
*/
void OTExtension::hash_outputs(int nOTs, vector<BitVector>& receiverOutput)
{
cout << "Hashing... " << flush;
octetStream os, h_os(HASH_SIZE);
BitVector tmp(nbaseOTs);
MMO mmo;
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
for (int i = 0; i < nOTs; i++)
{
if (ot_role & SENDER)
{
tmp.add(senderOutput[0][i], baseReceiverInput);
if (senderOutput[0][i].size() == 128)
{
mmo.hashOneBlock<gf2n>(senderOutput[0][i].get_ptr(), senderOutput[0][i].get_ptr());
mmo.hashOneBlock<gf2n>(senderOutput[1][i].get_ptr(), tmp.get_ptr());
}
else
{
os.reset_write_head();
h_os.reset_write_head();
senderOutput[0][i].pack(os);
os.hash(h_os);
senderOutput[0][i].unpack(h_os);
os.reset_write_head();
h_os.reset_write_head();
tmp.pack(os);
os.hash(h_os);
senderOutput[1][i].unpack(h_os);
}
}
if (ot_role & RECEIVER)
{
if (receiverOutput[i].size() == 128)
mmo.hashOneBlock<gf2n>(receiverOutput[i].get_ptr(), receiverOutput[i].get_ptr());
else
{
os.reset_write_head();
h_os.reset_write_head();
receiverOutput[i].pack(os);
os.hash(h_os);
receiverOutput[i].unpack(h_os);
}
}
}
cout << "done.\n";
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
double elapsed = timeval_diff(&startv, &endv);
cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush;
times["Hashing"] += timeval_diff(&startv, &endv);
#endif
}
// test if a == b
int eq_m128i(__m128i a, __m128i b)
{
__m128i vcmp = _mm_cmpeq_epi8(a, b);
uint16_t vmask = _mm_movemask_epi8(vcmp);
return (vmask == 0xffff);
}
void random_m128i(PRNG& G, __m128i *r)
{
BitVector rv(128);
rv.randomize(G);
*r = _mm_load_si128((__m128i*)&(rv.get_ptr()[0]));
}
void test_mul()
{
cout << "Testing GF(2^128) multiplication\n";
__m128i t1, t2, t3, t4, t5, t6, t7, t8;
PRNG G;
G.ReSeed();
BitVector r(128);
for (int i = 0; i < 1000; i++)
{
random_m128i(G, &t1);
random_m128i(G, &t2);
// test commutativity
gfmul128(t1, t2, &t3);
gfmul128(t2, t1, &t4);
if (!eq_m128i(t3, t4))
{
cerr << "Incorrect multiplication:\n";
cerr << "t1 * t2 = " << __m128i_toString<octet>(t3) << endl;
cerr << "t2 * t1 = " << __m128i_toString<octet>(t4) << endl;
}
// test distributivity: t1*t3 + t2*t3 = (t1 + t2) * t3
random_m128i(G, &t1);
random_m128i(G, &t2);
random_m128i(G, &t3);
gfmul128(t1, t3, &t4);
gfmul128(t2, t3, &t5);
t6 = _mm_xor_si128(t4, t5);
t7 = _mm_xor_si128(t1, t2);
gfmul128(t7, t3, &t8);
if (!eq_m128i(t6, t8))
{
cerr << "Incorrect multiplication:\n";
cerr << "t1 * t3 + t2 * t3 = " << __m128i_toString<octet>(t6) << endl;
cerr << "(t1 + t2) * t3 = " << __m128i_toString<octet>(t8) << endl;
}
}
t1 = _mm_set_epi32(0, 0, 0, 03);
t2 = _mm_set_epi32(0, 0, 0, 11);
//gfmul128(t1, t2, &t3);
mul128(t1, t2, &t3, &t4);
cout << "t1 = " << __m128i_toString<octet>(t1) << endl;
cout << "t2 = " << __m128i_toString<octet>(t2) << endl;
cout << "t3 = " << __m128i_toString<octet>(t3) << endl;
cout << "t4 = " << __m128i_toString<octet>(t4) << endl;
uint64_t cc[] __attribute__((aligned (16))) = { 0,0 };
_mm_store_si128((__m128i*)cc, t1);
word t1w = cc[0];
_mm_store_si128((__m128i*)cc, t2);
word t2w = cc[0];
cout << "t1w = " << t1w << endl;
cout << "t1 = " << word_to_bytes(t1w) << endl;
cout << "t2 = " << word_to_bytes(t2w) << endl;
cout << "t1 * t2 = " << word_to_bytes(t1w*t2w) << endl;
}
void OTExtension::check_correlation(int nOTs,
const BitVector& receiverInput)
{
//cout << "\tStarting correlation check\n" << flush;
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
if (nbaseOTs != 128)
{
cerr << "Correlation check not implemented for length != 128\n";
throw not_implemented();
}
PRNG G;
octet* seed = new octet[SEED_SIZE];
random_seed_commit(seed, *player, SEED_SIZE);
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
double elapsed = timeval_diff(&startv, &endv);
cout << "\t\tCommitment for seed took time " << elapsed/1000000 << endl << flush;
times["Commitment for seed"] += timeval_diff(&startv, &endv);
gettimeofday(&startv, NULL);
#endif
G.SetSeed(seed);
delete[] seed;
vector<octetStream> os(2);
if (!Check_CPU_support_AES())
{
cerr << "Not implemented GF(2^128) multiplication in C\n";
throw not_implemented();
}
__m128i Delta, x128i;
Delta = _mm_load_si128((__m128i*)&(baseReceiverInput.get_ptr()[0]));
BitVector chi(nbaseOTs);
BitVector x(nbaseOTs);
__m128i t = _mm_setzero_si128();
__m128i q = _mm_setzero_si128();
__m128i t2 = _mm_setzero_si128();
__m128i q2 = _mm_setzero_si128();
__m128i chii, ti, qi, ti2, qi2;
x128i = _mm_setzero_si128();
for (int i = 0; i < nOTs; i++)
{
// chi.randomize(G);
// chii = _mm_load_si128((__m128i*)&(chi.get_ptr()[0]));
chii = G.get_doubleword();
if (ot_role & RECEIVER)
{
if (receiverInput.get_bit(i) == 1)
{
x128i = _mm_xor_si128(x128i, chii);
}
ti = _mm_loadu_si128((__m128i*)get_receiver_output(i));
// multiply over polynomial ring to avoid reduction
mul128(ti, chii, &ti, &ti2);
t = _mm_xor_si128(t, ti);
t2 = _mm_xor_si128(t2, ti2);
}
if (ot_role & SENDER)
{
qi = _mm_loadu_si128((__m128i*)(get_sender_output(0, i)));
mul128(qi, chii, &qi, &qi2);
q = _mm_xor_si128(q, qi);
q2 = _mm_xor_si128(q2, qi2);
}
}
#ifdef OTEXT_DEBUG
if (ot_role & RECEIVER)
{
cout << "\tSending x,t\n";
cout << "\tsend x = " << __m128i_toString<octet>(x128i) << endl;
cout << "\tsend t = " << __m128i_toString<octet>(t) << endl;
cout << "\tsend t2 = " << __m128i_toString<octet>(t2) << endl;
}
#endif
check_iteration(Delta, q, q2, t, t2, x128i);
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
elapsed = timeval_diff(&startv, &endv);
cout << "\t\tChecking correlation took time " << elapsed/1000000 << endl << flush;
times["Checking correlation"] += timeval_diff(&startv, &endv);
#endif
}
void OTExtension::check_iteration(__m128i delta, __m128i q, __m128i q2,
__m128i t, __m128i t2, __m128i x)
{
vector<octetStream> os(2);
// send x, t;
__m128i received_t, received_t2, received_x, tmp1, tmp2;
if (ot_role & RECEIVER)
{
os[0].append((octet*)&x, sizeof(x));
os[0].append((octet*)&t, sizeof(t));
os[0].append((octet*)&t2, sizeof(t2));
}
send_if_ot_receiver(player, os, ot_role);
if (ot_role & SENDER)
{
os[1].consume((octet*)&received_x, sizeof(received_x));
os[1].consume((octet*)&received_t, sizeof(received_t));
os[1].consume((octet*)&received_t2, sizeof(received_t2));
// check t = x * Delta + q
//gfmul128(received_x, delta, &tmp1);
mul128(received_x, delta, &tmp1, &tmp2);
tmp1 = _mm_xor_si128(tmp1, q);
tmp2 = _mm_xor_si128(tmp2, q2);
if (eq_m128i(tmp1, received_t) && eq_m128i(tmp2, received_t2))
{
//cout << "\tCheck passed\n";
}
else
{
cerr << "Correlation check failed\n";
cout << "rec t = " << __m128i_toString<octet>(received_t) << endl;
cout << "tmp1 = " << __m128i_toString<octet>(tmp1) << endl;
cout << "q = " << __m128i_toString<octet>(q) << endl;
exit(1);
}
}
}
octet* OTExtension::get_receiver_output(int i)
{
return receiverOutput[i].get_ptr();
}
octet* OTExtension::get_sender_output(int choice, int i)
{
return senderOutput[choice][i].get_ptr();
}

122
OT/OTExtension.h Normal file
View File

@@ -0,0 +1,122 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _OTEXTENSION
#define _OTEXTENSION
#include "OT/BaseOT.h"
#include "Exceptions/Exceptions.h"
#include "Networking/Player.h"
#include "Tools/time-func.h"
#include <stdlib.h>
#include <assert.h>
#include <sstream>
#include <fstream>
#include <iostream>
#include <map>
using namespace std;
//#define OTEXT_TIMER
//#define OTEXT_DEBUG
class OTExtension
{
public:
BitVector baseReceiverInput;
vector< vector<BitVector> > senderOutput;
vector<BitVector> receiverOutput;
map<string,long long> times;
OTExtension(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
TwoPartyPlayer* player,
BitVector& baseReceiverInput,
vector< vector<BitVector> >& baseSenderInput,
vector<BitVector>& baseReceiverOutput,
OT_ROLE role=BOTH,
bool passive=false)
: baseReceiverInput(baseReceiverInput), passive_only(passive), nbaseOTs(nbaseOTs),
baseLength(baseLength), nloops(nloops), nsubloops(nsubloops), ot_role(role), player(player)
{
G_sender.resize(nbaseOTs, vector<PRNG>(2));
G_receiver.resize(nbaseOTs);
// set up PRGs for expanding the seed OTs
for (int i = 0; i < nbaseOTs; i++)
{
assert(baseSenderInput[i][0].size_bytes() >= AES_BLK_SIZE);
assert(baseSenderInput[i][1].size_bytes() >= AES_BLK_SIZE);
assert(baseReceiverOutput[i].size_bytes() >= AES_BLK_SIZE);
if (ot_role & RECEIVER)
{
G_sender[i][0].SetSeed(baseSenderInput[i][0].get_ptr());
G_sender[i][1].SetSeed(baseSenderInput[i][1].get_ptr());
}
if (ot_role & SENDER)
{
G_receiver[i].SetSeed(baseReceiverOutput[i].get_ptr());
}
#ifdef OTEXT_DEBUG
// sanity check for base OTs
vector<octetStream> os(2);
BitVector t0(128);
if (ot_role & RECEIVER)
{
// send both inputs to test
baseSenderInput[i][0].pack(os[0]);
baseSenderInput[i][1].pack(os[0]);
}
send_if_ot_receiver(player, os, ot_role);
if (ot_role & SENDER)
{
// sender checks results
t0.unpack(os[1]);
if (baseReceiverInput.get_bit(i) == 1)
t0.unpack(os[1]);
if (!t0.equals(baseReceiverOutput[i]))
{
cerr << "Incorrect base OT\n";
exit(1);
}
}
os[0].reset_write_head();
os[1].reset_write_head();
#endif
}
}
virtual ~OTExtension() {}
virtual void transfer(int nOTs, const BitVector& receiverInput);
virtual octet* get_receiver_output(int i);
virtual octet* get_sender_output(int choice, int i);
protected:
bool passive_only;
int nbaseOTs, baseLength, nloops, nsubloops;
OT_ROLE ot_role;
TwoPartyPlayer* player;
vector< vector<PRNG> > G_sender;
vector<PRNG> G_receiver;
void check_correlation(int nOTs,
const BitVector& receiverInput);
void check_iteration(__m128i delta, __m128i q, __m128i q2,
__m128i t, __m128i t2, __m128i x);
void hash_outputs(int nOTs, vector<BitVector>& receiverOutput);
};
#endif

View File

@@ -0,0 +1,466 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OTExtensionWithMatrix.cpp
*
*/
#include "OTExtensionWithMatrix.h"
#include "Math/gfp.h"
void OTExtensionWithMatrix::seed(vector<BitMatrix>& baseSenderInput,
BitMatrix& baseReceiverOutput)
{
nbaseOTs = baseReceiverInput.size();
//cout << "nbaseOTs " << nbaseOTs << endl;
G_sender.resize(nbaseOTs, vector<PRNG>(2));
G_receiver.resize(nbaseOTs);
// set up PRGs for expanding the seed OTs
for (int i = 0; i < nbaseOTs; i++)
{
if (ot_role & RECEIVER)
{
G_sender[i][0].SetSeed((octet*)&baseSenderInput[0].squares[i/128].rows[i%128]);
G_sender[i][1].SetSeed((octet*)&baseSenderInput[1].squares[i/128].rows[i%128]);
}
if (ot_role & SENDER)
{
G_receiver[i].SetSeed((octet*)&baseReceiverOutput.squares[i/128].rows[i%128]);
}
}
}
void OTExtensionWithMatrix::transfer(int nOTs,
const BitVector& receiverInput)
{
#ifdef OTEXT_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl;
// resize to account for extra k OTs that are discarded
BitVector newReceiverInput(nOTs);
for (unsigned int i = 0; i < receiverInput.size_bytes(); i++)
{
newReceiverInput.set_byte(i, receiverInput.get_byte(i));
}
for (int loop = 0; loop < nloops; loop++)
{
extend<gf2n>(nOTs, newReceiverInput);
#ifdef OTEXT_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush;
#endif
}
#ifdef OTEXT_TIMER
gettimeofday(&totalendv, NULL);
times["Total thread"] += timeval_diff(&totalstartv, &totalendv);
#endif
}
void OTExtensionWithMatrix::resize(int nOTs)
{
t1.resize(nOTs);
u.resize(nOTs);
senderOutputMatrices.resize(2);
for (int i = 0; i < 2; i++)
senderOutputMatrices[i].resize(nOTs);
receiverOutputMatrix.resize(nOTs);
}
// the template is used to denote the field of the hash output
template <class T>
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
{
// if (nOTs % nbaseOTs != 0)
// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
if (nOTs_requested == 0)
return;
// add k + s to account for discarding k OTs
int nOTs = nOTs_requested + 2 * 128;
int slice = nOTs / nsubloops / 128;
nOTs = slice * nsubloops * 128;
resize(nOTs);
newReceiverInput.resize(nOTs);
// randomize last 128 + 128 bits that will be discarded
for (int i = 0; i < 4; i++)
newReceiverInput.set_word(nOTs/64 - i - 1, G.get_word());
// subloop for first part to interleave communication with computation
for (int start = 0; start < nOTs / 128; start += slice)
{
expand<gf2n>(start, slice);
correlate<gf2n>(start, slice, newReceiverInput, true);
transpose(start, slice);
}
#ifdef OTEXT_TIMER
double elapsed;
#endif
// correlation check
if (!passive_only)
{
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
check_correlation(nOTs, newReceiverInput);
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
elapsed = timeval_diff(&startv, &endv);
cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush;
times["Total correlation check"] += timeval_diff(&startv, &endv);
#endif
}
hash_outputs<T>(nOTs);
receiverOutputMatrix.resize(nOTs_requested);
senderOutputMatrices[0].resize(nOTs_requested);
senderOutputMatrices[1].resize(nOTs_requested);
newReceiverInput.resize(nOTs_requested);
}
template <class T>
void OTExtensionWithMatrix::expand(int start, int slice)
{
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
BitMatrixSlice senderOutputSlices[2] = {
BitMatrixSlice(senderOutputMatrices[0], start, slice),
BitMatrixSlice(senderOutputMatrices[1], start, slice)
};
BitMatrixSlice t1Slice(t1, start, slice);
// expand with PRG
if (ot_role & RECEIVER)
{
for (int i = 0; i < nbaseOTs; i++)
{
receiverOutputSlice.randomize<T>(i, G_sender[i][0]);
t1Slice.randomize<T>(i, G_sender[i][1]);
}
}
if (ot_role & SENDER)
{
for (int i = 0; i < nbaseOTs; i++)
// randomize base receiver output
senderOutputSlices[0].randomize<T>(i, G_receiver[i]);
}
}
template <class T>
void OTExtensionWithMatrix::expand_transposed()
{
for (int i = 0; i < nbaseOTs; i++)
{
if (ot_role & RECEIVER)
{
receiverOutputMatrix.squares[i/128].randomize<T>(i % 128, G_sender[i][0]);
t1.squares[i/128].randomize<T>(i % 128, G_sender[i][1]);
}
if (ot_role & SENDER)
{
senderOutputMatrices[0].squares[i/128].randomize<T>(i % 128, G_receiver[i]);
}
}
}
void OTExtensionWithMatrix::setup_for_correlation(vector<BitMatrix>& baseSenderOutputs, BitMatrix& baseReceiverOutput)
{
receiverOutputMatrix = baseSenderOutputs[0];
t1 = baseSenderOutputs[1];
u.resize(t1.size());
senderOutputMatrices.resize(2);
senderOutputMatrices[0] = baseReceiverOutput;
}
template <class T>
void OTExtensionWithMatrix::correlate(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat)
{
vector<octetStream> os(2);
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
BitMatrixSlice senderOutputSlices[2] = {
BitMatrixSlice(senderOutputMatrices[0], start, slice),
BitMatrixSlice(senderOutputMatrices[1], start, slice)
};
BitMatrixSlice t1Slice(t1, start, slice);
BitMatrixSlice uSlice(u, start, slice);
// create correlation
if (ot_role & RECEIVER)
{
t1Slice.rsub<T>(receiverOutputSlice);
t1Slice.add<T>(newReceiverInput, repeat);
t1Slice.pack(os[0]);
// t1 = receiverOutputMatrix;
// t1 ^= newReceiverInput;
// receiverOutputMatrix.print_side_by_side(t1);
}
#ifdef OTEXT_TIMER
timeval commst1, commst2;
gettimeofday(&commst1, NULL);
#endif
// send t0 + t1 + x
send_if_ot_receiver(player, os, ot_role);
// sender adjusts using base receiver bits
if (ot_role & SENDER)
{
// u = t0 + t1 + x
uSlice.unpack(os[1]);
senderOutputSlices[0].conditional_add<T>(baseReceiverInput, u, !useConstantBase);
}
#ifdef OTEXT_TIMER
gettimeofday(&commst2, NULL);
double commstime = timeval_diff(&commst1, &commst2);
cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush;
times["Communication"] += timeval_diff(&commst1, &commst2);
#endif
}
void OTExtensionWithMatrix::transpose(int start, int slice)
{
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
BitMatrixSlice senderOutputSlices[2] = {
BitMatrixSlice(senderOutputMatrices[0], start, slice),
BitMatrixSlice(senderOutputMatrices[1], start, slice)
};
// transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0]
//cout << "Starting matrix transpose\n" << flush << endl;
#ifdef OTEXT_TIMER
timeval transt1, transt2;
gettimeofday(&transt1, NULL);
#endif
// transpose in 128-bit chunks
if (ot_role & RECEIVER)
receiverOutputSlice.transpose();
if (ot_role & SENDER)
senderOutputSlices[0].transpose();
#ifdef OTEXT_TIMER
gettimeofday(&transt2, NULL);
double transtime = timeval_diff(&transt1, &transt2);
cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush;
times["Matrix transpose"] += timeval_diff(&transt1, &transt2);
#endif
}
/*
* Hash outputs to make into random OT
*/
template <class T>
void OTExtensionWithMatrix::hash_outputs(int nOTs)
{
//cout << "Hashing... " << flush;
octetStream os, h_os(HASH_SIZE);
square128 tmp;
MMO mmo;
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
for (int i = 0; i < nOTs / 128; i++)
{
if (ot_role & SENDER)
{
tmp = senderOutputMatrices[0].squares[i];
tmp ^= baseReceiverInput;
senderOutputMatrices[0].squares[i].hash_row_wise<T>(mmo, senderOutputMatrices[0].squares[i]);
senderOutputMatrices[1].squares[i].hash_row_wise<T>(mmo, tmp);
}
if (ot_role & RECEIVER)
{
receiverOutputMatrix.squares[i].hash_row_wise<T>(mmo, receiverOutputMatrix.squares[i]);
}
}
//cout << "done.\n";
#ifdef OTEXT_TIMER
gettimeofday(&endv, NULL);
double elapsed = timeval_diff(&startv, &endv);
cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush;
times["Hashing"] += timeval_diff(&startv, &endv);
#endif
}
template <class T>
void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples, vector<T>& output)
{
if (receiverOutputMatrix.squares.size() < nTriples)
throw invalid_length();
output.resize(nTriples);
for (unsigned int j = 0; j < nTriples; j++)
{
T c1, c2;
receiverOutputMatrix.squares[j].to(c1);
senderOutputMatrices[0].squares[j].to(c2);
output[j] = c1 - c2;
}
}
octet* OTExtensionWithMatrix::get_receiver_output(int i)
{
return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128];
}
octet* OTExtensionWithMatrix::get_sender_output(int choice, int i)
{
return (octet*)&senderOutputMatrices[choice].squares[i/128].rows[i%128];
}
void OTExtensionWithMatrix::print(BitVector& newReceiverInput, int i)
{
if (player->my_num() == 0)
{
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix, i);
print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]);
}
else
{
print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]);
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix, i);
}
}
template <class T>
void OTExtensionWithMatrix::print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int k, int offset)
{
if (ot_role & RECEIVER)
{
for (int i = 0; i < 16; i++)
{
if (newReceiverInput.get_bit((offset + k) * 128 + i))
{
for (int j = 0; j < 33; j++)
cout << " ";
cout << T(matrix.squares[k].rows[i]);
}
else
cout << int128(matrix.squares[k].rows[i]);
cout << endl;
}
cout << endl;
}
}
void OTExtensionWithMatrix::print_sender(square128& square0, square128& square1)
{
if (ot_role & SENDER)
{
for (int i = 0; i < 16; i++)
{
cout << int128(square0.rows[i]) << " ";
cout << int128(square1.rows[i]) << " ";
cout << endl;
}
cout << endl;
}
}
template <class T>
void OTExtensionWithMatrix::print_post_correlate(BitVector& newReceiverInput, int j, int offset, int sender)
{
cout << "post correlate, sender" << sender << endl;
if (player->my_num() == sender)
{
T delta = newReceiverInput.get_int128(offset + j);
for (int i = 0; i < 16; i++)
{
cout << (int128(receiverOutputMatrix.squares[j].rows[i]));
cout << " ";
cout << (T(receiverOutputMatrix.squares[j].rows[i]) - delta);
cout << endl;
}
cout << endl;
}
else
{
print_receiver<T>(baseReceiverInput, senderOutputMatrices[0], j);
}
}
void OTExtensionWithMatrix::print_pre_correlate(int i)
{
cout << "pre correlate" << endl;
if (player->my_num() == 0)
print_sender(receiverOutputMatrix.squares[i], t1.squares[i]);
else
print_receiver<gf2n>(baseReceiverInput, senderOutputMatrices[0], i);
}
void OTExtensionWithMatrix::print_post_transpose(BitVector& newReceiverInput, int i, int sender)
{
cout << "post transpose, sender " << sender << endl;
if (player->my_num() == sender)
{
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix);
}
else
{
square128 tmp = senderOutputMatrices[0].squares[i];
tmp ^= baseReceiverInput;
print_sender(senderOutputMatrices[0].squares[i], tmp);
}
}
void OTExtensionWithMatrix::print_pre_expand()
{
cout << "pre expand" << endl;
if (player->my_num() == 0)
{
for (int i = 0; i < 16; i++)
{
for (int j = 0; j < 2; j++)
cout << int128(_mm_loadu_si128((__m128i*)G_sender[i][j].get_seed())) << " ";
cout << endl;
}
cout << endl;
}
else
{
for (int i = 0; i < 16; i++)
{
if (baseReceiverInput.get_bit(i))
{
for (int j = 0; j < 33; j++)
cout << " ";
}
cout << int128(_mm_loadu_si128((__m128i*)G_receiver[i].get_seed())) << endl;
}
cout << endl;
}
}
template void OTExtensionWithMatrix::correlate<gf2n>(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat);
template void OTExtensionWithMatrix::correlate<gfp>(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat);
template void OTExtensionWithMatrix::print_post_correlate<gf2n>(
BitVector& newReceiverInput, int j, int offset, int sender);
template void OTExtensionWithMatrix::print_post_correlate<gfp>(
BitVector& newReceiverInput, int j, int offset, int sender);
template void OTExtensionWithMatrix::extend<gf2n>(int nOTs_requested,
BitVector& newReceiverInput);
template void OTExtensionWithMatrix::extend<gfp>(int nOTs_requested,
BitVector& newReceiverInput);
template void OTExtensionWithMatrix::expand<gf2n>(int start, int slice);
template void OTExtensionWithMatrix::expand<gfp>(int start, int slice);
template void OTExtensionWithMatrix::expand_transposed<gf2n>();
template void OTExtensionWithMatrix::expand_transposed<gfp>();
template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples,
vector<gf2n>& output);
template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples,
vector<gfp>& output);

View File

@@ -0,0 +1,71 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OTExtensionWithMatrix.h
*
*/
#ifndef OT_OTEXTENSIONWITHMATRIX_H_
#define OT_OTEXTENSIONWITHMATRIX_H_
#include "OTExtension.h"
#include "BitMatrix.h"
#include "Math/gf2n.h"
class OTExtensionWithMatrix : public OTExtension
{
public:
vector<BitMatrix> senderOutputMatrices;
BitMatrix receiverOutputMatrix;
BitMatrix t1, u;
PRNG G;
OTExtensionWithMatrix(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
TwoPartyPlayer* player,
BitVector& baseReceiverInput,
vector< vector<BitVector> >& baseSenderInput,
vector<BitVector>& baseReceiverOutput,
OT_ROLE role=BOTH,
bool passive=false)
: OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
baseSenderInput, baseReceiverOutput, role, passive) {
G.ReSeed();
}
void seed(vector<BitMatrix>& baseSenderInput,
BitMatrix& baseReceiverOutput);
void transfer(int nOTs, const BitVector& receiverInput);
void resize(int nOTs);
template <class T>
void extend(int nOTs, BitVector& newReceiverInput);
template <class T>
void expand(int start, int slice);
template <class T>
void expand_transposed();
template <class T>
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
void transpose(int start, int slice);
void setup_for_correlation(vector<BitMatrix>& baseSenderOutputs, BitMatrix& baseReceiverOutput);
template <class T>
void reduce_squares(unsigned int nTriples, vector<T>& output);
void print(BitVector& newReceiverInput, int i = 0);
template <class T>
void print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int i = 0, int offset = 0);
void print_sender(square128& square0, square128& square);
template <class T>
void print_post_correlate(BitVector& newReceiverInput, int i = 0, int offset = 0, int sender = 0);
void print_pre_correlate(int i = 0);
void print_post_transpose(BitVector& newReceiverInput, int i = 0, int sender = 0);
void print_pre_expand();
octet* get_receiver_output(int i);
octet* get_sender_output(int choice, int i);
protected:
template <class T>
void hash_outputs(int nOTs);
};
#endif /* OT_OTEXTENSIONWITHMATRIX_H_ */

401
OT/OTMachine.cpp Normal file
View File

@@ -0,0 +1,401 @@
// (C) 2016 University of Bristol. See License.txt
#include "Networking/Player.h"
#include "OT/OTExtension.h"
#include "OT/OTExtensionWithMatrix.h"
#include "Exceptions/Exceptions.h"
#include "Tools/time-func.h"
#include <stdlib.h>
#include <sstream>
#include <fstream>
#include <iostream>
#include <sys/time.h>
#include "OutputCheck.h"
#include "OTMachine.h"
//#define BASE_OT_DEBUG
class OT_thread_info
{
public:
int thread_num;
bool stop;
int other_player_num;
OTExtension* ot_ext;
int nOTs, nbase;
BitVector receiverInput;
};
void* run_otext_thread(void* ptr)
{
OT_thread_info *tinfo = (OT_thread_info*) ptr;
//int num = tinfo->thread_num;
//int other_player_num = tinfo->other_player_num;
printf("\tI am in thread %d\n", tinfo->thread_num);
tinfo->ot_ext->transfer(tinfo->nOTs, tinfo->receiverInput);
return NULL;
}
OTMachine::OTMachine(int argc, const char** argv)
{
opt.add(
"", // Default.
1, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number, 0/1 (required).", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.add(
"5000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Base port number (default: 5000).", // Help description.
"-pn", // Flag token.
"--portnum" // Flag token.
);
opt.add(
"localhost", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Host name(s) that player 0 is running on (default: localhost). Split with commas.", // Help description.
"-h", // Flag token.
"--hostname" // Flag token.
);
opt.add(
"1024",
0,
1,
0,
"Number of extended OTs to run (default: 1024).",
"-n",
"--nOTs"
);
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of base OTs to run (default: 128).", // Help description.
"-b", // Flag token.
"--nbase" // Flag token.
);
opt.add(
"s",
0,
1,
0,
"Mode for OT. a (asymmetric) or s (symmetric, i.e. play both sender/receiver) (default: s).",
"-m",
"--mode"
);
opt.add(
"1",
0,
1,
0,
"Number of threads (default: 1).",
"-x",
"--nthreads"
);
opt.add(
"1",
0,
1,
0,
"Number of loops (default: 1).",
"-l",
"--nloops"
);
opt.add(
"1",
0,
1,
0,
"Number of subloops (default: 1).",
"-s",
"--nsubloops"
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Run in passive security mode.", // Help description.
"-pas", // Flag token.
"--passive" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Write results to files.", // Help description.
"-o", // Flag token.
"--output" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Real base OT.", // Help description.
"-r", // Flag token.
"--real" // Flag token.
);
opt.parse(argc, argv);
string hostname, ot_mode, usage;
passive = false;
opt.get("-p")->getInt(my_num);
opt.get("-pn")->getInt(portnum_base);
opt.get("-h")->getString(hostname);
opt.get("-n")->getLong(nOTs);
opt.get("-m")->getString(ot_mode);
opt.get("--nthreads")->getInt(nthreads);
opt.get("--nloops")->getInt(nloops);
opt.get("--nsubloops")->getInt(nsubloops);
opt.get("--nbase")->getInt(nbase);
if (opt.isSet("-pas"))
passive = true;
if (!opt.isSet("-p"))
{
opt.getUsage(usage);
cout << usage;
exit(0);
}
cout << "Player 0 host name = " << hostname << endl;
cout << "Creating " << nOTs << " extended OTs in " << nthreads << " threads\n";
cout << "Running in mode " << ot_mode << endl;
if (passive)
cout << "Running with PASSIVE security only\n";
if (nbase < 128)
cout << "WARNING: only using " << nbase << " seed OTs, using these for OT extensions is insecure.\n";
if (ot_mode.compare("s") == 0)
ot_role = BOTH;
else if (ot_mode.compare("a") == 0)
{
if (my_num == 0)
ot_role = SENDER;
else
ot_role = RECEIVER;
}
else
{
cerr << "Invalid OT mode argument: " << ot_mode << endl;
exit(1);
}
// Several names for multiplexing
unsigned int pos = 0;
while (pos < hostname.length())
{
string::size_type new_pos = hostname.find(',', pos);
if (new_pos == string::npos)
new_pos = hostname.length();
int len = new_pos - pos;
string name = hostname.substr(pos, len);
pos = new_pos + 1;
vector<string> names(2);
names[my_num] = "localhost";
names[1-my_num] = name;
N.resize(N.size() + 1);
N[N.size()-1].init(my_num, portnum_base, names);
}
P = new TwoPartyPlayer(N[0], 1 - my_num, 500);
timeval baseOTstart, baseOTend;
gettimeofday(&baseOTstart, NULL);
// swap role for base OTs
if (opt.isSet("-r"))
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
else
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
cout << "real mode " << opt.isSet("-r") << endl;
BaseOT& bot = *bot_;
bot.exec_base();
gettimeofday(&baseOTend, NULL);
double basetime = timeval_diff(&baseOTstart, &baseOTend);
cout << "\t\tBaseTime (" << role_to_str(ot_role) << "): " << basetime/1000000 << endl << flush;
// Receiver send something to force synchronization
// (since Sender finishes baseOTs before Receiver)
int a = 3;
vector<octetStream> os(2);
os[0].store(a);
P->send_receive_player(os);
os[1].get(a);
cout << a << endl;
#ifdef BASE_OT_DEBUG
// check base OTs
bot.check();
// check after extending with PRG a few times
for (int i = 0; i < 8; i++)
{
bot.extend_length();
bot.check();
}
cout << "Verifying base OTs (debugging)\n";
#endif
// convert baseOT selection bits to BitVector
// (not already BitVector due to legacy PVW code)
baseReceiverInput.resize(nbase);
for (int i = 0; i < nbase; i++)
{
baseReceiverInput.set_bit(i, bot.receiver_inputs[i]);
}
}
OTMachine::~OTMachine()
{
delete bot_;
delete P;
}
void OTMachine::run()
{
// divide nOTs between threads and loops
nOTs = DIV_CEIL(nOTs, nthreads * nloops);
// round up to multiple of base OTs and subloops
// discount for discarded OTs
nOTs = DIV_CEIL(nOTs + 2 * 128, nbase * nsubloops) * nbase * nsubloops - 2 * 128;
cout << "Running " << nOTs << " OT extensions per thread and loop\n" << flush;
// PRG for generating inputs etc
PRNG G;
G.ReSeed();
BitVector receiverInput(nOTs);
receiverInput.randomize(G);
BaseOT& bot = *bot_;
cout << "Initialize OT Extension\n";
vector<OT_thread_info> tinfos(nthreads);
vector<pthread_t> threads(nthreads);
timeval OTextstart, OTextend;
gettimeofday(&OTextstart, NULL);
// copy base inputs/outputs for each thread
vector<BitVector> base_receiver_input_copy(nthreads);
vector<vector< vector<BitVector> > > base_sender_inputs_copy(nthreads, vector<vector<BitVector> >(nbase, vector<BitVector>(2)));
vector< vector<BitVector> > base_receiver_outputs_copy(nthreads, vector<BitVector>(nbase));
vector<TwoPartyPlayer*> players(nthreads);
for (int i = 0; i < nthreads; i++)
{
tinfos[i].receiverInput.assign(receiverInput);
base_receiver_input_copy[i].assign(baseReceiverInput);
for (int j = 0; j < nbase; j++)
{
base_sender_inputs_copy[i][j][0].assign(bot.sender_inputs[j][0]);
base_sender_inputs_copy[i][j][1].assign(bot.sender_inputs[j][1]);
base_receiver_outputs_copy[i][j].assign(bot.receiver_outputs[j]);
}
// now setup resources for each thread
// round robin with the names
players[i] = new TwoPartyPlayer(N[i%N.size()], 1 - my_num, (i+1) * 1000);
tinfos[i].thread_num = i+1;
tinfos[i].other_player_num = 1 - my_num;
tinfos[i].nOTs = nOTs;
tinfos[i].ot_ext = new OTExtensionWithMatrix(nbase, bot.length(),
nloops, nsubloops,
players[i],
base_receiver_input_copy[i],
base_sender_inputs_copy[i],
base_receiver_outputs_copy[i],
ot_role,
passive);
// create the thread
pthread_create(&threads[i], NULL, run_otext_thread, &tinfos[i]);
// extend base OTs with PRG for the next thread
bot.extend_length();
}
// wait for threads to finish
for (int i = 0; i < nthreads; i++)
{
pthread_join(threads[i],NULL);
cout << "thread " << i+1 << " finished\n" << flush;
}
map<string,long long>& times = tinfos[0].ot_ext->times;
for (map<string,long long>::iterator it = times.begin(); it != times.end(); it++)
{
long long sum = 0;
for (int i = 0; i < nthreads; i++)
sum += tinfos[i].ot_ext->times[it->first];
cout << it->first << " on average took time "
<< double(sum) / nthreads / 1e6 << endl;
}
gettimeofday(&OTextend, NULL);
double totaltime = timeval_diff(&OTextstart, &OTextend);
cout << "Time for OTExt threads (" << role_to_str(ot_role) << "): " << totaltime/1000000 << endl << flush;
if (opt.isSet("-o"))
{
BitVector receiver_output, sender_output;
char filename[1024];
sprintf(filename, RECEIVER_INPUT, my_num);
ofstream outf(filename);
receiverInput.output(outf, false);
outf.close();
sprintf(filename, RECEIVER_OUTPUT, my_num);
outf.open(filename);
for (unsigned int i = 0; i < nOTs; i++)
{
receiver_output.assign_bytes((char*) tinfos[0].ot_ext->get_receiver_output(i), sizeof(__m128i));
receiver_output.output(outf, false);
}
outf.close();
for (int i = 0; i < 2; i++)
{
sprintf(filename, SENDER_OUTPUT, my_num, i);
outf.open(filename);
for (int j = 0; j < nOTs; j++)
{
sender_output.assign_bytes((char*) tinfos[0].ot_ext->get_sender_output(i, j), sizeof(__m128i));
sender_output.output(outf, false);
}
outf.close();
}
}
for (int i = 0; i < nthreads; i++)
{
delete players[i];
delete tinfos[i].ot_ext;
}
}

33
OT/OTMachine.h Normal file
View File

@@ -0,0 +1,33 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OTMachine.h
*
*/
#ifndef OT_OTMACHINE_H_
#define OT_OTMACHINE_H_
#include "OT/OTExtension.h"
#include "Tools/ezOptionParser.h"
class OTMachine
{
ez::ezOptionParser opt;
OT_ROLE ot_role;
public:
int my_num, portnum_base, nthreads, nloops, nsubloops, nbase;
long nOTs;
bool passive;
TwoPartyPlayer* P;
BitVector baseReceiverInput;
BaseOT* bot_;
vector<Names> N;
OTMachine(int argc, const char** argv);
~OTMachine();
void run();
};
#endif /* OT_OTMACHINE_H_ */

164
OT/OTMultiplier.cpp Normal file
View File

@@ -0,0 +1,164 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OTMultiplier.cpp
*
*/
#include "OT/OTMultiplier.h"
#include "OT/NPartyTripleGenerator.h"
#include <math.h>
template<class T>
OTMultiplier<T>::OTMultiplier(NPartyTripleGenerator& generator,
int thread_num) :
generator(generator), thread_num(thread_num),
rot_ext(128, 128, 0, 1,
generator.players[thread_num], generator.baseReceiverInput,
generator.baseSenderInputs[thread_num],
generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check)
{
c_output.resize(generator.nTriplesPerLoop);
pthread_mutex_init(&mutex, 0);
pthread_cond_init(&ready, 0);
thread = 0;
}
template<class T>
OTMultiplier<T>::~OTMultiplier()
{
pthread_mutex_destroy(&mutex);
pthread_cond_destroy(&ready);
}
template<class T>
void OTMultiplier<T>::multiply()
{
BitVector keyBits(generator.field_size);
keyBits.set_int128(0, generator.machine.get_mac_key<T>().to_m128i());
rot_ext.extend<T>(generator.field_size, keyBits);
vector< vector<BitVector> > senderOutput(128);
vector<BitVector> receiverOutput;
for (int j = 0; j < 128; j++)
{
senderOutput[j].resize(2);
for (int i = 0; i < 2; i++)
{
senderOutput[j][i].resize(128);
senderOutput[j][i].set_int128(0, rot_ext.senderOutputMatrices[i].squares[0].rows[j]);
}
}
rot_ext.receiverOutputMatrix.to(receiverOutput);
OTExtensionWithMatrix auth_ot_ext(128, 128, 0, 1,
generator.players[thread_num], keyBits, senderOutput,
receiverOutput, BOTH, true);
if (generator.machine.generateBits)
multiplyForBits(auth_ot_ext);
else
multiplyForTriples(auth_ot_ext);
}
template<class T>
void OTMultiplier<T>::multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext)
{
auth_ot_ext.resize(generator.nPreampTriplesPerLoop * generator.field_size);
// dummy input for OT correlator
vector<BitVector> _;
vector< vector<BitVector> > __;
BitVector ___;
OTExtensionWithMatrix otCorrelator(0, 0, 0, 0, generator.players[thread_num],
___, __, _, BOTH, true);
otCorrelator.resize(128 * generator.nPreampTriplesPerLoop);
rot_ext.resize(generator.field_size * generator.nPreampTriplesPerLoop + 2 * 128);
pthread_mutex_lock(&mutex);
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
for (int i = 0; i < generator.nloops; i++)
{
BitVector aBits = generator.valueBits[0];
//timers["Extension"].start();
rot_ext.extend<T>(generator.field_size * generator.nPreampTriplesPerLoop, aBits);
//timers["Extension"].stop();
//timers["Correlation"].start();
otCorrelator.baseReceiverInput = aBits;
otCorrelator.setup_for_correlation(rot_ext.senderOutputMatrices, rot_ext.receiverOutputMatrix);
otCorrelator.correlate<T>(0, generator.nPreampTriplesPerLoop, generator.valueBits[1], false, generator.nAmplify);
//timers["Correlation"].stop();
//timers["Triple computation"].start();
otCorrelator.reduce_squares(generator.nPreampTriplesPerLoop, c_output);
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
if (generator.machine.generateMACs)
{
macs.resize(3);
for (int j = 0; j < 3; j++)
{
int nValues = generator.nTriplesPerLoop;
if (generator.machine.check && (j % 2 == 0))
nValues *= 2;
auth_ot_ext.expand<T>(0, nValues);
auth_ot_ext.correlate<T>(0, nValues, generator.valueBits[j], true);
auth_ot_ext.reduce_squares(nValues, macs[j]);
}
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
}
}
pthread_mutex_unlock(&mutex);
}
template<>
void OTMultiplier<gfp>::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext)
{
multiplyForTriples(auth_ot_ext);
}
template<>
void OTMultiplier<gf2n>::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext)
{
int nBits = generator.nTriplesPerLoop + generator.field_size;
int nBlocks = ceil(1.0 * nBits / generator.field_size);
auth_ot_ext.resize(nBlocks * generator.field_size);
macs.resize(1);
macs[0].resize(nBits);
pthread_mutex_lock(&mutex);
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
for (int i = 0; i < generator.nloops; i++)
{
auth_ot_ext.expand<gf2n>(0, nBlocks);
auth_ot_ext.correlate<gf2n>(0, nBlocks, generator.valueBits[0], true);
auth_ot_ext.transpose(0, nBlocks);
for (int j = 0; j < nBits; j++)
{
int128 r = auth_ot_ext.receiverOutputMatrix.squares[j/128].rows[j%128];
int128 s = auth_ot_ext.senderOutputMatrices[0].squares[j/128].rows[j%128];
macs[0][j] = r ^ s;
}
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
}
pthread_mutex_unlock(&mutex);
}
template class OTMultiplier<gf2n>;
template class OTMultiplier<gfp>;

41
OT/OTMultiplier.h Normal file
View File

@@ -0,0 +1,41 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OTMultiplier.h
*
*/
#ifndef OT_OTMULTIPLIER_H_
#define OT_OTMULTIPLIER_H_
#include <vector>
using namespace std;
#include "OT/OTExtensionWithMatrix.h"
#include "Tools/random.h"
class NPartyTripleGenerator;
template <class T>
class OTMultiplier
{
void multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext);
void multiplyForBits(OTExtensionWithMatrix& auth_ot_ext);
public:
NPartyTripleGenerator& generator;
int thread_num;
OTExtensionWithMatrix rot_ext;
//OTExtensionWithMatrix* auth_ot_ext;
vector<T> c_output;
vector< vector<T> > macs;
pthread_t thread;
pthread_mutex_t mutex;
pthread_cond_t ready;
OTMultiplier(NPartyTripleGenerator& generator, int thread_num);
~OTMultiplier();
void multiply();
};
#endif /* OT_OTMULTIPLIER_H_ */

44
OT/OTTripleSetup.cpp Normal file
View File

@@ -0,0 +1,44 @@
// (C) 2016 University of Bristol. See License.txt
#include "OTTripleSetup.h"
void OTTripleSetup::setup()
{
timeval baseOTstart, baseOTend;
gettimeofday(&baseOTstart, NULL);
G.ReSeed();
for (int i = 0; i < nbase; i++)
{
base_receiver_inputs[i] = G.get_uchar() & 1;
}
//baseReceiverInput.randomize(G);
for (int i = 0; i < nparties - 1; i++)
{
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
baseOTs[i]->exec_base(false);
}
gettimeofday(&baseOTend, NULL);
double basetime = timeval_diff(&baseOTstart, &baseOTend);
cout << "\t\tBaseTime: " << basetime/1000000 << endl << flush;
// Receiver send something to force synchronization
// (since Sender finishes baseOTs before Receiver)
}
void OTTripleSetup::close_connections()
{
for (size_t i = 0; i < players.size(); i++)
{
delete players[i];
}
}
OTTripleSetup::~OTTripleSetup()
{
for (size_t i = 0; i < baseOTs.size(); i++)
{
delete baseOTs[i];
}
}

91
OT/OTTripleSetup.h Normal file
View File

@@ -0,0 +1,91 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef OT_TRIPLESETUP_H_
#define OT_TRIPLESETUP_H_
#include "Networking/Player.h"
#include "OT/BaseOT.h"
#include "OT/OTMachine.h"
#include "Tools/random.h"
#include "Tools/time-func.h"
#include "Math/gfp.h"
/*
* Class for creating and storing base OTs between every pair of parties.
*/
class OTTripleSetup
{
vector<int> base_receiver_inputs;
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
PRNG G;
int nparties;
int my_num;
int nbase;
bool real_OTs;
public:
map<string,Timer> timers;
vector<BaseOT*> baseOTs;
vector<TwoPartyPlayer*> players;
int get_nparties() { return nparties; }
int get_nbase() { return nbase; }
int get_my_num() { return my_num; }
int get_base_receiver_input(int i) { return base_receiver_inputs[i]; }
OTTripleSetup(Names& N, bool real_OTs)
: nparties(N.num_players()), my_num(N.my_num()), nbase(128), real_OTs(real_OTs)
{
base_receiver_inputs.resize(nbase);
players.resize(nparties - 1);
baseOTs.resize(nparties - 1);
baseSenderInputs.resize(nparties - 1);
baseReceiverOutputs.resize(nparties - 1);
if (real_OTs)
cout << "Doing real base OTs\n";
else
cout << "Doing fake base OTs\n";
for (int i = 0; i < nparties - 1; i++)
{
int other_player, id;
// i for indexing, other_player is actual number
if (i >= my_num)
other_player = i + 1;
else
other_player = i;
// unique id per pair of parties (to assign port no.)
if (my_num < other_player)
id = my_num*nparties + other_player;
else
id = other_player*nparties + my_num;
players[i] = new TwoPartyPlayer(N, other_player, id);
// sets up a pair of base OTs, playing both roles
if (real_OTs)
{
baseOTs[i] = new BaseOT(nbase, 128, players[i]);
}
else
{
baseOTs[i] = new FakeOT(nbase, 128, players[i]);
}
}
}
~OTTripleSetup();
// run the Base OTs
void setup();
// close down the sockets
void close_connections();
//template <class T>
//T get_mac_key();
};
#endif

13
OT/OText_main.cpp Normal file
View File

@@ -0,0 +1,13 @@
// (C) 2016 University of Bristol. See License.txt
/*
* OText_main.cpp
*
*/
#include "OTMachine.h"
int main(int argc, const char** argv)
{
OTMachine(argc, argv).run();
}

15
OT/OutputCheck.h Normal file
View File

@@ -0,0 +1,15 @@
// (C) 2016 University of Bristol. See License.txt
/*
* check.h
*
*/
#ifndef OT_OUTPUTCHECK_H_
#define OT_OUTPUTCHECK_H_
#define RECEIVER_INPUT "Player-Data/OT-receiver%d-input"
#define RECEIVER_OUTPUT "Player-Data/OT-receiver%d-output"
#define SENDER_OUTPUT "Player-Data/OT-sender%d-output%d"
#endif /* OT_OUTPUTCHECK_H_ */

107
OT/Tools.cpp Normal file
View File

@@ -0,0 +1,107 @@
// (C) 2016 University of Bristol. See License.txt
#include "Tools.h"
#include "Math/gf2nlong.h"
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len)
{
PRNG G;
G.ReSeed();
vector<octetStream> seed_strm(2);
vector<octetStream> Comm_seed(2);
vector<octetStream> Open_seed(2);
G.get_octetStream(seed_strm[0], len);
Commit(Comm_seed[0], Open_seed[0], seed_strm[0], player.my_num());
player.send_receive_player(Comm_seed);
player.send_receive_player(Open_seed);
memset(seed, 0, len*sizeof(octet));
if (!Open(seed_strm[1], Comm_seed[1], Open_seed[1], player.other_player_num()))
{
throw invalid_commitment();
}
for (int i = 0; i < len; i++)
{
seed[i] = seed_strm[0].get_data()[i] ^ seed_strm[1].get_data()[i];
}
}
void shiftl128(word x1, word x2, word& res1, word& res2, size_t k)
{
if (k > 128)
throw invalid_length();
if (k >= 64) // shifting a 64-bit integer by more than 63 bits is "undefined"
{
x1 = x2;
x2 = 0;
shiftl128(x1, x2, res1, res2, k - 64);
}
else
{
res1 = (x1 << k) | (x2 >> (64-k));
res2 = (x2 << k);
}
}
// reduce modulo x^128 + x^7 + x^2 + x + 1
// NB this is incorrect as it bit-reflects the result as required for
// GCM mode
void gfred128(__m128i tmp3, __m128i tmp6, __m128i *res)
{
__m128i tmp2, tmp4, tmp5, tmp7, tmp8, tmp9;
tmp7 = _mm_srli_epi32(tmp3, 31);
tmp8 = _mm_srli_epi32(tmp6, 31);
tmp3 = _mm_slli_epi32(tmp3, 1);
tmp6 = _mm_slli_epi32(tmp6, 1);
tmp9 = _mm_srli_si128(tmp7, 12);
tmp8 = _mm_slli_si128(tmp8, 4);
tmp7 = _mm_slli_si128(tmp7, 4);
tmp3 = _mm_or_si128(tmp3, tmp7);
tmp6 = _mm_or_si128(tmp6, tmp8);
tmp6 = _mm_or_si128(tmp6, tmp9);
tmp7 = _mm_slli_epi32(tmp3, 31);
tmp8 = _mm_slli_epi32(tmp3, 30);
tmp9 = _mm_slli_epi32(tmp3, 25);
tmp7 = _mm_xor_si128(tmp7, tmp8);
tmp7 = _mm_xor_si128(tmp7, tmp9);
tmp8 = _mm_srli_si128(tmp7, 4);
tmp7 = _mm_slli_si128(tmp7, 12);
tmp3 = _mm_xor_si128(tmp3, tmp7);
tmp2 = _mm_srli_epi32(tmp3, 1);
tmp4 = _mm_srli_epi32(tmp3, 2);
tmp5 = _mm_srli_epi32(tmp3, 7);
tmp2 = _mm_xor_si128(tmp2, tmp4);
tmp2 = _mm_xor_si128(tmp2, tmp5);
tmp2 = _mm_xor_si128(tmp2, tmp8);
tmp3 = _mm_xor_si128(tmp3, tmp2);
tmp6 = _mm_xor_si128(tmp6, tmp3);
*res = tmp6;
}
// Based on Intel's code for GF(2^128) mul, with reduction
void gfmul128 (__m128i a, __m128i b, __m128i *res)
{
__m128i tmp3, tmp6;
mul128(a, b, &tmp3, &tmp6);
// Now do the reduction
gfred128(tmp3, tmp6, res);
}
string word_to_bytes(const word w)
{
stringstream ss;
octet* bytes = (octet*) &w;
ss << hex;
for (unsigned int i = 0; i < sizeof(word); i++)
ss << (int)bytes[i] << " ";
return ss.str();
}

51
OT/Tools.h Normal file
View File

@@ -0,0 +1,51 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _OTTOOLS
#define _OTTOOLS
#include "Networking/Player.h"
#include "Tools/Commit.h"
#include "Tools/random.h"
#define SEED_SIZE_BYTES SEED_SIZE
/*
* Generate a secure, random seed between 2 parties via commitment
*/
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len);
/*
* GF(2^128) multiplication using Intel instructions
* (should this go in gf2n class???)
*/
void gfmul128(__m128i a, __m128i b, __m128i *res);
void gfred128(__m128i a1, __m128i a2, __m128i *res);
//#if defined(__SSE2__)
/*
* Convert __m128i to string of type T
*/
template <typename T>
string __m128i_toString(const __m128i var) {
stringstream sstr;
sstr << hex;
const T* values = (const T*) &var;
if (sizeof(T) == 1) {
for (unsigned int i = 0; i < sizeof(__m128i); i++) {
sstr << (int) values[i] << " ";
}
} else {
for (unsigned int i = 0; i < sizeof(__m128i) / sizeof(T); i++) {
sstr << values[i] << " ";
}
}
return sstr.str();
}
//#endif
string word_to_bytes(const word w);
void shiftl128(word x1, word x2, word& res1, word& res2, size_t k);
#endif

270
OT/TripleMachine.cpp Normal file
View File

@@ -0,0 +1,270 @@
// (C) 2016 University of Bristol. See License.txt
/*
* TripleMachine.cpp
*
*/
#include <OT/TripleMachine.h>
#include "OT/NPartyTripleGenerator.h"
#include "OT/OTMachine.h"
#include "OT/OTTripleSetup.h"
#include "Math/gf2n.h"
#include "Math/Setup.h"
#include "Tools/ezOptionParser.h"
#include "Math/Setup.h"
#include <iostream>
#include <fstream>
using namespace std;
template <class T>
void* run_ngenerator_thread(void* ptr)
{
((NPartyTripleGenerator*)ptr)->generate<T>();
return 0;
}
TripleMachine::TripleMachine(int argc, const char** argv)
{
ez::ezOptionParser opt;
opt.add(
"2", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of parties (default: 2).", // Help description.
"-N", // Flag token.
"--nparties" // Flag token.
);
opt.add(
"", // Default.
1, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number, 0/1 (required).", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.add(
"1",
0,
1,
0,
"Number of threads (default: 1).",
"-x",
"--nthreads"
);
opt.add(
"1000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of triples (default: 1000).", // Help description.
"-n", // Flag token.
"--ntriples" // Flag token.
);
opt.add(
"1", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of loops (default: 1).", // Help description.
"-l", // Flag token.
"--nloops" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Generate MACs (implies -a).", // Help description.
"-m", // Flag token.
"--macs" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Amplify triples.", // Help description.
"-a", // Flag token.
"--amplify" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Check triples (implies -m).", // Help description.
"-c", // Flag token.
"--check" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"GF(p) triples", // Help description.
"-P", // Flag token.
"--prime-field" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Channel bonding", // Help description.
"-b", // Flag token.
"--bonding" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Generate bits", // Help description.
"-B", // Flag token.
"--bits" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Write results to files.", // Help description.
"-o", // Flag token.
"--output" // Flag token.
);
opt.parse(argc, argv);
opt.get("-p")->getInt(my_num);
opt.get("-N")->getInt(nplayers);
opt.get("-x")->getInt(nthreads);
opt.get("-n")->getInt(ntriples);
opt.get("-l")->getInt(nloops);
generateBits = opt.get("-B")->isSet;
check = opt.get("-c")->isSet || generateBits;
generateMACs = opt.get("-m")->isSet || check;
amplify = opt.get("-a")->isSet || generateMACs;
primeField = opt.get("-P")->isSet;
bonding = opt.get("-b")->isSet;
output = opt.get("-o")->isSet;
if (!opt.isSet("-p"))
{
string usage;
opt.getUsage(usage);
cout << usage;
exit(0);
}
nTriplesPerThread = DIV_CEIL(ntriples, nthreads);
prep_data_dir = get_prep_dir(nplayers, 128, 128);
ofstream outf;
bigint p;
generate_online_setup(outf, prep_data_dir, p, 128, 128);
// doesn't work with Montgomery multiplication
gfp::init_field(p, false);
gf2n::init_field(128);
PRNG G;
G.ReSeed();
mac_key2.randomize(G);
mac_keyp.randomize(G);
}
void TripleMachine::run()
{
cout << "my_num: " << my_num << endl;
Names N[2];
N[0].init(my_num, nplayers, 10000, "HOSTS");
int nConnections = 1;
if (bonding)
{
N[1].init(my_num, nplayers, 11000, "HOSTS2");
nConnections = 2;
}
// do the base OTs
OTTripleSetup setup(N[0], false);
setup.setup();
setup.close_connections();
vector<NPartyTripleGenerator*> generators(nthreads);
vector<pthread_t> threads(nthreads);
for (int i = 0; i < nthreads; i++)
{
generators[i] = new NPartyTripleGenerator(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this);
}
ntriples = generators[0]->nTriples * nthreads;
cout <<"Setup generators\n";
for (int i = 0; i < nthreads; i++)
{
// lock before starting thread to avoid race condition
generators[i]->lock();
if (primeField)
pthread_create(&threads[i], 0, run_ngenerator_thread<gfp>, generators[i]);
else
pthread_create(&threads[i], 0, run_ngenerator_thread<gf2n>, generators[i]);
}
// wait for initialization, then start clock and computation
for (int i = 0; i < nthreads; i++)
generators[i]->wait();
cout << "Starting computation" << endl;
gettimeofday(&start, 0);
for (int i = 0; i < nthreads; i++)
{
generators[i]->signal();
generators[i]->unlock();
}
// wait for threads to finish
for (int i = 0; i < nthreads; i++)
{
pthread_join(threads[i],NULL);
cout << "thread " << i+1 << " finished\n" << flush;
}
map<string,Timer>& timers = generators[0]->timers;
for (map<string,Timer>::iterator it = timers.begin(); it != timers.end(); it++)
{
double sum = 0;
for (size_t i = 0; i < generators.size(); i++)
sum += generators[i]->timers[it->first].elapsed();
cout << it->first << " on average took time "
<< sum / generators.size() << endl;
}
gettimeofday(&stop, 0);
double time = timeval_diff_in_seconds(&start, &stop);
cout << "Time: " << time << endl;
cout << "Throughput: " << ntriples / time << endl;
for (size_t i = 0; i < generators.size(); i++)
delete generators[i];
output_mac_keys();
}
void TripleMachine::output_mac_keys()
{
stringstream ss;
ss << prep_data_dir << "Player-MAC-Keys-P" << my_num;
cout << "Writing MAC key to " << ss.str() << endl;
ofstream outputFile(ss.str().c_str());
outputFile << nplayers << endl;
outputFile << mac_keyp << " " << mac_key2 << endl;
}
template<> gf2n TripleMachine::get_mac_key()
{
return mac_key2;
}
template<> gfp TripleMachine::get_mac_key()
{
return mac_keyp;
}

40
OT/TripleMachine.h Normal file
View File

@@ -0,0 +1,40 @@
// (C) 2016 University of Bristol. See License.txt
/*
* TripleMachine.h
*
*/
#ifndef OT_TRIPLEMACHINE_H_
#define OT_TRIPLEMACHINE_H_
#include "Math/gf2n.h"
#include "Math/gfp.h"
class TripleMachine
{
gf2n mac_key2;
gfp mac_keyp;
public:
int my_num, nplayers, nthreads, ntriples, nloops;
int nTriplesPerThread;
string prep_data_dir;
bool generateMACs;
bool amplify;
bool check;
bool primeField;
bool bonding;
bool generateBits;
bool output;
struct timeval start, stop;
TripleMachine(int argc, const char** argv);
void run();
template <class T>
T get_mac_key();
void output_mac_keys();
};
#endif /* OT_TRIPLEMACHINE_H_ */

179
Player-Online.cpp Normal file
View File

@@ -0,0 +1,179 @@
// (C) 2016 University of Bristol. See License.txt
#include "Processor/Machine.h"
#include "Tools/ezOptionParser.h"
#include <iostream>
#include <map>
#include <string>
using namespace std;
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
opt.syntax = "./Player-Online.x [OPTIONS] <playernum> <progname>\n";
opt.example = "./Player-Online.x -lgp 64 -lg2 128 -m new 0 sample-prog\n./Player-Online.x -pn 13000 -h localhost 1 sample-prog\n";
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(p) field (default: 128)", // Help description.
"-lgp", // Flag token.
"--lgp" // Flag token.
);
opt.add(
"40", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(2^n) field (default: 40)", // Help description.
"-lg2", // Flag token.
"--lg2" // Flag token.
);
opt.add(
"5000", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Port number base to attempt to start connections from (default: 5000)", // Help description.
"-pn", // Flag token.
"--portnumbase" // Flag token.
);
opt.add(
"localhost", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Host where Server.x is running (default: localhost)", // Help description.
"-h", // Flag token.
"--hostname" // Flag token.
);
opt.add(
"empty", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Where to obtain memory, new|old|empty (default: empty)\n\t"
"new: copy from Player-Memory-P<i> file\n\t"
"old: reuse previous memory in Memory-P<i>\n\t"
"empty: create new empty memory", // Help description.
"-m", // Flag token.
"--memory" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Direct communication instead of star-shaped", // Help description.
"-d", // Flag token.
"--direct" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Star-shaped communication handled by background threads", // Help description.
"-P", // Flag token.
"--parallel" // Flag token.
);
opt.add(
"0", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Sum at most n shares at once when using indirect communication", // Help description.
"-s", // Flag token.
"--opening-sum" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use player-specific threads for communication", // Help description.
"-t", // Flag token.
"--threads" // Flag token.
);
opt.add(
"0", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Maximum number of parties to send to at once", // Help description.
"-b", // Flag token.
"--max-broadcast" // Flag token.
);
opt.parse(argc, argv);
vector<string*> allArgs(opt.firstArgs);
allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end());
string progname;
int playerno;
string usage;
vector<string> badOptions;
unsigned int i;
if (allArgs.size() != 3)
{
cerr << "ERROR: incorrect number of arguments to Player-Online.x\n";
cerr << "Arguments given were:\n";
for (unsigned int j = 1; j < allArgs.size(); j++)
cout << "'" << *allArgs[j] << "'" << endl;
opt.getUsage(usage);
cout << usage;
return 1;
}
else
{
playerno = atoi(allArgs[1]->c_str());
progname = *allArgs[2];
}
if(!opt.gotRequired(badOptions))
{
for (i=0; i < badOptions.size(); ++i)
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
return 1;
}
if(!opt.gotExpected(badOptions))
{
for(i=0; i < badOptions.size(); ++i)
cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << ".";
opt.getUsage(usage);
cout << usage;
return 1;
}
string memtype, hostname;
int lg2, lgp, pnbase, opening_sum, max_broadcast;
opt.get("--portnumbase")->getInt(pnbase);
opt.get("--lgp")->getInt(lgp);
opt.get("--lg2")->getInt(lg2);
opt.get("--memory")->getString(memtype);
opt.get("--hostname")->getString(hostname);
opt.get("--opening-sum")->getInt(opening_sum);
opt.get("--max-broadcast")->getInt(max_broadcast);
Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2,
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
opt.get("--threads")->isSet, max_broadcast).run();
cerr << "Command line:";
for (int i = 0; i < argc; i++)
cerr << " " << argv[i];
cerr << endl;
}

145
Processor/Buffer.cpp Normal file
View File

@@ -0,0 +1,145 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Buffer.cpp
*
*/
#include "Buffer.h"
#include "Processor/InputTuple.h"
bool BufferBase::rewind = false;
void BufferBase::setup(ifstream* f, int length, const char* type)
{
file = f;
tuple_length = length;
data_type = type;
}
void BufferBase::seekg(int pos)
{
file->seekg(pos * tuple_length);
if (file->eof() || file->fail())
{
file->clear();
file->seekg(0);
if (!rewind)
cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl;
rewind = true;
}
next = BUFFER_SIZE;
}
template<class T, class U>
Buffer<T, U>::~Buffer()
{
if (timer.elapsed() && data_type)
cerr << T::type_string() << " " << data_type << " reading: "
<< timer.elapsed() << endl;
}
template<class T, class U>
void Buffer<T, U>::fill_buffer()
{
if (T::size() == sizeof(T))
{
// read directly
read((char*)buffer);
}
else
{
char read_buffer[sizeof(buffer)];
read(read_buffer);
//memset(buffer, 0, sizeof(buffer));
for (int i = 0; i < BUFFER_SIZE; i++)
buffer[i].assign(&read_buffer[i*T::size()]);
}
}
template<class T, class U>
void Buffer<T, U>::read(char* read_buffer)
{
int size_in_bytes = T::size() * BUFFER_SIZE;
int n_read = 0;
timer.start();
do
{
file->read(read_buffer + n_read, size_in_bytes - n_read);
n_read += file->gcount();
if (file->eof())
{
file->clear(); // unset EOF flag
file->seekg(0);
if (!rewind)
cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl;
rewind = true;
eof = true;
}
if (file->fail())
{
stringstream ss;
ss << "IO problem when buffering " << T::type_string();
if (data_type)
ss << " " << data_type;
throw file_error(ss.str());
}
}
while (n_read < size_in_bytes);
timer.stop();
}
template <class T, class U>
void Buffer<T,U>::input(U& a)
{
if (next == BUFFER_SIZE)
{
fill_buffer();
next = 0;
}
a = buffer[next];
next++;
}
template < template<class T> class U, template<class T> class V >
BufferBase& BufferHelper<U,V>::get_buffer(DataFieldType field_type)
{
if (field_type == DATA_MODP)
return bufferp;
else if (field_type == DATA_GF2N)
return buffer2;
else
throw not_implemented();
}
template < template<class T> class U, template<class T> class V >
void BufferHelper<U,V>::setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type)
{
files[field_type] = new ifstream(filename.c_str(), ios::in | ios::binary);
if (files[field_type]->fail())
throw file_error(filename);
get_buffer(field_type).setup(files[field_type], tuple_length, data_type);
}
template<template<class T> class U, template<class T> class V>
void BufferHelper<U,V>::close()
{
for (int i = 0; i < N_DATA_FIELD_TYPE; i++)
if (files[i])
{
files[i]->close();
delete files[i];
}
}
template class Buffer< Share<gfp>, Share<gfp> >;
template class Buffer< Share<gf2n>, Share<gf2n> >;
template class Buffer< InputTuple<gfp>, RefInputTuple<gfp> >;
template class Buffer< InputTuple<gf2n>, RefInputTuple<gf2n> >;
template class Buffer< gfp, gfp >;
template class Buffer< gf2n, gf2n >;
template class BufferHelper<Share, Share>;
template class BufferHelper<InputTuple, RefInputTuple>;

74
Processor/Buffer.h Normal file
View File

@@ -0,0 +1,74 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Buffer.h
*
*/
#ifndef PROCESSOR_BUFFER_H_
#define PROCESSOR_BUFFER_H_
#include <fstream>
using namespace std;
#include "Math/Share.h"
#include "Math/field_types.h"
#include "Tools/time-func.h"
#ifndef BUFFER_SIZE
#define BUFFER_SIZE 101
#endif
class BufferBase
{
protected:
static bool rewind;
ifstream* file;
int next;
const char* data_type;
Timer timer;
int tuple_length;
public:
bool eof;
BufferBase() : file(0), next(BUFFER_SIZE), data_type(0), tuple_length(-1), eof(false) {};
void setup(ifstream* f, int length, const char* type = 0);
void seekg(int pos);
bool is_up() { return file != 0; }
};
template <class T, class U>
class Buffer : public BufferBase
{
T buffer[BUFFER_SIZE];
void read(char* read_buffer);
public:
~Buffer();
void input(U& a);
void fill_buffer();
};
template < template<class T> class U, template<class T> class V >
class BufferHelper
{
public:
Buffer< U<gfp>, V<gfp> > bufferp;
Buffer< U<gf2n>, V<gf2n> > buffer2;
ifstream* files[N_DATA_FIELD_TYPE];
BufferHelper() { memset(files, 0, sizeof(files)); }
void input(V<gfp>& a) { bufferp.input(a); }
void input(V<gf2n>& a) { buffer2.input(a); }
BufferBase& get_buffer(DataFieldType field_type);
void setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type = 0);
void close();
};
#endif /* PROCESSOR_BUFFER_H_ */

218
Processor/Data_Files.cpp Normal file
View File

@@ -0,0 +1,218 @@
// (C) 2016 University of Bristol. See License.txt
#include "Processor/Data_Files.h"
#include "Processor/Processor.h"
#include <iomanip>
const char* Data_Files::dtype_names[N_DTYPE] = { "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples" };
const char* Data_Files::field_names[] = { "p", "2" };
const char* Data_Files::long_field_names[] = { "gfp", "gf2n" };
const bool Data_Files::implemented[N_DATA_FIELD_TYPE][N_DTYPE] = {
{ true, true, true, true, false, false },
{ true, true, true, true, true, true },
};
const int Data_Files::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 };
Lock Data_Files::tuple_lengths_lock;
map<DataTag, int> Data_Files::tuple_lengths;
void DataPositions::set_num_players(int num_players)
{
files.resize(N_DATA_FIELD_TYPE, vector<int>(N_DTYPE));
inputs.resize(num_players, vector<int>(N_DATA_FIELD_TYPE));
}
void DataPositions::increase(const DataPositions& delta)
{
if (inputs.size() != delta.inputs.size())
throw invalid_length();
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++)
files[field_type][dtype] += delta.files[field_type][dtype];
for (unsigned int j = 0; j < inputs.size(); j++)
inputs[j][field_type] += delta.inputs[j][field_type];
map<DataTag, int>::const_iterator it;
const map<DataTag, int>& delta_ext = delta.extended[field_type];
for (it = delta_ext.begin(); it != delta_ext.end(); it++)
extended[field_type][it->first] += it->second;
}
}
void DataPositions::print_cost() const
{
ifstream file("cost");
double total_cost = 0;
for (int i = 0; i < N_DATA_FIELD_TYPE; i++)
{
cerr << " Type " << Data_Files::field_names[i] << endl;
for (int j = 0; j < N_DTYPE; j++)
{
double cost_per_item = 0;
file >> cost_per_item;
if (cost_per_item < 0)
break;
int items_used = files[i][j];
double cost = items_used * cost_per_item;
total_cost += cost;
cerr.fill(' ');
cerr << " " << setw(10) << cost << " = " << setw(10) << items_used
<< " " << setw(14) << Data_Files::dtype_names[j] << " à " << setw(11)
<< cost_per_item << endl;
}
for (map<DataTag, int>::const_iterator it = extended[i].begin();
it != extended[i].end(); it++)
{
cerr.fill(' ');
cerr << setw(27) << it->second << " " << setw(14) << it->first.get_string() << endl;
}
}
cerr << "Total cost: " << total_cost << endl;
}
int Data_Files::share_length(int field_type)
{
switch (field_type)
{
case DATA_MODP:
return 2 * gfp::t() * sizeof(mp_limb_t);
case DATA_GF2N:
return 2 * sizeof(word);
default:
throw invalid_params();
}
}
int Data_Files::tuple_length(int field_type, int dtype)
{
return tuple_size[dtype] * share_length(field_type);
}
Data_Files::Data_Files(int myn, int n, const string& prep_data_dir) :
usage(n), prep_data_dir(prep_data_dir)
{
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
num_players=n;
my_num=myn;
char filename[1024];
input_buffers = new BufferHelper<Share, Share>[num_players];
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
for (int dtype = 0; dtype < N_DTYPE; dtype++)
{
if (implemented[field_type][dtype])
{
sprintf(filename,(prep_data_dir + "%s-%s-P%d").c_str(),dtype_names[dtype],
field_names[field_type],my_num);
buffers[dtype].setup(DataFieldType(field_type), filename,
tuple_length(field_type, dtype), dtype_names[dtype]);
}
}
for (int i=0; i<num_players; i++)
{
sprintf(filename,(prep_data_dir + "Inputs-%s-P%d-%d").c_str(),
field_names[field_type],my_num,i);
if (i == my_num)
my_input_buffers.setup(DataFieldType(field_type), filename,
share_length(field_type) * 3 / 2);
else
input_buffers[i].setup(DataFieldType(field_type), filename,
share_length(field_type));
}
}
cerr << "done\n";
}
Data_Files::~Data_Files()
{
for (int i = 0; i < N_DTYPE; i++)
buffers[i].close();
for (int i = 0; i < num_players; i++)
input_buffers[i].close();
delete[] input_buffers;
my_input_buffers.close();
for (map<DataTag, BufferHelper<Share, Share> >::iterator it =
extended.begin(); it != extended.end(); it++)
it->second.close();
}
void Data_Files::seekg(DataPositions& pos)
{
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
for (int dtype = 0; dtype < N_DTYPE; dtype++)
if (implemented[field_type][dtype])
buffers[dtype].get_buffer(DataFieldType(field_type)).seekg(pos.files[field_type][dtype]);
for (int j = 0; j < num_players; j++)
if (j == my_num)
my_input_buffers.get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]);
else
input_buffers[j].get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]);
for (map<DataTag, int>::const_iterator it = pos.extended[field_type].begin();
it != pos.extended[field_type].end(); it++)
{
setup_extended(DataFieldType(field_type), it->first);
extended[it->first].get_buffer(DataFieldType(field_type)).seekg(it->second);
}
}
usage = pos;
}
void Data_Files::skip(const DataPositions& pos)
{
DataPositions new_pos = usage;
new_pos.increase(pos);
seekg(new_pos);
}
void Data_Files::setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size)
{
BufferBase& buffer = extended[tag].get_buffer(field_type);
tuple_lengths_lock.lock();
int tuple_length = tuple_lengths[tag];
int my_tuple_length = tuple_size * share_length(field_type);
if (tuple_length > 0)
{
if (tuple_size > 0 && my_tuple_length != tuple_length)
{
stringstream ss;
ss << "Inconsistent size of " << field_names[field_type] << " "
<< tag.get_string() << ": " << my_tuple_length << " vs "
<< tuple_length;
throw Processor_Error(ss.str());
}
}
else
tuple_lengths[tag] = my_tuple_length;
tuple_lengths_lock.unlock();
if (!buffer.is_up())
{
stringstream ss;
ss << prep_data_dir << tag.get_string() << "-" << field_names[field_type] << "-P" << my_num;
extended[tag].setup(field_type, ss.str(), tuple_length);
}
}
template<class T>
void Data_Files::get(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size)
{
usage.extended[T::field_type()][tag] += vector_size;
setup_extended(T::field_type(), tag, regs.size());
for (int j = 0; j < vector_size; j++)
for (unsigned int i = 0; i < regs.size(); i++)
extended[tag].input(proc.get_S_ref<T>(regs[i] + j));
}
template void Data_Files::get<gfp>(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);
template void Data_Files::get<gf2n>(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);

158
Processor/Data_Files.h Normal file
View File

@@ -0,0 +1,158 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Data_Files
#define _Data_Files
/* This class holds the Online data files all in one place
* so the streams are easy to pass around and access
*/
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Math/Share.h"
#include "Math/field_types.h"
#include "Processor/Buffer.h"
#include "Processor/InputTuple.h"
#include "Tools/Lock.h"
#include <fstream>
#include <map>
using namespace std;
enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE };
class DataTag
{
int t[4];
public:
// assume that tag is three integers
DataTag(const int* tag)
{
strncpy((char*)t, (char*)tag, 3 * sizeof(int));
t[3] = 0;
}
string get_string() const
{
return string((char*)t);
}
bool operator<(const DataTag& other) const
{
for (int i = 0; i < 3; i++)
if (t[i] != other.t[i])
return t[i] < other.t[i];
return false;
}
};
struct DataPositions
{
vector< vector<int> > files;
vector< vector<int> > inputs;
map<DataTag, int> extended[N_DATA_FIELD_TYPE];
DataPositions(int num_players = 0) { set_num_players(num_players); }
void set_num_players(int num_players);
void increase(const DataPositions& delta);
void print_cost() const;
};
class Processor;
class Data_Files
{
static map<DataTag, int> tuple_lengths;
static Lock tuple_lengths_lock;
BufferHelper<Share, Share> buffers[N_DTYPE];
BufferHelper<Share, Share>* input_buffers;
BufferHelper<InputTuple, RefInputTuple> my_input_buffers;
map<DataTag, BufferHelper<Share, Share> > extended;
int my_num,num_players;
DataPositions usage;
const string prep_data_dir;
public:
static const char* dtype_names[N_DTYPE];
static const char* field_names[N_DATA_FIELD_TYPE];
static const char* long_field_names[N_DATA_FIELD_TYPE];
static const bool implemented[N_DATA_FIELD_TYPE][N_DTYPE];
static const int tuple_size[N_DTYPE];
static int share_length(int field_type);
static int tuple_length(int field_type, int dtype);
Data_Files(int my_num,int n,const string& prep_data_dir);
~Data_Files();
DataPositions tellg();
void seekg(DataPositions& pos);
void skip(const DataPositions& pos);
template<class T>
bool eof(Dtype dtype);
template<class T>
bool input_eof(int player);
void setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size = 0);
template<class T>
void get(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);
DataPositions get_usage()
{
return usage;
}
template <class T>
void get_three(DataFieldType field_type, Dtype dtype, Share<T>& a, Share<T>& b, Share<T>& c)
{
usage.files[field_type][dtype]++;
buffers[dtype].input(a);
buffers[dtype].input(b);
buffers[dtype].input(c);
}
template <class T>
void get_two(DataFieldType field_type, Dtype dtype, Share<T>& a, Share<T>& b)
{
usage.files[field_type][dtype]++;
buffers[dtype].input(a);
buffers[dtype].input(b);
}
template <class T>
void get_one(DataFieldType field_type, Dtype dtype, Share<T>& a)
{
usage.files[field_type][dtype]++;
buffers[dtype].input(a);
}
template <class T>
void get_input(Share<T>& a,T& x,int i)
{
usage.inputs[i][T::field_type()]++;
RefInputTuple<T> tuple(a, x);
if (i==my_num)
my_input_buffers.input(tuple);
else
input_buffers[i].input(a);
}
};
template<class T> inline
bool Data_Files::eof(Dtype dtype)
{ return buffers[dtype].get_buffer(T::field_type()).eof; }
template<class T> inline
bool Data_Files::input_eof(int player)
{
if (player == my_num)
return my_input_buffers.get_buffer(T::field_type()).eof;
else
return input_buffers[player].get_buffer(T::field_type()).eof;
}
#endif

89
Processor/Input.cpp Normal file
View File

@@ -0,0 +1,89 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Input.cpp
*
*/
#include "Input.h"
#include "Processor.h"
template<class T>
Input<T>::Input(Processor& proc, MAC_Check<T>& mc) :
proc(proc), MC(mc), values_input(0)
{
buffer.setup(&proc.private_input, -1, "private input");
}
template<class T>
Input<T>::~Input()
{
if (timer.elapsed() > 0)
cerr << T::type_string() << " inputs: " << timer.elapsed() << endl;
}
template<class T>
void Input<T>::adjust_mac(Share<T>& share, T& value)
{
T tmp;
tmp.mul(MC.get_alphai(), value);
tmp.add(share.get_mac(),tmp);
share.set_mac(tmp);
}
template<class T>
void Input<T>::start(int player, int n_inputs)
{
if (player == proc.P.my_num())
{
octetStream o;
shares.resize(n_inputs);
for (int i = 0; i < n_inputs; i++)
{
T rr, t;
Share<T>& share = shares[i];
proc.DataF.get_input(share, rr, player);
T xi;
buffer.input(t);
t.sub(t, rr);
t.pack(o);
xi.add(t, share.get_share());
share.set_share(xi);
adjust_mac(share, t);
}
proc.P.send_all(o, true);
values_input += n_inputs;
}
}
template<class T>
void Input<T>::stop(int player, vector<int> targets)
{
T tmp;
if (player == proc.P.my_num())
{
for (unsigned int i = 0; i < targets.size(); i++)
proc.get_S_ref<T>(targets[i]) = shares[i];
}
else
{
T t;
octetStream o;
timer.start();
proc.P.receive_player(player, o, true);
timer.stop();
for (unsigned int i = 0; i < targets.size(); i++)
{
Share<T>& share = proc.get_S_ref<T>(targets[i]);
proc.DataF.get_input(share, t, player);
t.unpack(o);
adjust_mac(share, t);
}
}
}
template class Input<gf2n>;
template class Input<gfp>;

43
Processor/Input.h Normal file
View File

@@ -0,0 +1,43 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Input.h
*
*/
#ifndef PROCESSOR_INPUT_H_
#define PROCESSOR_INPUT_H_
#include <vector>
using namespace std;
#include "Math/Share.h"
#include "Auth/MAC_Check.h"
#include "Processor/Buffer.h"
#include "Tools/time-func.h"
class Processor;
template<class T>
class Input
{
Processor& proc;
MAC_Check<T>& MC;
vector< Share<T> > shares;
Buffer<T,T> buffer;
Timer timer;
void adjust_mac(Share<T>& share, T& value);
public:
int values_input;
Input(Processor& proc, MAC_Check<T>& mc);
~Input();
void start(int player, int n_inputs);
void stop(int player, vector<int> targets);
};
#endif /* PROCESSOR_INPUT_H_ */

42
Processor/InputTuple.h Normal file
View File

@@ -0,0 +1,42 @@
// (C) 2016 University of Bristol. See License.txt
/*
* InputTuple.h
*
*/
#ifndef PROCESSOR_INPUTTUPLE_H_
#define PROCESSOR_INPUTTUPLE_H_
template <class T>
struct InputTuple
{
Share<T> share;
T value;
static int size()
{ return Share<T>::size() + T::size(); }
static string type_string()
{ return T::type_string(); }
void assign(const char* buffer)
{
share.assign(buffer);
value.assign(buffer + Share<T>::size());
}
};
template <class T>
struct RefInputTuple
{
Share<T>& share;
T& value;
RefInputTuple(Share<T>& share, T& value) : share(share), value(value) {}
void operator=(InputTuple<T>& other) { share = other.share; value = other.value; }
};
#endif /* PROCESSOR_INPUTTUPLE_H_ */

1558
Processor/Instruction.cpp Normal file

File diff suppressed because it is too large Load Diff

311
Processor/Instruction.h Normal file
View File

@@ -0,0 +1,311 @@
// (C) 2016 University of Bristol. See License.txt
#ifndef _Instruction
#define _Instruction
/* Class to read and decode an instruction
*/
#include <iostream>
#include <fstream>
#include <vector>
using namespace std;
#include "Processor/Memory.h"
#include "Processor/Data_Files.h"
#include "Networking/Player.h"
#include "Math/Integer.h"
#include "Auth/MAC_Check.h"
class Machine;
class Processor;
/*
* Opcode constants
*
* Whenever these are changed the corresponding dict in Compiler/instructions.py
* MUST also be changed. (+ the documentation)
*/
enum
{
// Load/store
LDI = 0x1,
LDSI = 0x2,
LDMC = 0x3,
LDMS = 0x4,
STMC = 0x5,
STMS = 0x6,
LDMCI = 0x7,
LDMSI = 0x8,
STMCI = 0x9,
STMSI = 0xA,
MOVC = 0xB,
MOVS = 0xC,
PROTECTMEMS = 0xD,
PROTECTMEMC = 0xE,
PROTECTMEMINT = 0xF,
LDMINT = 0xCA,
STMINT = 0xCB,
LDMINTI = 0xCC,
STMINTI = 0xCD,
PUSHINT = 0xCE,
POPINT = 0xCF,
MOVINT = 0xD0,
// Machine
LDTN = 0x10,
LDARG = 0x11,
REQBL = 0x12,
STARG = 0x13,
TIME = 0x14,
START = 0x15,
STOP = 0x16,
USE = 0x17,
USE_INP = 0x18,
RUN_TAPE = 0x19,
JOIN_TAPE = 0x1A,
CRASH = 0x1B,
USE_PREP = 0x1C,
// Addition
ADDC = 0x20,
ADDS = 0x21,
ADDM = 0x22,
ADDCI = 0x23,
ADDSI = 0x24,
SUBC = 0x25,
SUBS = 0x26,
SUBML = 0x27,
SUBMR = 0x28,
SUBCI = 0x29,
SUBSI = 0x2A,
SUBCFI = 0x2B,
SUBSFI = 0x2C,
// Multiplication/division/other arithmetic
MULC = 0x30,
MULM = 0x31,
MULCI = 0x32,
MULSI = 0x33,
DIVC = 0x34,
DIVCI = 0x35,
MODC = 0x36,
MODCI = 0x37,
LEGENDREC = 0x38,
// Open
STARTOPEN = 0xA0,
STOPOPEN = 0xA1,
// Data access
TRIPLE = 0x50,
BIT = 0x51,
SQUARE = 0x52,
INV = 0x53,
INPUTMASK = 0x56,
PREP = 0x57,
// Input
INPUT = 0x60,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
READSOCKETS = 0x64,
WRITESOCKETC = 0x65,
WRITESOCKETS = 0x66,
OPENSOCKET = 0x67,
CLOSESOCKET = 0x68,
// Bitwise logic
ANDC = 0x70,
XORC = 0x71,
ORC = 0x72,
ANDCI = 0x73,
XORCI = 0x74,
ORCI = 0x75,
NOTC = 0x76,
// Bitwise shifts
SHLC = 0x80,
SHRC = 0x81,
SHLCI = 0x82,
SHRCI = 0x83,
// Branching and comparison
JMP = 0x90,
JMPNZ = 0x91,
JMPEQZ = 0x92,
EQZC = 0x93,
LTZC = 0x94,
LTC = 0x95,
GTC = 0x96,
EQC = 0x97,
JMPI = 0x98,
// Integers
LDINT = 0x9A,
ADDINT = 0x9B,
SUBINT = 0x9C,
MULINT = 0x9D,
DIVINT = 0x9E,
// Conversion
CONVINT = 0xC0,
CONVMODP = 0xC1,
// IO
PRINTMEM = 0xB0,
PRINTREG = 0XB1,
RAND = 0xB2,
PRINTREGPLAIN = 0xB3,
PRINTCHR = 0xB4,
PRINTSTR = 0xB5,
PUBINPUT = 0xB6,
RAWOUTPUT = 0xB7,
STARTPRIVATEOUTPUT = 0xB8,
STOPPRIVATEOUTPUT = 0xB9,
PRINTCHRINT = 0xBA,
PRINTSTRINT = 0xBB,
// GF(2^n) versions
// Load/store
GLDI = 0x101,
GLDSI = 0x102,
GLDMC = 0x103,
GLDMS = 0x104,
GSTMC = 0x105,
GSTMS = 0x106,
GLDMCI = 0x107,
GLDMSI = 0x108,
GSTMCI = 0x109,
GSTMSI = 0x10A,
GMOVC = 0x10B,
GMOVS = 0x10C,
GPROTECTMEMS = 0x10D,
GPROTECTMEMC = 0x10E,
// Machine
GREQBL = 0x112,
GUSE_PREP = 0x11C,
// Addition
GADDC = 0x120,
GADDS = 0x121,
GADDM = 0x122,
GADDCI = 0x123,
GADDSI = 0x124,
GSUBC = 0x125,
GSUBS = 0x126,
GSUBML = 0x127,
GSUBMR = 0x128,
GSUBCI = 0x129,
GSUBSI = 0x12A,
GSUBCFI = 0x12B,
GSUBSFI = 0x12C,
// Multiplication/division
GMULC = 0x130,
GMULM = 0x131,
GMULCI = 0x132,
GMULSI = 0x133,
GDIVC = 0x134,
GDIVCI = 0x135,
GMULBITC = 0x136,
GMULBITM = 0x137,
// Open
GSTARTOPEN = 0x1A0,
GSTOPOPEN = 0x1A1,
// Data access
GTRIPLE = 0x150,
GBIT = 0x151,
GSQUARE = 0x152,
GINV = 0x153,
GBITTRIPLE = 0x154,
GBITGF2NTRIPLE = 0x155,
GINPUTMASK = 0x156,
GPREP = 0x157,
// Input
GINPUT = 0x160,
GSTARTINPUT = 0x161,
GSTOPINPUT = 0x162,
GREADSOCKETS = 0x164,
GWRITESOCKETS = 0x166,
// Bitwise logic
GANDC = 0x170,
GXORC = 0x171,
GORC = 0x172,
GANDCI = 0x173,
GXORCI = 0x174,
GORCI = 0x175,
GNOTC = 0x176,
// Bitwise shifts
GSHLCI = 0x182,
GSHRCI = 0x183,
GBITDEC = 0x184,
GBITCOM = 0x185,
// Conversion
GCONVINT = 0x1C0,
GCONVGF2N = 0x1C1,
// IO
GPRINTMEM = 0x1B0,
GPRINTREG = 0X1B1,
GPRINTREGPLAIN = 0x1B3,
GRAWOUTPUT = 0x1B7,
GSTARTPRIVATEOUTPUT = 0x1B8,
GSTOPPRIVATEOUTPUT = 0x1B9,
};
// Register types
enum RegType {
MODP,
GF2N,
INT,
MAX_REG_TYPE,
NONE
};
enum SecrecyType {
SECRET,
CLEAR,
MAX_SECRECY_TYPE
};
struct TempVars {
gf2n ans2; Share<gf2n> Sans2;
gfp ansp; Share<gfp> Sansp;
bigint aa,aa2;
// INPUT and LDSI
gfp rrp,tp,tmpp;
gfp xip;
// GINPUT and GLDSI
gf2n rr2,t2,tmp2;
gf2n xi2;
};
class Instruction
{
int opcode; // The code
int size; // Vector size
int r[3]; // Three possible registers
unsigned int n; // Possible immediate value
vector<int> start; // Values for a start/stop open
public:
// Reads a single instruction from the istream
void parse(istream& s);
// Return whether usage is known
bool get_offline_data_usage(DataPositions& usage);
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
RegType get_reg_type() const;
bool is_direct_memory_access(SecrecyType sec_type) const;
// Returns the maximal register used
int get_max_reg(RegType reg_type) const;
// Returns the memory size used if applicable and known
int get_mem(RegType reg_type, SecrecyType sec_type) const;
friend ostream& operator<<(ostream& s,const Instruction& instr);
// Execute this instruction, updateing the processor and memory
// and streams pointing to the triples etc
void execute(Processor& Proc) const;
};
#endif

362
Processor/Machine.cpp Normal file
View File

@@ -0,0 +1,362 @@
// (C) 2016 University of Bristol. See License.txt
#include "Machine.h"
#include "Exceptions/Exceptions.h"
#include <sys/time.h>
#include "Math/Setup.h"
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <pthread.h>
using namespace std;
Machine::Machine(int my_number, int PortnumBase, string hostname,
string progname_str, string memtype, int lgp, int lg2, bool direct,
int opening_sum, bool parallel, bool receive_threads, int max_broadcast)
: my_number(my_number), nthreads(0), tn(0), numt(0), usage_unknown(false),
progname(progname_str), direct(direct), opening_sum(opening_sum), parallel(parallel),
receive_threads(receive_threads), max_broadcast(max_broadcast)
{
N.init(my_number,PortnumBase,hostname.c_str());
if (opening_sum < 2)
this->opening_sum = N.num_players();
if (max_broadcast < 2)
this->max_broadcast = N.num_players();
// Set up the fields
prep_dir_prefix = get_prep_dir(N.num_players(), lgp, lg2);
read_setup(prep_dir_prefix);
char filename[1024];
int nn;
sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number);
inpf.open(filename);
if (inpf.fail())
{
cerr << "Could not open MAC key file. Perhaps it needs to be generated?\n";
throw file_error(filename);
}
inpf >> nn;
if (nn!=N.num_players())
{ cerr << "KeyGen was last run with " << nn << " players." << endl;
cerr << " - You are running Online with " << N.num_players() << " players." << endl;
exit(1);
}
alphapi.input(inpf,true);
alpha2i.input(inpf,true);
cerr << "MAC Key p = " << alphapi << endl;
cerr << "MAC Key 2 = " << alpha2i << endl;
inpf.close();
// Initialize the global memory
if (memtype.compare("new")==0)
{sprintf(filename, "Player-Data/Player-Memory-P%d", my_number);
ifstream memfile(filename);
if (memfile.fail()) { throw file_error(filename); }
Load_Memory(M2,memfile);
Load_Memory(Mp,memfile);
Load_Memory(Mi,memfile);
memfile.close();
}
else if (memtype.compare("old")==0)
{
sprintf(filename, "Player-Data/Memory-P%d", my_number);
inpf.open(filename,ios::in | ios::binary);
if (inpf.fail()) { throw file_error(); }
inpf >> M2 >> Mp >> Mi;
inpf.close();
}
else if (!(memtype.compare("empty")==0))
{ cerr << "Invalid memory argument" << endl;
exit(1);
}
sprintf(filename, "Programs/Schedules/%s.sch",progname.c_str());
cerr << "Opening file " << filename << endl;
inpf.open(filename);
if (inpf.fail()) { throw file_error("Missing '" + string(filename) + "'. Did you compile '" + progname + "'?"); }
int nprogs;
inpf >> nthreads;
inpf >> nprogs;
// Keep record of used offline data
pos.set_num_players(N.num_players());
cerr << "Number of threads I will run in parallel = " << nthreads << endl;
cerr << "Number of program sequences I need to load = " << nprogs << endl;
// Load in the programs
progs.resize(nprogs,N.num_players());
char threadname[1024];
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
sprintf(filename,"Programs/Bytecode/%s.bc",threadname);
cerr << "Loading program " << i << " from " << filename << endl;
ifstream pinp(filename);
if (pinp.fail()) { throw file_error(filename); }
progs[i].parse(pinp);
pinp.close();
if (progs[i].direct_mem2_s() > M2.size_s())
{
cerr << threadname << " needs more secret mod2 memory, resizing to "
<< progs[i].direct_mem2_s() << endl;
M2.resize_s(progs[i].direct_mem2_s());
}
if (progs[i].direct_memp_s() > Mp.size_s())
{
cerr << threadname << " needs more secret modp memory, resizing to "
<< progs[i].direct_memp_s() << endl;
Mp.resize_s(progs[i].direct_memp_s());
}
if (progs[i].direct_mem2_c() > M2.size_c())
{
cerr << threadname << " needs more clear mod2 memory, resizing to "
<< progs[i].direct_mem2_c() << endl;
M2.resize_c(progs[i].direct_mem2_c());
}
if (progs[i].direct_memp_c() > Mp.size_c())
{
cerr << threadname << " needs more clear modp memory, resizing to "
<< progs[i].direct_memp_c() << endl;
Mp.resize_c(progs[i].direct_memp_c());
}
if (progs[i].direct_memi_c() > Mi.size_c())
{
cerr << threadname << " needs more clear integer memory, resizing to "
<< progs[i].direct_memi_c() << endl;
Mi.resize_c(progs[i].direct_memi_c());
}
}
progs[0].print_offline_cost();
/* Set up the threads */
tinfo.resize(nthreads);
threads.resize(nthreads);
t_mutex.resize(nthreads);
client_ready.resize(nthreads);
server_ready.resize(nthreads);
join_timer.resize(nthreads);
for (int i=0; i<nthreads; i++)
{ pthread_mutex_init(&t_mutex[i],NULL);
pthread_cond_init(&client_ready[i],NULL);
pthread_cond_init(&server_ready[i],NULL);
tinfo[i].thread_num=i;
tinfo[i].Nms=&N;
tinfo[i].alphapi=&alphapi;
tinfo[i].alpha2i=&alpha2i;
tinfo[i].prognum=-2; // Dont do anything until we are ready
tinfo[i].finished=true;
tinfo[i].ready=false;
tinfo[i].machine=this;
// lock for synchronization
pthread_mutex_lock(&t_mutex[i]);
pthread_create(&threads[i],NULL,Main_Func,&tinfo[i]);
}
// synchronize with clients before starting timer
for (int i=0; i<nthreads; i++)
{
while (!tinfo[i].ready)
{
cerr << "Waiting for thread " << i << " to be ready" << endl;
pthread_cond_wait(&client_ready[i],&t_mutex[i]);
}
pthread_mutex_unlock(&t_mutex[i]);
}
}
DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number)
{
pthread_mutex_lock(&t_mutex[thread_number]);
tinfo[thread_number].prognum=tape_number;
tinfo[thread_number].arg=arg;
tinfo[thread_number].pos=pos;
tinfo[thread_number].finished=false;
//printf("Send signal to run program %d in thread %d\n",tape_number,thread_number);
pthread_cond_signal(&server_ready[thread_number]);
pthread_mutex_unlock(&t_mutex[thread_number]);
//printf("Running line %d\n",exec);
if (progs[tape_number].usage_unknown())
{ // only one thread allowed
if (numt>1)
{ cerr << "Line " << line_number << " has " <<
numt << " threads but tape " << tape_number <<
" has unknown offline data usage" << endl;
throw invalid_program();
}
else if (line_number == -1)
{
cerr << "Internally called tape " << tape_number <<
" has unknown offline data usage" << endl;
throw invalid_program();
}
usage_unknown = true;
return DataPositions(N.num_players());
}
else
{
// Bits, Triples, Squares, and Inverses skipping
return progs[tape_number].get_offline_data_used();
}
}
void Machine::join_tape(int i)
{
join_timer[i].start();
pthread_mutex_lock(&t_mutex[i]);
//printf("Waiting for client to terminate\n");
if ((tinfo[i].finished)==false)
{ pthread_cond_wait(&client_ready[i],&t_mutex[i]); }
pthread_mutex_unlock(&t_mutex[i]);
join_timer[i].stop();
}
void Machine::run()
{
Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID);
proc_timer.start();
timer[0].start();
bool flag=true;
usage_unknown=false;
int exec=0;
while (flag)
{ inpf >> numt;
if (numt==0)
{ flag=false; }
else
{ for (int i=0; i<numt; i++)
{
// Now load up data
inpf >> tn;
// Cope with passing an integer parameter to a tape
int arg;
if (inpf.get() == ':')
inpf >> arg;
else
arg = 0;
//cerr << "Run scheduled tape " << tn << " in thread " << i << endl;
pos.increase(run_tape(i, tn, arg, exec));
}
// Make sure all terminate before we continue
for (int i=0; i<numt; i++)
{ join_tape(i);
}
if (usage_unknown)
{ // synchronize files
pos = tinfo[0].pos;
usage_unknown = false;
}
//printf("Finished running line %d\n",exec);
exec++;
}
}
char compiler[1000];
inpf.get();
inpf.getline(compiler, 1000);
if (compiler[0] != 0)
cerr << "Compiler: " << compiler << endl;
inpf.close();
finish_timer.start();
// Tell all C-threads to stop
for (int i=0; i<nthreads; i++)
{ pthread_mutex_lock(&t_mutex[i]);
//printf("Send kill signal to client\n");
tinfo[i].prognum=-1;
tinfo[i].ready = false;
pthread_cond_signal(&server_ready[i]);
pthread_mutex_unlock(&t_mutex[i]);
}
cerr << "Waiting for all clients to finish" << endl;
// Wait until all clients have signed out
for (int i=0; i<nthreads; i++)
{
pthread_mutex_lock(&t_mutex[i]);
tinfo[i].ready = true;
pthread_cond_signal(&server_ready[i]);
pthread_mutex_unlock(&t_mutex[i]);
pthread_join(threads[i],NULL);
pthread_mutex_destroy(&t_mutex[i]);
pthread_cond_destroy(&client_ready[i]);
pthread_cond_destroy(&server_ready[i]);
}
finish_timer.stop();
for (unsigned int i = 0; i < join_timer.size(); i++)
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
cerr << "Process timer: " << proc_timer.elapsed() << endl;
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
timer.erase(0);
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
if (opening_sum < N.num_players() && !direct)
cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl;
else
cerr << "Summed all shares at once" << endl;
if (max_broadcast < N.num_players() && !direct)
cerr << "Send to at most " << max_broadcast << " parties at once" << endl;
else
cerr << "Full broadcast" << endl;
// Reduce memory size to speed up
int max_size = 1 << 20;
if (M2.size_s() > max_size)
M2.resize_s(max_size);
if (Mp.size_s() > max_size)
Mp.resize_s(max_size);
// Write out the memory to use next time
char filename[1024];
sprintf(filename,"Player-Data/Memory-P%d",my_number);
ofstream outf(filename,ios::out | ios::binary);
outf << M2 << Mp << Mi;
outf.close();
extern unsigned long long sent_amount, sent_counter;
cerr << "Data sent = " << sent_amount << " bytes in "
<< sent_counter << " calls,";
cerr << sent_amount / sent_counter / N.num_players()
<< " bytes per call" << endl;
for (int dtype = 0; dtype < N_DTYPE; dtype++)
{
cerr << "Num " << Data_Files::dtype_names[dtype] << "\t=";
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
cerr << " " << pos.files[field_type][dtype];
cerr << endl;
}
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
cerr << "Num " << Data_Files::long_field_names[field_type] << " Inputs\t=";
for (int i = 0; i < N.num_players(); i++)
cerr << " " << pos.inputs[i][field_type];
cerr << endl;
}
cerr << "Total cost of program:" << endl;
pos.print_cost();
cerr << "End of prog" << endl;
}

83
Processor/Machine.h Normal file
View File

@@ -0,0 +1,83 @@
// (C) 2016 University of Bristol. See License.txt
/*
* Machine.h
*
*/
#ifndef MACHINE_H_
#define MACHINE_H_
#include "Processor/Memory.h"
#include "Processor/Program.h"
#include "Processor/Online-Thread.h"
#include "Processor/Data_Files.h"
#include "Math/gfp.h"
#include "Tools/time-func.h"
#include <vector>
#include <map>
using namespace std;
class Machine
{
/* The mutex's lock the C-threads and then only release
* then we an MPC thread is ready to run on the C-thread.
* Control is passed back to the main loop when the
* MPC thread releases the mutex
*/
vector<thread_info> tinfo;
vector<pthread_t> threads;
int my_number;
Names N;
gfp alphapi;
gf2n alpha2i;
int nthreads;
ifstream inpf;
// Keep record of used offline data
DataPositions pos;
int tn,numt;
bool usage_unknown;
public:
vector<pthread_mutex_t> t_mutex;
vector<pthread_cond_t> client_ready;
vector<pthread_cond_t> server_ready;
vector<Program> progs;
Memory<gf2n> M2;
Memory<gfp> Mp;
Memory<Integer> Mi;
std::map<int,Timer> timer;
vector<Timer> join_timer;
Timer finish_timer;
string prep_dir_prefix;
string progname;
bool direct;
int opening_sum;
bool parallel;
bool receive_threads;
int max_broadcast;
Machine(int my_number, int PortnumBase, string hostname, string progname,
string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel,
bool receive_threads, int max_broadcast);
DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number);
void join_tape(int thread_number);
void run();
};
#endif /* MACHINE_H_ */

147
Processor/Memory.cpp Normal file
View File

@@ -0,0 +1,147 @@
// (C) 2016 University of Bristol. See License.txt
#include "Processor/Memory.h"
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Integer.h"
#include <fstream>
#ifdef MEMPROTECT
template<class T>
void Memory<T>::protect_s(unsigned int start, unsigned int end)
{
protected_s.insert(pair<unsigned int,unsigned int>(start, end));
}
template<class T>
void Memory<T>::protect_c(unsigned int start, unsigned int end)
{
protected_c.insert(pair<unsigned int,unsigned int>(start, end));
}
template<class T>
bool Memory<T>::is_protected_s(unsigned int index)
{
for (set< pair<unsigned int,unsigned int> >::iterator it = protected_s.begin();
it != protected_s.end(); it++)
if (it->first <= index and it->second > index)
return true;
return false;
}
template<class T>
bool Memory<T>::is_protected_c(unsigned int index)
{
for (set< pair<unsigned int,unsigned int> >::iterator it = protected_c.begin();
it != protected_c.end(); it++)
if (it->first <= index and it->second > index)
return true;
return false;
}
#endif
template<class T>
ostream& operator<<(ostream& s,const Memory<T>& M)
{
s << M.MS.size() << endl;
s << M.MC.size() << endl;
#ifdef DEBUG
for (unsigned int i=0; i<M.MS.size(); i++)
{ M.MS[i].output(s,true); s << endl; }
s << endl;
for (unsigned int i=0; i<M.MC.size(); i++)
{ M.MC[i].output(s,true); s << endl; }
s << endl;
#else
for (unsigned int i=0; i<M.MS.size(); i++)
{ M.MS[i].output(s,false); }
for (unsigned int i=0; i<M.MC.size(); i++)
{ M.MC[i].output(s,false); }
#endif
return s;
}
template<class T>
istream& operator>>(istream& s,Memory<T>& M)
{
int len;
s >> len;
M.resize_s(len);
s >> len;
M.resize_c(len);
s.seekg(1, istream::cur);
for (unsigned int i=0; i<M.MS.size(); i++)
{ M.MS[i].input(s,false); }
for (unsigned int i=0; i<M.MC.size(); i++)
{ M.MC[i].input(s,false); }
return s;
}
template<class T>
void Load_Memory(Memory<T>& M,ifstream& inpf)
{
int a;
T val;
Share<T> S;
inpf >> a;
M.resize_s(a);
inpf >> a;
M.resize_c(a);
cerr << "Reading Clear Memory" << endl;
// Read clear memory
inpf >> a;
val.input(inpf,true);
while (a!=-1)
{ M.write_C(a,val);
inpf >> a;
val.input(inpf,true);
}
cerr << "Reading Shared Memory" << endl;
// Read shared memory
inpf >> a;
S.input(inpf,true);
while (a!=-1)
{ M.write_S(a,S);
inpf >> a;
S.input(inpf,true);
}
}
template class Memory<gfp>;
template class Memory<gf2n>;
template class Memory<Integer>;
template istream& operator>>(istream& s,Memory<gfp>& M);
template istream& operator>>(istream& s,Memory<gf2n>& M);
template istream& operator>>(istream& s,Memory<Integer>& M);
template ostream& operator<<(ostream& s,const Memory<gfp>& M);
template ostream& operator<<(ostream& s,const Memory<gf2n>& M);
template ostream& operator<<(ostream& s,const Memory<Integer>& M);
template void Load_Memory(Memory<gfp>& M,ifstream& inpf);
template void Load_Memory(Memory<gf2n>& M,ifstream& inpf);
template void Load_Memory(Memory<Integer>& M,ifstream& inpf);
#ifdef USE_GF2N_LONG
template class Memory<gf2n_short>;
template istream& operator>>(istream& s,Memory<gf2n_short>& M);
template ostream& operator<<(ostream& s,const Memory<gf2n_short>& M);
template void Load_Memory(Memory<gf2n_short>& M,ifstream& inpf);
#endif

Some files were not shown because too many files have changed in this diff Show More