mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Initial release.
This commit is contained in:
72
.gitignore
vendored
Normal file
72
.gitignore
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# Offline data, runtime logs #
|
||||
##############################
|
||||
Player-Data/*
|
||||
Prep-Data/*
|
||||
logs/*
|
||||
Language-Definition/main.pdf
|
||||
|
||||
# Personal CONFIG file #
|
||||
##############################
|
||||
CONFIG.mine
|
||||
|
||||
# Compiled source #
|
||||
###################
|
||||
Programs/Bytecode/*
|
||||
Programs/Schedules/*
|
||||
Programs/Public-Input/*
|
||||
*.com
|
||||
*.class
|
||||
*.dll
|
||||
*.exe
|
||||
*.x
|
||||
*.o
|
||||
*.so
|
||||
*.pyc
|
||||
*.bc
|
||||
*.sch
|
||||
*.a
|
||||
|
||||
# Packages #
|
||||
############
|
||||
# it's better to unpack these files and commit the raw source
|
||||
# git has its own built in compression methods
|
||||
*.7z
|
||||
*.dmg
|
||||
*.gz
|
||||
*.iso
|
||||
*.jar
|
||||
*.rar
|
||||
*.tar
|
||||
*.zip
|
||||
|
||||
# Latex #
|
||||
#########
|
||||
*.aux
|
||||
*.lof
|
||||
*.log
|
||||
*.lot
|
||||
*.fls
|
||||
*.out
|
||||
*.toc
|
||||
*.fmt
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.blg
|
||||
|
||||
|
||||
# Logs and databases #
|
||||
######################
|
||||
*.log
|
||||
*.sql
|
||||
*.sqlite
|
||||
|
||||
# OS generated files #
|
||||
######################
|
||||
*~
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "SimpleOT"]
|
||||
path = SimpleOT
|
||||
url = git@github.com:pascholl/SimpleOT.git
|
||||
509
Auth/MAC_Check.cpp
Normal file
509
Auth/MAC_Check.cpp
Normal file
@@ -0,0 +1,509 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Auth/MAC_Check.h"
|
||||
#include "Auth/Subroutines.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/int.h"
|
||||
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
enum mc_timer { SEND, RECV_ADD, BCAST, RECV_SUM, SEED, COMMIT, WAIT_SUMMER, RECV, SUM, SELECT, MAX_TIMER };
|
||||
const char* mc_timer_names[] = {
|
||||
"sending",
|
||||
"receiving and adding",
|
||||
"broadcasting",
|
||||
"receiving summed values",
|
||||
"random seed",
|
||||
"commit and open",
|
||||
"wait for summer thread",
|
||||
"receiving",
|
||||
"summing",
|
||||
"waiting for select()"
|
||||
};
|
||||
|
||||
|
||||
template<class T>
|
||||
MAC_Check<T>::MAC_Check(const T& ai, int opening_sum, int max_broadcast, int send_player) :
|
||||
base_player(send_player), opening_sum(opening_sum), max_broadcast(max_broadcast)
|
||||
{
|
||||
popen_cnt=0;
|
||||
alphai=ai;
|
||||
values_opened=0;
|
||||
timers.resize(MAX_TIMER);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
MAC_Check<T>::~MAC_Check()
|
||||
{
|
||||
for (unsigned int i = 0; i < timers.size(); i++)
|
||||
if (timers[i].elapsed() > 0)
|
||||
cerr << T::type_string() << " " << mc_timer_names[i] << ": "
|
||||
<< timers[i].elapsed() << endl;
|
||||
|
||||
for (unsigned int i = 0; i < player_timers.size(); i++)
|
||||
if (player_timers[i].elapsed() > 0)
|
||||
cerr << T::type_string() << " waiting for " << i << ": "
|
||||
<< player_timers[i].elapsed() << endl;
|
||||
}
|
||||
|
||||
template<class T, int t>
|
||||
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check<T>& MC)
|
||||
{
|
||||
MC.player_timers.resize(P.num_players());
|
||||
vector<octetStream>& oss = MC.oss;
|
||||
oss.resize(P.num_players());
|
||||
vector<int> senders;
|
||||
senders.reserve(P.num_players());
|
||||
|
||||
for (int relative_sender = positive_modulo(P.my_num() - send_player, P.num_players()) + sum_players;
|
||||
relative_sender < last_sum_players; relative_sender += sum_players)
|
||||
{
|
||||
int sender = positive_modulo(send_player + relative_sender, P.num_players());
|
||||
senders.push_back(sender);
|
||||
}
|
||||
|
||||
for (int j = 0; j < (int)senders.size(); j++)
|
||||
P.request_receive(senders[j], oss[j]);
|
||||
|
||||
for (int j = 0; j < (int)senders.size(); j++)
|
||||
{
|
||||
int sender = senders[j];
|
||||
MC.player_timers[sender].start();
|
||||
P.wait_receive(sender, oss[j], true);
|
||||
MC.player_timers[sender].stop();
|
||||
if ((unsigned)oss[j].get_length() < values.size() * T::size())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Not enough information received, expected "
|
||||
<< values.size() * T::size() << " bytes, got "
|
||||
<< oss[j].get_length();
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
MC.timers[SUM].start();
|
||||
for (unsigned int i=0; i<values.size(); i++)
|
||||
{
|
||||
values[i].template add<t>(oss[j].consume(T::size()));
|
||||
}
|
||||
MC.timers[SUM].stop();
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
AddToMacs(S);
|
||||
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
{ values[i]=S[i].get_share(); }
|
||||
|
||||
os.reset_write_head();
|
||||
int sum_players = P.num_players();
|
||||
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
|
||||
while (true)
|
||||
{
|
||||
int last_sum_players = sum_players;
|
||||
sum_players = (sum_players - 2 + opening_sum) / opening_sum;
|
||||
if (sum_players == 0)
|
||||
break;
|
||||
if (my_relative_num >= sum_players && my_relative_num < last_sum_players)
|
||||
{
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
{ values[i].pack(os); }
|
||||
int receiver = positive_modulo(base_player + my_relative_num % sum_players, P.num_players());
|
||||
timers[SEND].start();
|
||||
P.send_to(receiver,os,true);
|
||||
timers[SEND].stop();
|
||||
}
|
||||
|
||||
if (my_relative_num < sum_players)
|
||||
{
|
||||
timers[RECV_ADD].start();
|
||||
if (T::t() == 2)
|
||||
add_openings<T,2>(values, P, sum_players, last_sum_players, base_player, *this);
|
||||
else
|
||||
add_openings<T,0>(values, P, sum_players, last_sum_players, base_player, *this);
|
||||
timers[RECV_ADD].stop();
|
||||
}
|
||||
}
|
||||
|
||||
if (P.my_num() == base_player)
|
||||
{
|
||||
os.reset_write_head();
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
{ values[i].pack(os); }
|
||||
timers[BCAST].start();
|
||||
for (int i = 1; i < max_broadcast && i < P.num_players(); i++)
|
||||
{
|
||||
P.send_to((base_player + i) % P.num_players(), os, true);
|
||||
}
|
||||
timers[BCAST].stop();
|
||||
AddToValues(values);
|
||||
}
|
||||
else if (my_relative_num * max_broadcast < P.num_players())
|
||||
{
|
||||
int sender = (base_player + my_relative_num / max_broadcast) % P.num_players();
|
||||
ReceiveValues(values, P, sender);
|
||||
timers[BCAST].start();
|
||||
for (int i = 0; i < max_broadcast; i++)
|
||||
{
|
||||
int relative_receiver = (my_relative_num * max_broadcast + i);
|
||||
if (relative_receiver < P.num_players())
|
||||
{
|
||||
int receiver = (base_player + relative_receiver) % P.num_players();
|
||||
P.send_to(receiver, os, true);
|
||||
}
|
||||
}
|
||||
timers[BCAST].stop();
|
||||
}
|
||||
|
||||
values_opened += S.size();
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
S.size();
|
||||
int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players());
|
||||
if (my_relative_num * max_broadcast >= P.num_players())
|
||||
{
|
||||
int sender = (base_player + my_relative_num / max_broadcast) % P.num_players();
|
||||
ReceiveValues(values, P, sender);
|
||||
}
|
||||
else
|
||||
GetValues(values);
|
||||
|
||||
popen_cnt += values.size();
|
||||
CheckIfNeeded(P);
|
||||
|
||||
/* not compatible with continuous communication
|
||||
send_player++;
|
||||
if (send_player==P.num_players())
|
||||
{ send_player=0; }
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::AddToMacs(const vector<Share<T> >& shares)
|
||||
{
|
||||
for (unsigned int i = 0; i < shares.size(); i++)
|
||||
macs.push_back(shares[i].get_mac());
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::AddToValues(vector<T>& values)
|
||||
{
|
||||
vals.insert(vals.end(), values.begin(), values.end());
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::ReceiveValues(vector<T>& values, const Player& P, int sender)
|
||||
{
|
||||
timers[RECV_SUM].start();
|
||||
P.receive_player(sender, os, true);
|
||||
timers[RECV_SUM].stop();
|
||||
for (unsigned int i = 0; i < values.size(); i++)
|
||||
values[i].unpack(os);
|
||||
AddToValues(values);
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::GetValues(vector<T>& values)
|
||||
{
|
||||
int size = values.size();
|
||||
if (popen_cnt + size > int(vals.size()))
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "wanted " << values.size() << " values from " << popen_cnt << ", only " << vals.size() << " in store";
|
||||
throw out_of_range(ss.str());
|
||||
}
|
||||
values.clear();
|
||||
typename vector<T>::iterator first = vals.begin() + popen_cnt;
|
||||
values.insert(values.end(), first, first + size);
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::CheckIfNeeded(const Player& P)
|
||||
{
|
||||
if (WaitingForCheck() >= POPEN_MAX)
|
||||
Check(P);
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
void MAC_Check<T>::AddToCheck(const T& mac, const T& value, const Player& P)
|
||||
{
|
||||
CheckIfNeeded(P);
|
||||
macs.push_back(mac);
|
||||
vals.push_back(value);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
void MAC_Check<T>::Check(const Player& P)
|
||||
{
|
||||
if (WaitingForCheck() == 0)
|
||||
return;
|
||||
|
||||
//cerr << "In MAC Check : " << popen_cnt << endl;
|
||||
octet seed[SEED_SIZE];
|
||||
timers[SEED].start();
|
||||
Create_Random_Seed(seed,P,SEED_SIZE);
|
||||
timers[SEED].stop();
|
||||
PRNG G;
|
||||
G.SetSeed(seed);
|
||||
|
||||
Share<T> sj;
|
||||
T a,gami,h,temp;
|
||||
a.assign_zero();
|
||||
gami.assign_zero();
|
||||
vector<T> tau(P.num_players());
|
||||
for (int i=0; i<popen_cnt; i++)
|
||||
{ h.almost_randomize(G);
|
||||
temp.mul(h,vals[i]);
|
||||
a.add(a,temp);
|
||||
|
||||
temp.mul(h,macs[i]);
|
||||
gami.add(gami,temp);
|
||||
}
|
||||
vals.erase(vals.begin(), vals.begin() + popen_cnt);
|
||||
macs.erase(macs.begin(), macs.begin() + popen_cnt);
|
||||
temp.mul(alphai,a);
|
||||
tau[P.my_num()].sub(gami,temp);
|
||||
|
||||
//cerr << "\tCommit and Open" << endl;
|
||||
timers[COMMIT].start();
|
||||
Commit_And_Open(tau,P);
|
||||
timers[COMMIT].stop();
|
||||
|
||||
//cerr << "\tFinal Check" << endl;
|
||||
|
||||
T t;
|
||||
t.assign_zero();
|
||||
for (int i=0; i<P.num_players(); i++)
|
||||
{ t.add(t,tau[i]); }
|
||||
if (!t.is_zero()) { throw mac_fail(); }
|
||||
|
||||
popen_cnt=0;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int mc_base_id(int function_id, int thread_num)
|
||||
{
|
||||
return (function_id << 28) + ((T::field_type() + 1) << 24) + (thread_num << 16);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Separate_MAC_Check<T>::Separate_MAC_Check(const T& ai, Names& Nms,
|
||||
int thread_num, int opening_sum, int max_broadcast, int send_player) :
|
||||
MAC_Check<T>(ai, opening_sum, max_broadcast, send_player),
|
||||
check_player(Nms, mc_base_id<T>(1, thread_num))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Separate_MAC_Check<T>::Check(const Player& P)
|
||||
{
|
||||
P.my_num();
|
||||
MAC_Check<T>::Check(check_player);
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void* run_summer_thread(void* summer)
|
||||
{
|
||||
((Summer<T>*) summer)->run();
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Parallel_MAC_Check<T>::Parallel_MAC_Check(const T& ai, Names& Nms,
|
||||
int thread_num, int opening_sum, int max_broadcast, int base_player) :
|
||||
Separate_MAC_Check<T>(ai, Nms, thread_num, opening_sum, max_broadcast, base_player),
|
||||
send_player(Nms, mc_base_id<T>(2, thread_num)),
|
||||
send_base_player(base_player)
|
||||
{
|
||||
int sum_players = Nms.num_players();
|
||||
Player* summer_send_player = &send_player;
|
||||
for (int i = 0; ; i++)
|
||||
{
|
||||
int last_sum_players = sum_players;
|
||||
sum_players = (sum_players - 2 + opening_sum) / opening_sum;
|
||||
int next_sum_players = (sum_players - 2 + opening_sum) / opening_sum;
|
||||
if (sum_players == 0)
|
||||
break;
|
||||
Player* summer_receive_player = summer_send_player;
|
||||
summer_send_player = new Player(Nms, mc_base_id<T>(3, thread_num));
|
||||
summers.push_back(new Summer<T>(sum_players, last_sum_players, next_sum_players,
|
||||
summer_send_player, summer_receive_player, *this));
|
||||
pthread_create(&(summers[i]->thread), 0, run_summer_thread<T>, summers[i]);
|
||||
}
|
||||
receive_player = summer_send_player;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Parallel_MAC_Check<T>::~Parallel_MAC_Check()
|
||||
{
|
||||
for (unsigned int i = 0; i < summers.size(); i++)
|
||||
{
|
||||
summers[i]->input_queue.stop();
|
||||
pthread_join(summers[i]->thread, 0);
|
||||
delete summers[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Parallel_MAC_Check<T>::POpen_Begin(vector<T>& values,
|
||||
const vector<Share<T> >& S, const Player& P)
|
||||
{
|
||||
values.size();
|
||||
this->AddToMacs(S);
|
||||
|
||||
int my_relative_num = positive_modulo(P.my_num() - send_base_player, P.num_players());
|
||||
int sum_players = (P.num_players() - 2 + this->opening_sum) / this->opening_sum;
|
||||
int receiver = positive_modulo(send_base_player + my_relative_num % sum_players, P.num_players());
|
||||
|
||||
// use queue rather sending to myself
|
||||
if (receiver == P.my_num())
|
||||
{
|
||||
for (unsigned int i = 0; i < S.size(); i++)
|
||||
values[i] = S[i].get_share();
|
||||
summers.front()->share_queue.push(values);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->os.reset_write_head();
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
S[i].get_share().pack(this->os);
|
||||
this->timers[SEND].start();
|
||||
send_player.send_to(receiver,this->os,true);
|
||||
this->timers[SEND].stop();
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < summers.size(); i++)
|
||||
summers[i]->input_queue.push(S.size());
|
||||
|
||||
this->values_opened += S.size();
|
||||
send_base_player = (send_base_player + 1) % send_player.num_players();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Parallel_MAC_Check<T>::POpen_End(vector<T>& values,
|
||||
const vector<Share<T> >& S, const Player& P)
|
||||
{
|
||||
int last_size = 0;
|
||||
this->timers[WAIT_SUMMER].start();
|
||||
summers.back()->output_queue.pop(last_size);
|
||||
this->timers[WAIT_SUMMER].stop();
|
||||
if (int(values.size()) != last_size)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "stopopen wants " << values.size() << " values, but I have " << last_size << endl;
|
||||
throw Processor_Error(ss.str().c_str());
|
||||
}
|
||||
|
||||
if (this->base_player == P.my_num())
|
||||
{
|
||||
value_queue.pop(values);
|
||||
if (int(values.size()) != last_size)
|
||||
throw Processor_Error("wrong number of local values");
|
||||
else
|
||||
this->AddToValues(values);
|
||||
}
|
||||
this->MAC_Check<T>::POpen_End(values, S, *receive_player);
|
||||
this->base_player = (this->base_player + 1) % send_player.num_players();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
Direct_MAC_Check<T>::Direct_MAC_Check(const T& ai, Names& Nms, int num) : Separate_MAC_Check<T>(ai, Nms, num) {
|
||||
open_counter = 0;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
Direct_MAC_Check<T>::~Direct_MAC_Check() {
|
||||
cerr << T::type_string() << " open counter: " << open_counter << endl;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Direct_MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
values.size();
|
||||
this->os.reset_write_head();
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
S[i].get_share().pack(this->os);
|
||||
this->timers[SEND].start();
|
||||
P.send_all(this->os,true);
|
||||
this->timers[SEND].stop();
|
||||
|
||||
this->AddToMacs(S);
|
||||
for (unsigned int i=0; i<S.size(); i++)
|
||||
this->vals.push_back(S[i].get_share());
|
||||
}
|
||||
|
||||
template<class T, int t>
|
||||
void direct_add_openings(vector<T>& values, const Player& P, vector<octetStream>& os)
|
||||
{
|
||||
for (unsigned int i=0; i<values.size(); i++)
|
||||
for (int j=0; j<P.num_players(); j++)
|
||||
if (j!=P.my_num())
|
||||
values[i].template add<t>(os[j].consume(T::size()));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Direct_MAC_Check<T>::POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P)
|
||||
{
|
||||
S.size();
|
||||
oss.resize(P.num_players());
|
||||
this->GetValues(values);
|
||||
|
||||
this->timers[RECV].start();
|
||||
|
||||
for (int j=0; j<P.num_players(); j++)
|
||||
if (j!=P.my_num())
|
||||
P.receive_player(j,oss[j],true);
|
||||
|
||||
this->timers[RECV].stop();
|
||||
open_counter++;
|
||||
|
||||
if (T::t() == 2)
|
||||
direct_add_openings<T,2>(values, P, oss);
|
||||
else
|
||||
direct_add_openings<T,0>(values, P, oss);
|
||||
|
||||
for (unsigned int i = 0; i < values.size(); i++)
|
||||
this->vals[this->popen_cnt+i] = values[i];
|
||||
|
||||
this->popen_cnt += values.size();
|
||||
this->CheckIfNeeded(P);
|
||||
}
|
||||
|
||||
template class MAC_Check<gfp>;
|
||||
template class Direct_MAC_Check<gfp>;
|
||||
template class Parallel_MAC_Check<gfp>;
|
||||
|
||||
template class MAC_Check<gf2n>;
|
||||
template class Direct_MAC_Check<gf2n>;
|
||||
template class Parallel_MAC_Check<gf2n>;
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template class MAC_Check<gf2n_short>;
|
||||
template class Direct_MAC_Check<gf2n_short>;
|
||||
template class Parallel_MAC_Check<gf2n_short>;
|
||||
#endif
|
||||
141
Auth/MAC_Check.h
Normal file
141
Auth/MAC_Check.h
Normal file
@@ -0,0 +1,141 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _MAC_Check
|
||||
#define _MAC_Check
|
||||
|
||||
/* Class for storing MAC Check data and doing the Check */
|
||||
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
using namespace std;
|
||||
|
||||
#include "Math/Share.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Networking/ServerSocket.h"
|
||||
#include "Auth/Summer.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
|
||||
/* The MAX number of things we will partially open before running
|
||||
* a MAC Check
|
||||
*
|
||||
* Keep this at much less than 1MB of data to be able to cope with
|
||||
* multi-threaded players
|
||||
*
|
||||
*/
|
||||
#define POPEN_MAX 1000000
|
||||
|
||||
|
||||
template<class T>
|
||||
class MAC_Check
|
||||
{
|
||||
protected:
|
||||
|
||||
/* POpen Data */
|
||||
int popen_cnt;
|
||||
vector<T> macs;
|
||||
vector<T> vals;
|
||||
int base_player;
|
||||
int opening_sum;
|
||||
int max_broadcast;
|
||||
octetStream os;
|
||||
|
||||
/* MAC Share */
|
||||
T alphai;
|
||||
|
||||
void AddToMacs(const vector< Share<T> >& shares);
|
||||
void AddToValues(vector<T>& values);
|
||||
void ReceiveValues(vector<T>& values, const Player& P, int sender);
|
||||
void GetValues(vector<T>& values);
|
||||
void CheckIfNeeded(const Player& P);
|
||||
int WaitingForCheck()
|
||||
{ return max(macs.size(), vals.size()); }
|
||||
|
||||
public:
|
||||
|
||||
int values_opened;
|
||||
vector<Timer> timers;
|
||||
vector<Timer> player_timers;
|
||||
vector<octetStream> oss;
|
||||
|
||||
MAC_Check(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0);
|
||||
virtual ~MAC_Check();
|
||||
|
||||
/* Run protocols to partially open data and check the MACs are
|
||||
* all OK.
|
||||
* - Implicit assume that the amount of data being sent does
|
||||
* not overload the OS
|
||||
* Begin and End expect the same arrays values and S passed to them
|
||||
* and they expect values to be of the same size as S.
|
||||
*/
|
||||
virtual void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
virtual void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
void AddToCheck(const T& mac, const T& value, const Player& P);
|
||||
virtual void Check(const Player& P);
|
||||
|
||||
int number() const { return values_opened; }
|
||||
|
||||
const T& get_alphai() const { return alphai; }
|
||||
};
|
||||
|
||||
|
||||
template<class T, int t>
|
||||
void add_openings(vector<T>& values, const Player& P, int sum_players, int last_sum_players, int send_player, MAC_Check<T>& MC);
|
||||
|
||||
|
||||
template<class T>
|
||||
class Separate_MAC_Check: public MAC_Check<T>
|
||||
{
|
||||
// Different channel for checks
|
||||
Player check_player;
|
||||
|
||||
protected:
|
||||
// No sense to expose this
|
||||
Separate_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0);
|
||||
virtual ~Separate_MAC_Check() {};
|
||||
|
||||
public:
|
||||
virtual void Check(const Player& P);
|
||||
};
|
||||
|
||||
|
||||
template<class T>
|
||||
class Parallel_MAC_Check: public Separate_MAC_Check<T>
|
||||
{
|
||||
// Different channel for every round
|
||||
Player send_player;
|
||||
// Managed by Summer
|
||||
Player* receive_player;
|
||||
|
||||
vector< Summer<T>* > summers;
|
||||
|
||||
int send_base_player;
|
||||
|
||||
WaitQueue< vector<T> > value_queue;
|
||||
|
||||
public:
|
||||
Parallel_MAC_Check(const T& ai, Names& Nms, int thread_num, int opening_sum=10, int max_broadcast=10, int send_player=0);
|
||||
virtual ~Parallel_MAC_Check();
|
||||
|
||||
virtual void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
virtual void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
|
||||
friend class Summer<T>;
|
||||
};
|
||||
|
||||
|
||||
template<class T>
|
||||
class Direct_MAC_Check: public Separate_MAC_Check<T>
|
||||
{
|
||||
int open_counter;
|
||||
vector<octetStream> oss;
|
||||
|
||||
public:
|
||||
Direct_MAC_Check(const T& ai, Names& Nms, int thread_num);
|
||||
~Direct_MAC_Check();
|
||||
|
||||
void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
|
||||
};
|
||||
|
||||
#endif
|
||||
235
Auth/Subroutines.cpp
Normal file
235
Auth/Subroutines.cpp
Normal file
@@ -0,0 +1,235 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Auth/Subroutines.h"
|
||||
|
||||
#include "Tools/random.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Tools/Commit.h"
|
||||
|
||||
/* To ease readability as I re-write this program the following conventions
|
||||
* will be used.
|
||||
* For a variable v index by a player i
|
||||
* Comm_v[i] is the commitment string for player i
|
||||
* Open_v[i] is the opening data for player i
|
||||
*/
|
||||
|
||||
|
||||
// Special version for octetStreams
|
||||
void Commit(vector< vector<octetStream> >& Comm_data,
|
||||
vector<octetStream>& Open_data,
|
||||
const vector< vector<octetStream> >& data,const Player& P,int num_runs)
|
||||
{
|
||||
int my_number=P.my_num();
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ Comm_data[i].resize(P.num_players());
|
||||
Commit(Comm_data[i][my_number],Open_data[i],data[i][my_number],my_number);
|
||||
P.Broadcast_Receive(Comm_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// Special version for octetStreams
|
||||
void Open(vector< vector<octetStream> >& data,
|
||||
const vector< vector<octetStream> >& Comm_data,
|
||||
const vector<octetStream>& My_Open_data,
|
||||
const Player& P,int num_runs,int dont)
|
||||
{
|
||||
int my_number=P.my_num();
|
||||
int num_players=P.num_players();
|
||||
vector<octetStream> Open_data(num_players);
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ if (i!=dont)
|
||||
{ Open_data[my_number]=My_Open_data[i];
|
||||
P.Broadcast_Receive(Open_data);
|
||||
for (int j=0; j<num_players; j++)
|
||||
{ if (j!=my_number)
|
||||
{ if (!Open(data[i][j],Comm_data[i][j],Open_data[j],j))
|
||||
{ throw invalid_commitment(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Open(vector< vector<octetStream> >& data,
|
||||
const vector< vector<octetStream> >& Comm_data,
|
||||
const vector<octetStream>& My_Open_data,
|
||||
const vector<int> open,
|
||||
const Player& P,int num_runs)
|
||||
{
|
||||
int my_number=P.my_num();
|
||||
int num_players=P.num_players();
|
||||
vector<octetStream> Open_data(num_players);
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ if (open[i]==1)
|
||||
{ Open_data[my_number]=My_Open_data[i];
|
||||
P.Broadcast_Receive(Open_data);
|
||||
for (int j=0; j<num_players; j++)
|
||||
{ if (j!=my_number)
|
||||
{ if (!Open(data[i][j],Comm_data[i][j],Open_data[j],j))
|
||||
{ throw invalid_commitment(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void Commit_To_Challenge(vector<unsigned int>& e,
|
||||
vector<octetStream>& Comm_e,vector<octetStream>& Open_e,
|
||||
const Player& P,int num_runs)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
e.resize(P.num_players());
|
||||
Comm_e.resize(P.num_players());
|
||||
Open_e.resize(P.num_players());
|
||||
|
||||
e[P.my_num()]=G.get_uint()%num_runs;
|
||||
octetStream ee; ee.store(e[P.my_num()]);
|
||||
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num());
|
||||
P.Broadcast_Receive(Comm_e);
|
||||
}
|
||||
|
||||
|
||||
int Open_Challenge(vector<unsigned int>& e,vector<octetStream>& Open_e,
|
||||
const vector<octetStream>& Comm_e,
|
||||
const Player& P,int num_runs)
|
||||
{
|
||||
// Now open the challenge commitments and determine which run was for real
|
||||
P.Broadcast_Receive(Open_e);
|
||||
|
||||
int challenge=0;
|
||||
octetStream ee;
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{ if (i != P.my_num())
|
||||
{ if (!Open(ee,Comm_e[i],Open_e[i],i))
|
||||
{ throw invalid_commitment(); }
|
||||
ee.get(e[i]);
|
||||
}
|
||||
challenge+=e[i];
|
||||
}
|
||||
challenge = challenge % num_runs;
|
||||
|
||||
return challenge;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Create_Random(T& ans,const Player& P)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
vector<T> e(P.num_players());
|
||||
vector<octetStream> Comm_e(P.num_players());
|
||||
vector<octetStream> Open_e(P.num_players());
|
||||
|
||||
e[P.my_num()].randomize(G);
|
||||
octetStream ee;
|
||||
e[P.my_num()].pack(ee);
|
||||
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],ee,P.my_num());
|
||||
P.Broadcast_Receive(Comm_e);
|
||||
|
||||
P.Broadcast_Receive(Open_e);
|
||||
|
||||
ans.assign_zero();
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{ if (i != P.my_num())
|
||||
{ if (!Open(ee,Comm_e[i],Open_e[i],i))
|
||||
{ throw invalid_commitment(); }
|
||||
e[i].unpack(ee);
|
||||
}
|
||||
ans.add(ans,e[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Create_Random_Seed(octet* seed,const Player& P,int len)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
vector<octetStream> e(P.num_players());
|
||||
vector<octetStream> Comm_e(P.num_players());
|
||||
vector<octetStream> Open_e(P.num_players());
|
||||
|
||||
G.get_octetStream(e[P.my_num()],len);
|
||||
Commit(Comm_e[P.my_num()],Open_e[P.my_num()],e[P.my_num()],P.my_num());
|
||||
P.Broadcast_Receive(Comm_e);
|
||||
|
||||
P.Broadcast_Receive(Open_e);
|
||||
|
||||
memset(seed,0,len*sizeof(octet));
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{ if (i != P.my_num())
|
||||
{ if (!Open(e[i],Comm_e[i],Open_e[i],i))
|
||||
{ throw invalid_commitment(); }
|
||||
}
|
||||
for (int j=0; j<len; j++)
|
||||
{ seed[j]=seed[j]^(e[i].get_data()[j]); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Commit_And_Open(vector<T>& data,const Player& P)
|
||||
{
|
||||
vector<octetStream> Comm_data(P.num_players());
|
||||
vector<octetStream> Open_data(P.num_players());
|
||||
|
||||
octetStream ee;
|
||||
data[P.my_num()].pack(ee);
|
||||
Commit(Comm_data[P.my_num()],Open_data[P.my_num()],ee,P.my_num());
|
||||
P.Broadcast_Receive(Comm_data);
|
||||
|
||||
P.Broadcast_Receive(Open_data);
|
||||
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{ if (i != P.my_num())
|
||||
{ if (!Open(ee,Comm_data[i],Open_data[i],i))
|
||||
{ throw invalid_commitment(); }
|
||||
data[i].unpack(ee);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Commit_To_Seeds(vector<PRNG>& G,
|
||||
vector< vector<octetStream> >& seeds,
|
||||
vector< vector<octetStream> >& Comm_seeds,
|
||||
vector<octetStream>& Open_seeds,
|
||||
const Player& P,int num_runs)
|
||||
{
|
||||
seeds.resize(num_runs);
|
||||
Comm_seeds.resize(num_runs);
|
||||
Open_seeds.resize(num_runs);
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ G[i].ReSeed();
|
||||
seeds[i].resize(P.num_players());
|
||||
Comm_seeds[i].resize(P.num_players());
|
||||
Open_seeds[i].resize(P.num_players());
|
||||
seeds[i][P.my_num()].reset_write_head();
|
||||
seeds[i][P.my_num()].append(G[i].get_seed(),SEED_SIZE);
|
||||
}
|
||||
Commit(Comm_seeds,Open_seeds,seeds,P,num_runs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template void Commit_And_Open(vector<gf2n>& data,const Player& P);
|
||||
template void Create_Random(gf2n& ans,const Player& P);
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template void Commit_And_Open(vector<gf2n_short>& data,const Player& P);
|
||||
template void Create_Random(gf2n_short& ans,const Player& P);
|
||||
#endif
|
||||
|
||||
template void Commit_And_Open(vector<gfp>& data,const Player& P);
|
||||
template void Create_Random(gfp& ans,const Player& P);
|
||||
140
Auth/Subroutines.h
Normal file
140
Auth/Subroutines.h
Normal file
@@ -0,0 +1,140 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Subroutines
|
||||
#define _Subroutines
|
||||
|
||||
/* Defines subroutines for use in both KeyGen and Offline phase
|
||||
* Mainly focused around commiting and decommitting to various
|
||||
* bits of data
|
||||
*/
|
||||
|
||||
#include "Tools/random.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Tools/Commit.h"
|
||||
|
||||
/* Run just the Open Protocol for data[i][j] of type octetStream
|
||||
* 0 <= i < num_runs
|
||||
* 0 <= j < num_players
|
||||
* On output data[i][j] contains all the data
|
||||
* If dont!=-1 then dont open this run
|
||||
*/
|
||||
void Open(vector< vector<octetStream> >& data,
|
||||
const vector< vector<octetStream> >& Comm_data,
|
||||
const vector<octetStream>& My_Open_data,
|
||||
const Player& P,int num_runs,int dont=-1);
|
||||
|
||||
/* This one takes a vector open which contains 0 and 1
|
||||
* If 1 then we open this value, otherwise we do not
|
||||
*/
|
||||
void Open(vector< vector<octetStream> >& data,
|
||||
const vector< vector<octetStream> >& Comm_data,
|
||||
const vector<octetStream>& My_Open_data,
|
||||
const vector<int> open,
|
||||
const Player& P,int num_runs);
|
||||
|
||||
|
||||
|
||||
|
||||
/* This runs the Commit and Open Protocol for data[i][j] of type T
|
||||
* 0 <= i < num_runs
|
||||
* 0 <= j < num_players
|
||||
* On input data[i][j] is only defined for j=my_number
|
||||
*/
|
||||
template<class T>
|
||||
void Commit_And_Open(vector< vector<T> >& data,const Player& P,int num_runs);
|
||||
|
||||
template<class T>
|
||||
void Commit_And_Open(vector<T>& data,const Player& P);
|
||||
|
||||
template<class T>
|
||||
void Transmit_Data(vector< vector<T> >& data,const Player& P,int num_runs);
|
||||
|
||||
/* Functions to Commit and Open a Challenge Value */
|
||||
void Commit_To_Challenge(vector<unsigned int>& e,
|
||||
vector<octetStream>& Comm_e,vector<octetStream>& Open_e,
|
||||
const Player& P,int num_runs);
|
||||
|
||||
int Open_Challenge(vector<unsigned int>& e,vector<octetStream>& Open_e,
|
||||
const vector<octetStream>& Comm_e,
|
||||
const Player& P,int num_runs);
|
||||
|
||||
|
||||
/* Function to create a shared random value for T=gfp/gf2n */
|
||||
template<class T>
|
||||
void Create_Random(T& ans,const Player& P);
|
||||
|
||||
/* Produce a random seed of length len */
|
||||
void Create_Random_Seed(octet* seed,const Player& P,int len);
|
||||
|
||||
|
||||
|
||||
/* Functions to Commit to Seed Values
|
||||
* This also initialises the PRNG's in G
|
||||
*/
|
||||
void Commit_To_Seeds(vector<PRNG>& G,
|
||||
vector< vector<octetStream> >& seeds,
|
||||
vector< vector<octetStream> >& Comm_seeds,
|
||||
vector<octetStream>& Open_seeds,
|
||||
const Player& P,int num_runs);
|
||||
|
||||
|
||||
|
||||
/* Run just the Commit Protocol for data[i][j] of type T
|
||||
* 0 <= i < num_runs
|
||||
* 0 <= j < num_players
|
||||
* On input data[i][j] is only defined for j=my_number
|
||||
*/
|
||||
template<class T>
|
||||
void Commit(vector< vector<octetStream> >& Comm_data,
|
||||
vector<octetStream>& Open_data,
|
||||
const vector< vector<T> >& data,const Player& P,int num_runs)
|
||||
{
|
||||
octetStream os;
|
||||
int my_number=P.my_num();
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ os.reset_write_head();
|
||||
data[i][my_number].pack(os);
|
||||
Comm_data[i].resize(P.num_players());
|
||||
Commit(Comm_data[i][my_number],Open_data[i],os,my_number);
|
||||
P.Broadcast_Receive(Comm_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Run just the Open Protocol for data[i][j] of type T
|
||||
* 0 <= i < num_runs
|
||||
* 0 <= j < num_players
|
||||
* On output data[i][j] contains all the data
|
||||
* If dont!=-1 then dont open this run
|
||||
*/
|
||||
template<class T>
|
||||
void Open(vector< vector<T> >& data,
|
||||
const vector< vector<octetStream> >& Comm_data,
|
||||
const vector<octetStream>& My_Open_data,
|
||||
const Player& P,int num_runs,int dont=-1)
|
||||
{
|
||||
octetStream os;
|
||||
int my_number=P.my_num();
|
||||
int num_players=P.num_players();
|
||||
vector<octetStream> Open_data(num_players);
|
||||
for (int i=0; i<num_runs; i++)
|
||||
{ if (i!=dont)
|
||||
{ Open_data[my_number]=My_Open_data[i];
|
||||
P.Broadcast_Receive(Open_data);
|
||||
for (int j=0; j<num_players; j++)
|
||||
{ if (j!=my_number)
|
||||
{ if (!Open(os,Comm_data[i][j],Open_data[j],j))
|
||||
{ throw invalid_commitment(); }
|
||||
os.reset_read_head();
|
||||
data[i][j].unpack(os);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
94
Auth/Summer.cpp
Normal file
94
Auth/Summer.cpp
Normal file
@@ -0,0 +1,94 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Summer.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Auth/Summer.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
#include "Tools/int.h"
|
||||
|
||||
template<class T>
|
||||
Summer<T>::Summer(int sum_players, int last_sum_players, int next_sum_players,
|
||||
Player* send_player, Player* receive_player, Parallel_MAC_Check<T>& MC) :
|
||||
sum_players(sum_players), last_sum_players(last_sum_players), next_sum_players(next_sum_players),
|
||||
base_player(0), MC(MC),send_player(send_player), receive_player(receive_player),
|
||||
thread(0), stop(false), size(0)
|
||||
{
|
||||
cout << "Setting up summation by " << sum_players << " players" << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Summer<T>::~Summer()
|
||||
{
|
||||
delete send_player;
|
||||
if (timer.elapsed())
|
||||
cout << T::type_string() << " summation by " << sum_players << " players: "
|
||||
<< timer.elapsed() << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Summer<T>::run()
|
||||
{
|
||||
octetStream os;
|
||||
|
||||
while (true)
|
||||
{
|
||||
int size = 0;
|
||||
if (!input_queue.pop(size))
|
||||
break;
|
||||
|
||||
int my_relative_num = positive_modulo(send_player->my_num() - base_player, send_player->num_players());
|
||||
if (my_relative_num < sum_players)
|
||||
{
|
||||
// first summer takes inputs from queue
|
||||
if (last_sum_players == send_player->num_players())
|
||||
share_queue.pop(values);
|
||||
else
|
||||
{
|
||||
values.resize(size);
|
||||
receive_player->receive_player(receive_player->my_num(),os,true);
|
||||
for (int i = 0; i < size; i++)
|
||||
values[i].unpack(os);
|
||||
}
|
||||
|
||||
timer.start();
|
||||
if (T::t() == 2)
|
||||
add_openings<T,2>(values, *receive_player, sum_players,
|
||||
last_sum_players, base_player, MC);
|
||||
else
|
||||
add_openings<T,0>(values, *receive_player, sum_players,
|
||||
last_sum_players, base_player, MC);
|
||||
timer.stop();
|
||||
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < size; i++)
|
||||
values[i].pack(os);
|
||||
|
||||
if (sum_players > 1)
|
||||
{
|
||||
int receiver = positive_modulo(base_player + my_relative_num % next_sum_players,
|
||||
send_player->num_players());
|
||||
send_player->send_to(receiver, os, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
send_player->send_all(os);
|
||||
MC.value_queue.push(values);
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_players == 1)
|
||||
output_queue.push(size);
|
||||
|
||||
base_player = (base_player + 1) % send_player->num_players();
|
||||
}
|
||||
}
|
||||
|
||||
template class Summer<gfp>;
|
||||
template class Summer<gf2n>;
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template class Summer<gf2n_short>;
|
||||
#endif
|
||||
47
Auth/Summer.h
Normal file
47
Auth/Summer.h
Normal file
@@ -0,0 +1,47 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Summer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OFFLINE_SUMMER_H_
|
||||
#define OFFLINE_SUMMER_H_
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
#include <pthread.h>
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
template<class T>
|
||||
class Parallel_MAC_Check;
|
||||
|
||||
template<class T>
|
||||
class Summer
|
||||
{
|
||||
int sum_players, last_sum_players, next_sum_players;
|
||||
int base_player;
|
||||
Parallel_MAC_Check<T>& MC;
|
||||
Player* send_player;
|
||||
Player* receive_player;
|
||||
Timer timer;
|
||||
|
||||
public:
|
||||
vector<T> values;
|
||||
|
||||
pthread_t thread;
|
||||
WaitQueue<int> input_queue, output_queue;
|
||||
bool stop;
|
||||
int size;
|
||||
WaitQueue< vector<T> > share_queue;
|
||||
|
||||
Summer(int sum_players, int last_sum_players, int next_sum_players,
|
||||
Player* send_player, Player* receive_player, Parallel_MAC_Check<T>& MC);
|
||||
~Summer();
|
||||
void run();
|
||||
};
|
||||
|
||||
#endif /* OFFLINE_SUMMER_H_ */
|
||||
171
Auth/fake-stuff.cpp
Normal file
171
Auth/fake-stuff.cpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Share.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
template<class T>
|
||||
void make_share(vector<Share<T> >& Sa,const T& a,int N,const T& key,PRNG& G)
|
||||
{
|
||||
T mac,x,y;
|
||||
mac.mul(a,key);
|
||||
Share<T> S;
|
||||
S.set_share(a);
|
||||
S.set_mac(mac);
|
||||
|
||||
for (int i=0; i<N-1; i++)
|
||||
{ x.randomize(G);
|
||||
y.randomize(G);
|
||||
Sa[i].set_share(x);
|
||||
Sa[i].set_mac(y);
|
||||
S.sub(S,Sa[i]);
|
||||
}
|
||||
Sa[N-1]=S;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key)
|
||||
{
|
||||
value.assign(0);
|
||||
mac.assign(0);
|
||||
|
||||
for (int i=0; i<N; i++)
|
||||
{
|
||||
value.add(Sa[i].get_share());
|
||||
mac.add(Sa[i].get_mac());
|
||||
}
|
||||
|
||||
T res;
|
||||
res.mul(value, key);
|
||||
if (!res.equal(mac))
|
||||
{
|
||||
cout << "Value: " << value << endl;
|
||||
cout << "Input MAC: " << mac << endl;
|
||||
cout << "Actual MAC: " << res << endl;
|
||||
cout << "MAC key: " << key << endl;
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
|
||||
template void make_share(vector<Share<gf2n> >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G);
|
||||
template void make_share(vector<Share<gfp> >& Sa,const gfp& a,int N,const gfp& key,PRNG& G);
|
||||
|
||||
template void check_share(vector<Share<gf2n> >& Sa,gf2n& value,gf2n& mac,int N,const gf2n& key);
|
||||
template void check_share(vector<Share<gfp> >& Sa,gfp& value,gfp& mac,int N,const gfp& key);
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template void make_share(vector<Share<gf2n_short> >& Sa,const gf2n_short& a,int N,const gf2n_short& key,PRNG& G);
|
||||
template void check_share(vector<Share<gf2n_short> >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key);
|
||||
#endif
|
||||
|
||||
// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40)
|
||||
void expand_byte(gf2n_short& a,int b)
|
||||
{
|
||||
gf2n_short x,xp;
|
||||
x.assign(32+1);
|
||||
xp.assign_one();
|
||||
a.assign_zero();
|
||||
|
||||
while (b!=0)
|
||||
{ if ((b&1)==1)
|
||||
{ a.add(a,xp); }
|
||||
xp.mul(x);
|
||||
b>>=1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Have previously worked out the linear equations we need to solve
|
||||
void collapse_byte(int& b,const gf2n_short& aa)
|
||||
{
|
||||
word w=aa.get();
|
||||
int e35=(w>>35)&1;
|
||||
int e30=(w>>30)&1;
|
||||
int e25=(w>>25)&1;
|
||||
int e20=(w>>20)&1;
|
||||
int e15=(w>>15)&1;
|
||||
int e10=(w>>10)&1;
|
||||
int e5=(w>>5)&1;
|
||||
int e0=w&1;
|
||||
int a[8];
|
||||
a[7]=e35;
|
||||
a[6]=e30^a[7];
|
||||
a[5]=e25^a[7];
|
||||
a[4]=e20^a[5]^a[6]^a[7];
|
||||
a[3]=e15^a[7];
|
||||
a[2]=e10^a[3]^a[6]^a[7];
|
||||
a[1]=e5^a[3]^a[5]^a[7];
|
||||
a[0]=e0^a[1]^a[2]^a[3]^a[4]^a[5]^a[6]^a[7];
|
||||
|
||||
b=0;
|
||||
for (int i=7; i>=0; i--)
|
||||
{ b=b<<1;
|
||||
b+=a[i];
|
||||
}
|
||||
}
|
||||
|
||||
void generate_keys(const string& directory, int nplayers)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
gf2n mac2;
|
||||
gfp macp;
|
||||
mac2.assign_zero();
|
||||
macp.assign_zero();
|
||||
|
||||
ofstream outf;
|
||||
|
||||
for (int i = 0; i < nplayers; i++)
|
||||
{
|
||||
stringstream filename;
|
||||
filename << directory << "Player-MAC-Keys-P" << i;
|
||||
mac2.randomize(G);
|
||||
macp.randomize(G);
|
||||
cout << "Writing to " << filename.str().c_str() << endl;
|
||||
outf.open(filename.str().c_str());
|
||||
outf << nplayers << endl;
|
||||
macp.output(outf,true);
|
||||
outf << " ";
|
||||
mac2.output(outf,true);
|
||||
outf << endl;
|
||||
outf.close();
|
||||
}
|
||||
}
|
||||
|
||||
void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers)
|
||||
{
|
||||
gfp sharep;
|
||||
gf2n share2;
|
||||
keyp.assign_zero();
|
||||
key2.assign_zero();
|
||||
int i, tmpN = 0;
|
||||
ifstream inpf;
|
||||
|
||||
for (i = 0; i < nplayers; i++)
|
||||
{
|
||||
stringstream filename;
|
||||
filename << directory << "Player-MAC-Keys-P" << i;
|
||||
inpf.open(filename.str().c_str());
|
||||
if (inpf.fail())
|
||||
{
|
||||
inpf.close();
|
||||
cout << "Error: No MAC key share found for player " << i << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
inpf >> tmpN; // not needed here
|
||||
sharep.input(inpf,true);
|
||||
share2.input(inpf,true);
|
||||
inpf.close();
|
||||
}
|
||||
std::cout << " Key " << i << "\t p: " << sharep << "\n\t 2: " << share2 << std::endl;
|
||||
keyp.add(sharep);
|
||||
key2.add(share2);
|
||||
}
|
||||
std::cout << "Final MAC keys :\t p: " << keyp << "\n\t\t 2: " << key2 << std::endl;
|
||||
}
|
||||
64
Auth/fake-stuff.h
Normal file
64
Auth/fake-stuff.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#ifndef _fake_stuff
|
||||
#define _fake_stuff
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Share.h"
|
||||
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
template<class T>
|
||||
void make_share(vector<Share<T> >& Sa,const T& a,int N,const T& key,PRNG& G);
|
||||
|
||||
template<class T>
|
||||
void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key);
|
||||
|
||||
void expand_byte(gf2n_short& a,int b);
|
||||
void collapse_byte(int& b,const gf2n_short& a);
|
||||
|
||||
// Generate MAC key shares
|
||||
void generate_keys(const string& directory, int nplayers);
|
||||
|
||||
// Read MAC key shares and compute keys
|
||||
void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers);
|
||||
|
||||
template <class T>
|
||||
class Files
|
||||
{
|
||||
public:
|
||||
ofstream* outf;
|
||||
int N;
|
||||
T key;
|
||||
PRNG G;
|
||||
Files(int N, const T& key, const string& prefix) : N(N), key(key)
|
||||
{
|
||||
outf = new ofstream[N];
|
||||
for (int i=0; i<N; i++)
|
||||
{
|
||||
stringstream filename;
|
||||
filename << prefix << "-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail())
|
||||
throw file_error(filename.str().c_str());
|
||||
}
|
||||
G.ReSeed();
|
||||
}
|
||||
~Files()
|
||||
{
|
||||
delete[] outf;
|
||||
}
|
||||
void output_shares(const T& a)
|
||||
{
|
||||
vector<Share<T> > Sa(N);
|
||||
make_share(Sa,a,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
Sa[j].output(outf[j],false);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
46
CONFIG
Normal file
46
CONFIG
Normal file
@@ -0,0 +1,46 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
ROOT = .
|
||||
|
||||
OPTIM= -O3
|
||||
#PROF = -pg
|
||||
#DEBUG = -DDEBUG
|
||||
#MEMPROTECT = -DMEMPROTECT
|
||||
|
||||
# set this to your preferred local storage directory
|
||||
PREP_DIR = '-DPREP_DIR="Player-Data/"'
|
||||
|
||||
# set for 128-bit GF(2^n) and/or OT preprocessing
|
||||
USE_GF2N_LONG = 0
|
||||
|
||||
# set to -march=<architecture> for optimization
|
||||
# AVX2 support (Haswell or later) changes the bit matrix transpose
|
||||
ARCH = -mtune=native
|
||||
|
||||
#use CONFIG.mine to overwrite DIR settings
|
||||
-include CONFIG.mine
|
||||
|
||||
ifeq ($(USE_GF2N_LONG),1)
|
||||
GF2N_LONG = -DUSE_GF2N_LONG
|
||||
endif
|
||||
|
||||
# MAX_MOD_SZ must be at least ceil(len(p)/len(word))+1
|
||||
# Default is 3, which suffices for 128-bit p
|
||||
# MOD = -DMAX_MOD_SZ=3
|
||||
|
||||
LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
LDLIBS := -lntl $(LDLIBS)
|
||||
endif
|
||||
|
||||
OS := $(shell uname -s)
|
||||
ifeq ($(OS), Linux)
|
||||
LDLIBS += -lrt
|
||||
endif
|
||||
|
||||
CXX = g++
|
||||
CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH)
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = g++
|
||||
|
||||
203
Check-Offline.cpp
Normal file
203
Check-Offline.cpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Check-Offline.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Share.h"
|
||||
#include "Auth/fake-stuff.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include "Math/Setup.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
string PREP_DATA_PREFIX;
|
||||
|
||||
template<class T>
|
||||
void check_mult_triples(const T& key,int N,vector<Data_Files*>& dataF,DataFieldType field_type)
|
||||
{
|
||||
T a,b,c,mac,res;
|
||||
vector<Share<T> > Sa(N),Sb(N),Sc(N);
|
||||
int n = 0;
|
||||
|
||||
try {
|
||||
while (!dataF[0]->eof<T>(DATA_TRIPLE))
|
||||
{
|
||||
for (int i = 0; i < N; i++)
|
||||
dataF[i]->get_three(field_type, DATA_TRIPLE, Sa[i], Sb[i], Sc[i]);
|
||||
check_share(Sa, a, mac, N, key);
|
||||
check_share(Sb, b, mac, N, key);
|
||||
check_share(Sc, c, mac, N, key);
|
||||
|
||||
res.mul(a, b);
|
||||
if (!res.equal(c))
|
||||
{
|
||||
cout << n << ": " << c << " != " << a << " * " << b << endl;
|
||||
throw bad_value();
|
||||
}
|
||||
n++;
|
||||
}
|
||||
|
||||
cout << n << " triples of type " << T::type_string() << endl;
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
cout << "Error with triples of type " << T::type_string() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void check_bits(const T& key,int N,vector<Data_Files*>& dataF,DataFieldType field_type)
|
||||
{
|
||||
T a,b,c,mac,res;
|
||||
vector<Share<T> > Sa(N),Sb(N),Sc(N);
|
||||
int n = 0;
|
||||
|
||||
while (!dataF[0]->eof<T>(DATA_BIT))
|
||||
{
|
||||
for (int i = 0; i < N; i++)
|
||||
dataF[i]->get_one(field_type, DATA_BIT, Sa[i]);
|
||||
check_share(Sa, a, mac, N, key);
|
||||
|
||||
if (!(a.is_zero() || a.is_one()))
|
||||
{
|
||||
cout << n << ": " << a << " neither 0 or 1" << endl;
|
||||
throw bad_value();
|
||||
}
|
||||
n++;
|
||||
}
|
||||
|
||||
cout << n << " bits of type " << T::type_string() << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void check_inputs(const T& key,int N,vector<Data_Files*>& dataF)
|
||||
{
|
||||
T a, mac, x;
|
||||
vector< Share<T> > Sa(N);
|
||||
|
||||
for (int player = 0; player < N; player++)
|
||||
{
|
||||
int n = 0;
|
||||
while (!dataF[0]->input_eof<T>(player))
|
||||
{
|
||||
for (int i = 0; i < N; i++)
|
||||
dataF[i]->get_input(Sa[i], x, player);
|
||||
check_share(Sa, a, mac, N, key);
|
||||
if (!a.equal(x))
|
||||
throw bad_value();
|
||||
n++;
|
||||
}
|
||||
cout << n << " input masks for player " << player << " of type " << T::type_string() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
gfp::init_field(gfp::pr(), false);
|
||||
|
||||
opt.syntax = "./Check-Offline.x <nparties> [OPTIONS]\n";
|
||||
opt.example = "./Check-Offline.x 3 -lgp 64 -lg2 128\n";
|
||||
|
||||
opt.add(
|
||||
"128", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(p) field (default: 128)", // Help description.
|
||||
"-lgp", // Flag token.
|
||||
"--lgp" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"40", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(2^n) field (default: 40)", // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Read GF(p) triples in Montgomery representation (default: not set)", // Help description.
|
||||
"-m", // Flag token.
|
||||
"--usemont" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
|
||||
string usage;
|
||||
int lgp, lg2, nparties;
|
||||
bool use_montgomery = false;
|
||||
opt.get("--lgp")->getInt(lgp);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
if (opt.isSet("--usemont"))
|
||||
use_montgomery = true;
|
||||
|
||||
if (opt.firstArgs.size() == 2)
|
||||
nparties = atoi(opt.firstArgs[1]->c_str());
|
||||
else if (opt.lastArgs.size() == 1)
|
||||
nparties = atoi(opt.lastArgs[0]->c_str());
|
||||
else
|
||||
{
|
||||
cerr << "ERROR: invalid number of arguments\n";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
PREP_DATA_PREFIX = get_prep_dir(nparties, lgp, lg2);
|
||||
read_setup(PREP_DATA_PREFIX);
|
||||
|
||||
if (!use_montgomery)
|
||||
{
|
||||
// no montgomery
|
||||
gfp::init_field(gfp::pr(), false);
|
||||
}
|
||||
|
||||
/* Find number players and MAC keys etc*/
|
||||
char filename[1024];
|
||||
gfp keyp,pp; keyp.assign_zero();
|
||||
gf2n key2,p2; key2.assign_zero();
|
||||
int N=1;
|
||||
ifstream inpf;
|
||||
for (int i= 0; i < nparties; i++)
|
||||
{
|
||||
sprintf(filename, (PREP_DATA_PREFIX + "Player-MAC-Keys-P%d").c_str(), i);
|
||||
inpf.open(filename);
|
||||
if (inpf.fail()) { throw file_error(filename); }
|
||||
inpf >> N;
|
||||
pp.input(inpf,true);
|
||||
p2.input(inpf,true);
|
||||
cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl;
|
||||
keyp.add(pp);
|
||||
key2.add(p2);
|
||||
inpf.close();
|
||||
}
|
||||
cout << "--------------\n";
|
||||
cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl;
|
||||
|
||||
vector<Data_Files*> dataF(N);
|
||||
for (int i = 0; i < N; i++)
|
||||
dataF[i] = new Data_Files(i, N, PREP_DATA_PREFIX);
|
||||
check_mult_triples(key2, N, dataF, DATA_GF2N);
|
||||
check_mult_triples(keyp, N, dataF, DATA_MODP);
|
||||
check_inputs(key2, N, dataF);
|
||||
check_inputs(keyp, N, dataF);
|
||||
check_bits(key2, N, dataF, DATA_GF2N);
|
||||
check_bits(keyp, N, dataF, DATA_MODP);
|
||||
for (int i = 0; i < N; i++)
|
||||
delete dataF[i];
|
||||
}
|
||||
31
Compiler/__init__.py
Normal file
31
Compiler/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import compilerLib, program, instructions, types, library, floatingpoint
|
||||
import inspect
|
||||
from config import *
|
||||
from compilerLib import run
|
||||
|
||||
|
||||
# add all instructions to the program VARS dictionary
|
||||
compilerLib.VARS = {}
|
||||
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
|
||||
|
||||
instr_classes += [t[1] for t in inspect.getmembers(types, inspect.isclass)\
|
||||
if t[1].__module__ == types.__name__]
|
||||
|
||||
instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\
|
||||
if t[1].__module__ == library.__name__]
|
||||
|
||||
for op in instr_classes:
|
||||
compilerLib.VARS[op.__name__] = op
|
||||
|
||||
# add open and input separately due to name conflict
|
||||
compilerLib.VARS['open'] = instructions.asm_open
|
||||
compilerLib.VARS['vopen'] = instructions.vasm_open
|
||||
compilerLib.VARS['gopen'] = instructions.gasm_open
|
||||
compilerLib.VARS['vgopen'] = instructions.vgasm_open
|
||||
compilerLib.VARS['input'] = instructions.asm_input
|
||||
compilerLib.VARS['ginput'] = instructions.gasm_input
|
||||
|
||||
compilerLib.VARS['comparison'] = comparison
|
||||
compilerLib.VARS['floatingpoint'] = floatingpoint
|
||||
653
Compiler/allocator.py
Normal file
653
Compiler/allocator.py
Normal file
@@ -0,0 +1,653 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import itertools, time
|
||||
from collections import defaultdict, deque
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.config import *
|
||||
from Compiler.instructions import *
|
||||
from Compiler.instructions_base import *
|
||||
from Compiler.util import *
|
||||
import Compiler.graph
|
||||
import Compiler.program
|
||||
import heapq, itertools
|
||||
import operator
|
||||
|
||||
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
It is based on the precondition that every register is only defined once."""
|
||||
def __init__(self, n):
|
||||
self.free = defaultdict(set)
|
||||
self.alloc = {}
|
||||
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
|
||||
self.defined = {}
|
||||
self.dealloc = set()
|
||||
self.n = n
|
||||
|
||||
def alloc_reg(self, reg, persistent_allocation):
|
||||
base = reg.vectorbase
|
||||
if base in self.alloc:
|
||||
# already allocated
|
||||
return
|
||||
|
||||
reg_type = reg.reg_type
|
||||
size = base.size
|
||||
if not persistent_allocation and self.free[reg_type, size]:
|
||||
res = self.free[reg_type, size].pop()
|
||||
else:
|
||||
if self.usage[reg_type] < self.n:
|
||||
res = self.usage[reg_type]
|
||||
self.usage[reg_type] += size
|
||||
else:
|
||||
raise RegisterOverflowError()
|
||||
self.alloc[base] = res
|
||||
|
||||
if base.vector:
|
||||
for i,r in enumerate(base.vector):
|
||||
r.i = self.alloc[base] + i
|
||||
else:
|
||||
base.i = self.alloc[base]
|
||||
|
||||
def dealloc_reg(self, reg, inst):
|
||||
self.dealloc.add(reg)
|
||||
base = reg.vectorbase
|
||||
|
||||
if base.vector and not inst.is_vec():
|
||||
for i in base.vector:
|
||||
if i not in self.dealloc:
|
||||
# not all vector elements ready for deallocation
|
||||
return
|
||||
self.free[reg.reg_type, base.size].add(self.alloc[base])
|
||||
if inst.is_vec() and base.vector:
|
||||
for i in base.vector:
|
||||
self.defined[i] = inst
|
||||
else:
|
||||
self.defined[reg] = inst
|
||||
|
||||
def process(self, program, persistent_allocation=False):
|
||||
for k,i in enumerate(reversed(program)):
|
||||
unused_regs = []
|
||||
for j in i.get_def():
|
||||
if j.vectorbase in self.alloc:
|
||||
if j in self.defined:
|
||||
raise CompilerError("Double write on register %s " \
|
||||
"assigned by '%s' in %s" % \
|
||||
(j,i,format_trace(i.caller)))
|
||||
else:
|
||||
# unused register
|
||||
self.alloc_reg(j, persistent_allocation)
|
||||
unused_regs.append(j)
|
||||
if unused_regs and len(unused_regs) == len(i.get_def()):
|
||||
# only report if all assigned registers are unused
|
||||
print "Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller))
|
||||
|
||||
for j in i.get_used():
|
||||
self.alloc_reg(j, persistent_allocation)
|
||||
for j in i.get_def():
|
||||
self.dealloc_reg(j, i)
|
||||
|
||||
if k % 1000000 == 0 and k > 0:
|
||||
print "Allocated registers for %d instructions at" % k, time.asctime()
|
||||
|
||||
# print "Successfully allocated registers"
|
||||
# print "modp usage: %d clear, %d secret" % \
|
||||
# (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp])
|
||||
# print "GF2N usage: %d clear, %d secret" % \
|
||||
# (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N])
|
||||
return self.usage
|
||||
|
||||
|
||||
def determine_scope(block):
|
||||
last_def = defaultdict(lambda: -1)
|
||||
used_from_scope = set()
|
||||
|
||||
def find_in_scope(reg, scope):
|
||||
if scope is None:
|
||||
return False
|
||||
elif reg in scope.defined_registers:
|
||||
return True
|
||||
else:
|
||||
return find_in_scope(reg, scope.scope)
|
||||
|
||||
def read(reg, n):
|
||||
if last_def[reg] == -1:
|
||||
if find_in_scope(reg, block.scope):
|
||||
used_from_scope.add(reg)
|
||||
reg.can_eliminate = False
|
||||
else:
|
||||
print 'Warning: read before write at register', reg
|
||||
print '\tline %d: %s' % (n, instr)
|
||||
print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')
|
||||
print '\tregister trace: %s' % format_trace(reg.caller, '\t\t')
|
||||
|
||||
def write(reg, n):
|
||||
if last_def[reg] != -1:
|
||||
print 'Warning: double write at register', reg
|
||||
print '\tline %d: %s' % (n, instr)
|
||||
print '\ttrace: %s' % format_trace(instr.caller, '\t\t')
|
||||
last_def[reg] = n
|
||||
|
||||
for n,instr in enumerate(block.instructions):
|
||||
outputs,inputs = instr.get_def(), instr.get_used()
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
for reg in outputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
write(i, n)
|
||||
else:
|
||||
write(reg, n)
|
||||
|
||||
block.used_from_scope = used_from_scope
|
||||
block.defined_registers = set(last_def.iterkeys())
|
||||
|
||||
class Merger:
|
||||
def __init__(self, block, options):
|
||||
self.block = block
|
||||
self.instructions = block.instructions
|
||||
self.options = options
|
||||
if options.max_parallel_open:
|
||||
self.max_parallel_open = int(options.max_parallel_open)
|
||||
else:
|
||||
self.max_parallel_open = float('inf')
|
||||
self.dependency_graph()
|
||||
|
||||
def do_merge(self, merges_iter):
|
||||
""" Merge an iterable of nodes in G, returning the number of merged
|
||||
instructions and the index of the merged instruction. """
|
||||
instructions = self.instructions
|
||||
mergecount = 0
|
||||
try:
|
||||
n = next(merges_iter)
|
||||
except StopIteration:
|
||||
return mergecount, None
|
||||
|
||||
def expand_vector_args(inst):
|
||||
new_args = []
|
||||
for arg in inst.args:
|
||||
if inst.is_vec():
|
||||
arg.create_vector_elements()
|
||||
for reg in arg:
|
||||
new_args.append(reg)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
for i in merges_iter:
|
||||
if isinstance(instructions[n], startinput_class):
|
||||
instructions[n].args[1] += instructions[i].args[1]
|
||||
elif isinstance(instructions[n], (stopinput, gstopinput)):
|
||||
if instructions[n].get_size() != instructions[i].get_size():
|
||||
raise NotImplemented()
|
||||
else:
|
||||
instructions[n].args += instructions[i].args[1:]
|
||||
else:
|
||||
if instructions[n].get_size() != instructions[i].get_size():
|
||||
# merge as non-vector instruction
|
||||
instructions[n].args = expand_vector_args(instructions[n]) + \
|
||||
expand_vector_args(instructions[i])
|
||||
if instructions[n].is_vec():
|
||||
instructions[n].size = 1
|
||||
else:
|
||||
instructions[n].args += instructions[i].args
|
||||
|
||||
# join arg_formats if not special iterators
|
||||
# if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \
|
||||
# not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)):
|
||||
# instructions[n].arg_format += instructions[i].arg_format
|
||||
instructions[i] = None
|
||||
self.merge_nodes(n, i)
|
||||
mergecount += 1
|
||||
|
||||
return mergecount, n
|
||||
|
||||
def compute_max_depths(self, depth_of):
|
||||
""" Compute the maximum 'depth' at which every instruction can be placed.
|
||||
This is the minimum depth of any merge_node succeeding an instruction.
|
||||
|
||||
Similar to DAG shortest paths algorithm. Traverses the graph in reverse
|
||||
topological order, updating the max depth of each node's predecessors.
|
||||
"""
|
||||
G = self.G
|
||||
merge_nodes_set = self.open_nodes
|
||||
top_order = Compiler.graph.topological_sort(G)
|
||||
max_depth_of = [None] * len(G)
|
||||
max_depth = max(depth_of)
|
||||
|
||||
for i in range(len(max_depth_of)):
|
||||
if i in merge_nodes_set:
|
||||
max_depth_of[i] = depth_of[i] - 1
|
||||
else:
|
||||
max_depth_of[i] = max_depth
|
||||
|
||||
for u in reversed(top_order):
|
||||
for v in G.pred[u]:
|
||||
if v not in merge_nodes_set:
|
||||
max_depth_of[v] = min(max_depth_of[u], max_depth_of[v])
|
||||
return max_depth_of
|
||||
|
||||
def merge_inputs(self):
|
||||
merges = defaultdict(list)
|
||||
remaining_input_nodes = []
|
||||
def do_merge(nodes):
|
||||
if len(nodes) > 1000:
|
||||
print 'Merging %d inputs...' % len(nodes)
|
||||
self.do_merge(iter(nodes))
|
||||
for n in self.input_nodes:
|
||||
inst = self.instructions[n]
|
||||
merge = merges[inst.args[0],inst.__class__]
|
||||
if len(merge) == 0:
|
||||
remaining_input_nodes.append(n)
|
||||
merge.append(n)
|
||||
if len(merge) >= self.max_parallel_open:
|
||||
do_merge(merge)
|
||||
merge[:] = []
|
||||
for merge in merges.itervalues():
|
||||
if merge:
|
||||
do_merge(merge)
|
||||
self.input_nodes = remaining_input_nodes
|
||||
|
||||
def compute_preorder(self, merges, rev_depth_of):
|
||||
# find flexible nodes that can be on several levels
|
||||
# and find sources on level 0
|
||||
G = self.G
|
||||
merge_nodes_set = self.open_nodes
|
||||
depth_of = self.depths
|
||||
instructions = self.instructions
|
||||
flex_nodes = defaultdict(dict)
|
||||
starters = []
|
||||
for n in xrange(len(G)):
|
||||
if n not in merge_nodes_set and \
|
||||
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
|
||||
#print n, depth_of[n], rev_depth_of[n]
|
||||
flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n)
|
||||
elif len(G.pred[n]) == 0 and \
|
||||
not isinstance(self.instructions[n], RawInputInstruction):
|
||||
starters.append(n)
|
||||
if n % 10000000 == 0 and n > 0:
|
||||
print "Processed %d nodes at" % n, time.asctime()
|
||||
|
||||
inputs = defaultdict(list)
|
||||
for node in self.input_nodes:
|
||||
player = self.instructions[node].args[0]
|
||||
inputs[player].append(node)
|
||||
first_inputs = [l[0] for l in inputs.itervalues()]
|
||||
other_inputs = []
|
||||
i = 0
|
||||
while True:
|
||||
i += 1
|
||||
found = False
|
||||
for l in inputs.itervalues():
|
||||
if i < len(l):
|
||||
other_inputs.append(l[i])
|
||||
found = True
|
||||
if not found:
|
||||
break
|
||||
other_inputs.reverse()
|
||||
|
||||
preorder = []
|
||||
# magical preorder for topological search
|
||||
max_depth = max(merges)
|
||||
if max_depth > 10000:
|
||||
print "Computing pre-ordering ..."
|
||||
for i in xrange(max_depth, 0, -1):
|
||||
preorder.append(G.get_attr(merges[i], 'stop'))
|
||||
for j in flex_nodes[i-1].itervalues():
|
||||
preorder.extend(j)
|
||||
preorder.extend(flex_nodes[0].get(i, []))
|
||||
preorder.append(merges[i])
|
||||
if i % 100000 == 0 and i > 0:
|
||||
print "Done level %d at" % i, time.asctime()
|
||||
preorder.extend(other_inputs)
|
||||
preorder.extend(starters)
|
||||
preorder.extend(first_inputs)
|
||||
if max_depth > 10000:
|
||||
print "Done at", time.asctime()
|
||||
return preorder
|
||||
|
||||
def compute_continuous_preorder(self, merges, rev_depth_of):
|
||||
print 'Computing pre-ordering for continuous computation...'
|
||||
preorder = []
|
||||
sources_for = defaultdict(list)
|
||||
stops_in = defaultdict(list)
|
||||
startinputs = []
|
||||
stopinputs = []
|
||||
for source in self.sources:
|
||||
sources_for[rev_depth_of[source]].append(source)
|
||||
for merge in merges.itervalues():
|
||||
stop = self.G.get_attr(merge, 'stop')
|
||||
stops_in[rev_depth_of[stop]].append(stop)
|
||||
for node in self.input_nodes:
|
||||
if isinstance(self.instructions[node], startinput_class):
|
||||
startinputs.append(node)
|
||||
else:
|
||||
stopinputs.append(node)
|
||||
max_round = max(rev_depth_of)
|
||||
for i in xrange(max_round, 0, -1):
|
||||
preorder.extend(reversed(stops_in[i]))
|
||||
preorder.extend(reversed(sources_for[i]))
|
||||
# inputs at the beginning
|
||||
preorder.extend(reversed(stopinputs))
|
||||
preorder.extend(reversed(sources_for[0]))
|
||||
preorder.extend(reversed(startinputs))
|
||||
return preorder
|
||||
|
||||
def longest_paths_merge(self, instruction_type=startopen_class,
|
||||
merge_stopopens=True):
|
||||
""" Attempt to merge instructions of type instruction_type (which are given in
|
||||
merge_nodes) using longest paths algorithm.
|
||||
|
||||
Returns the no. of rounds of communication required after merging (assuming 1 round/instruction).
|
||||
|
||||
If merge_stopopens is True, will also merge associated stop_open instructions.
|
||||
If reorder_between_opens is True, will attempt to place non-opens between start/stop opens.
|
||||
|
||||
Doesn't use networkx.
|
||||
"""
|
||||
G = self.G
|
||||
instructions = self.instructions
|
||||
merge_nodes = self.open_nodes
|
||||
depths = self.depths
|
||||
if instruction_type is not startopen_class and merge_stopopens:
|
||||
raise CompilerError('Cannot merge stopopens whilst merging %s instructions' % instruction_type)
|
||||
if not merge_nodes and not self.input_nodes:
|
||||
return 0
|
||||
|
||||
# merge opens at same depth
|
||||
merges = defaultdict(list)
|
||||
for node in merge_nodes:
|
||||
merges[depths[node]].append(node)
|
||||
|
||||
# after merging, the first element in merges[i] remains for each depth i,
|
||||
# all others are removed from instructions and G
|
||||
last_nodes = [None, None]
|
||||
for i in sorted(merges):
|
||||
merge = merges[i]
|
||||
if len(merge) > 1000:
|
||||
print 'Merging %d opens in round %d/%d' % (len(merge), i, len(merges))
|
||||
nodes = defaultdict(lambda: None)
|
||||
for b in (False, True):
|
||||
my_merge = (m for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b)
|
||||
|
||||
if merge_stopopens:
|
||||
my_stopopen = [G.get_attr(m, 'stop') for m in merge if instructions[m] is not None and instructions[m].is_gf2n() is b]
|
||||
|
||||
mc, nodes[0,b] = self.do_merge(iter(my_merge))
|
||||
|
||||
if merge_stopopens:
|
||||
mc, nodes[1,b] = self.do_merge(iter(my_stopopen))
|
||||
|
||||
# add edges to retain order of gf2n/modp start/stop opens
|
||||
for j in (0,1):
|
||||
node2 = nodes[j,True]
|
||||
nodep = nodes[j,False]
|
||||
if nodep is not None and node2 is not None:
|
||||
G.add_edge(nodep, node2)
|
||||
# add edge to retain order of opens over rounds
|
||||
if last_nodes[j] is not None:
|
||||
G.add_edge(last_nodes[j], node2 if nodep is None else nodep)
|
||||
last_nodes[j] = nodep if node2 is None else node2
|
||||
merges[i] = last_nodes[0]
|
||||
|
||||
self.merge_inputs()
|
||||
|
||||
# compute preorder for topological sort
|
||||
if merge_stopopens and self.options.reorder_between_opens:
|
||||
if self.options.continuous or not merge_nodes:
|
||||
rev_depths = self.compute_max_depths(self.real_depths)
|
||||
preorder = self.compute_continuous_preorder(merges, rev_depths)
|
||||
else:
|
||||
rev_depths = self.compute_max_depths(self.depths)
|
||||
preorder = self.compute_preorder(merges, rev_depths)
|
||||
else:
|
||||
preorder = None
|
||||
|
||||
if len(instructions) > 100000:
|
||||
print "Topological sort ..."
|
||||
order = Compiler.graph.topological_sort(G, preorder)
|
||||
instructions[:] = [instructions[i] for i in order if instructions[i] is not None]
|
||||
if len(instructions) > 100000:
|
||||
print "Done at", time.asctime()
|
||||
|
||||
return len(merges)
|
||||
|
||||
def dependency_graph(self, merge_class=startopen_class):
|
||||
""" Create the program dependency graph. """
|
||||
block = self.block
|
||||
options = self.options
|
||||
open_nodes = set()
|
||||
self.open_nodes = open_nodes
|
||||
self.input_nodes = []
|
||||
colordict = defaultdict(lambda: 'gray', startopen='red', stopopen='red',\
|
||||
ldi='lightblue', ldm='lightblue', stm='blue',\
|
||||
mov='yellow', mulm='orange', mulc='orange',\
|
||||
triple='green', square='green', bit='green',\
|
||||
asm_input='lightgreen')
|
||||
|
||||
G = Compiler.graph.SparseDiGraph(len(block.instructions))
|
||||
self.G = G
|
||||
|
||||
reg_nodes = {}
|
||||
last_def = defaultdict(lambda: -1)
|
||||
last_mem_write = None
|
||||
last_mem_read = None
|
||||
warned_about_mem = []
|
||||
last_mem_write_of = defaultdict(list)
|
||||
last_mem_read_of = defaultdict(list)
|
||||
last_print_str = None
|
||||
last = defaultdict(lambda: defaultdict(lambda: None))
|
||||
last_open = deque()
|
||||
|
||||
depths = [0] * len(block.instructions)
|
||||
self.depths = depths
|
||||
parallel_open = defaultdict(lambda: 0)
|
||||
next_available_depth = {}
|
||||
self.sources = []
|
||||
self.real_depths = [0] * len(block.instructions)
|
||||
|
||||
def add_edge(i, j):
|
||||
from_merge = isinstance(block.instructions[i], merge_class)
|
||||
to_merge = isinstance(block.instructions[j], merge_class)
|
||||
G.add_edge(i, j)
|
||||
is_source = G.get_attr(i, 'is_source') and G.get_attr(j, 'is_source') and not from_merge
|
||||
G.set_attr(j, 'is_source', is_source)
|
||||
for d in (self.depths, self.real_depths):
|
||||
if d[j] < d[i]:
|
||||
d[j] = d[i]
|
||||
|
||||
def read(reg, n):
|
||||
if last_def[reg] != -1:
|
||||
add_edge(last_def[reg], n)
|
||||
|
||||
def write(reg, n):
|
||||
last_def[reg] = n
|
||||
|
||||
def handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind):
|
||||
this = last_access_this_kind[addr,reg_type]
|
||||
other = last_access_other_kind[addr,reg_type]
|
||||
if this and other:
|
||||
if this[-1] < other[0]:
|
||||
del this[:]
|
||||
this.append(n)
|
||||
for inst in other:
|
||||
add_edge(inst, n)
|
||||
|
||||
def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
|
||||
addr = instr.args[1]
|
||||
reg_type = instr.args[0].reg_type
|
||||
if isinstance(addr, int):
|
||||
for i in range(min(instr.get_size(), 100)):
|
||||
addr_i = addr + i
|
||||
handle_mem_access(addr_i, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if not warned_about_mem and (instr.get_size() > 100):
|
||||
print 'WARNING: Order of memory instructions ' \
|
||||
'not preserved due to long vector, errors possible'
|
||||
warned_about_mem.append(True)
|
||||
else:
|
||||
handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction):
|
||||
print 'WARNING: Order of memory instructions ' \
|
||||
'not preserved, errors possible'
|
||||
# hack
|
||||
warned_about_mem.append(True)
|
||||
|
||||
def keep_order(instr, n, t, arg_index=None):
|
||||
if arg_index is None:
|
||||
player = None
|
||||
else:
|
||||
player = instr.args[arg_index]
|
||||
if last[t][player] is not None:
|
||||
add_edge(last[t][player], n)
|
||||
last[t][player] = n
|
||||
|
||||
for n,instr in enumerate(block.instructions):
|
||||
outputs,inputs = instr.get_def(), instr.get_used()
|
||||
|
||||
G.add_node(n, is_source=True)
|
||||
|
||||
# if options.debug:
|
||||
# col = colordict[instr.__class__.__name__]
|
||||
# G.add_node(n, color=col, label=str(instr))
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
|
||||
for reg in outputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
write(i, n)
|
||||
else:
|
||||
write(reg, n)
|
||||
|
||||
if isinstance(instr, merge_class):
|
||||
open_nodes.add(n)
|
||||
last_open.append(n)
|
||||
G.add_node(n, merges=[])
|
||||
# the following must happen after adding the edge
|
||||
self.real_depths[n] += 1
|
||||
depth = depths[n] + 1
|
||||
if int(options.max_parallel_open):
|
||||
skipped_depths = set()
|
||||
while parallel_open[depth] >= int(options.max_parallel_open):
|
||||
skipped_depths.add(depth)
|
||||
depth = next_available_depth.get(depth, depth + 1)
|
||||
for d in skipped_depths:
|
||||
next_available_depth[d] = depth
|
||||
parallel_open[depth] += len(instr.args) * instr.get_size()
|
||||
depths[n] = depth
|
||||
|
||||
if isinstance(instr, stopopen_class):
|
||||
startopen = last_open.popleft()
|
||||
add_edge(startopen, n)
|
||||
G.set_attr(startopen, 'stop', n)
|
||||
G.set_attr(n, 'start', last_open)
|
||||
G.add_node(n, merges=[])
|
||||
|
||||
if isinstance(instr, ReadMemoryInstruction):
|
||||
if options.preserve_mem_order:
|
||||
last_mem_read = n
|
||||
if last_mem_write:
|
||||
add_edge(last_mem_write, n)
|
||||
else:
|
||||
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
|
||||
elif isinstance(instr, WriteMemoryInstruction):
|
||||
if options.preserve_mem_order:
|
||||
last_mem_write = n
|
||||
if last_mem_read:
|
||||
add_edge(last_mem_read, n)
|
||||
else:
|
||||
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
|
||||
# keep I/O instructions in order
|
||||
elif isinstance(instr, IOInstruction):
|
||||
if last_print_str is not None:
|
||||
add_edge(last_print_str, n)
|
||||
last_print_str = n
|
||||
elif isinstance(instr, PublicFileIOInstruction):
|
||||
keep_order(instr, n, instr.__class__)
|
||||
elif isinstance(instr, RawInputInstruction):
|
||||
keep_order(instr, n, instr.__class__, 0)
|
||||
self.input_nodes.append(n)
|
||||
G.add_node(n, merges=[])
|
||||
player = instr.args[0]
|
||||
if isinstance(instr, stopinput):
|
||||
add_edge(last[startinput_class][player], n)
|
||||
elif isinstance(instr, gstopinput):
|
||||
add_edge(last[gstartinput][player], n)
|
||||
elif isinstance(instr, startprivateoutput_class):
|
||||
keep_order(instr, n, startprivateoutput_class, 2)
|
||||
elif isinstance(instr, stopprivateoutput_class):
|
||||
keep_order(instr, n, stopprivateoutput_class, 1)
|
||||
elif isinstance(instr, prep_class):
|
||||
keep_order(instr, n, instr.args[0])
|
||||
|
||||
if not G.pred[n]:
|
||||
self.sources.append(n)
|
||||
|
||||
if n % 100000 == 0 and n > 0:
|
||||
print "Processed dependency of %d/%d instructions at" % \
|
||||
(n, len(block.instructions)), time.asctime()
|
||||
|
||||
if len(open_nodes) > 1000:
|
||||
print "Program has %d %s instructions" % (len(open_nodes), merge_class)
|
||||
|
||||
def merge_nodes(self, i, j):
|
||||
""" Merge node j into i, removing node j """
|
||||
G = self.G
|
||||
if j in G[i]:
|
||||
G.remove_edge(i, j)
|
||||
if i in G[j]:
|
||||
G.remove_edge(j, i)
|
||||
G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))
|
||||
G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))
|
||||
G.get_attr(i, 'merges').append(j)
|
||||
G.remove_node(j)
|
||||
|
||||
def eliminate_dead_code(self):
|
||||
instructions = self.instructions
|
||||
G = self.G
|
||||
merge_nodes = self.open_nodes
|
||||
count = 0
|
||||
open_count = 0
|
||||
for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)):
|
||||
# remove if instruction has result that isn't used
|
||||
unused_result = not G.degree(i) and len(inst.get_def()) \
|
||||
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
|
||||
and not isinstance(inst, (DoNotEliminateInstruction))
|
||||
stop_node = G.get_attr(i, 'stop')
|
||||
unused_startopen = stop_node != -1 and instructions[stop_node] is None
|
||||
if unused_result or unused_startopen:
|
||||
G.remove_node(i)
|
||||
merge_nodes.discard(i)
|
||||
instructions[i] = None
|
||||
count += 1
|
||||
if unused_startopen:
|
||||
open_count += len(inst.args)
|
||||
if count > 0:
|
||||
print 'Eliminated %d dead instructions, among which %d opens' % (count, open_count)
|
||||
|
||||
def print_graph(self, filename):
|
||||
f = open(filename, 'w')
|
||||
print >>f, 'digraph G {'
|
||||
for i in range(self.G.n):
|
||||
for j in self.G[i]:
|
||||
print >>f, '"%d: %s" -> "%d: %s";' % \
|
||||
(i, self.instructions[i], j, self.instructions[j])
|
||||
print >>f, '}'
|
||||
f.close()
|
||||
|
||||
def print_depth(self, filename):
|
||||
f = open(filename, 'w')
|
||||
for i in range(self.G.n):
|
||||
print >>f, '%d: %s' % (self.depths[i], self.instructions[i])
|
||||
f.close()
|
||||
548
Compiler/comparison.py
Normal file
548
Compiler/comparison.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
"""
|
||||
Functions for secure comparison of GF(p) types.
|
||||
Most protocols come from [1], with a few subroutines described in [2].
|
||||
|
||||
Function naming of comparison routines is as in [1,2], with k always
|
||||
representing the integer bit length, and kappa the statistical security
|
||||
parameter.
|
||||
|
||||
Most of these routines were implemented before the cint/sint classes, so use
|
||||
the old-fasioned Register class and assembly instructions instead of operator
|
||||
overloading.
|
||||
|
||||
The PreMulC function has a few variants, depending on whether
|
||||
preprocessing is only triples/bits, or inverse tuples or "special"
|
||||
comparison-specific preprocessing is also available.
|
||||
|
||||
[1] https://www1.cs.fau.de/filepool/publications/octavian_securescm/smcint-scn10.pdf
|
||||
[2] https://www1.cs.fau.de/filepool/publications/octavian_securescm/SecureSCM-D.9.2.pdf
|
||||
"""
|
||||
|
||||
# Use constant rounds protocols instead of log rounds
|
||||
const_rounds = False
|
||||
# Set use_inv to use preprocessed inverse tuples for more efficient
|
||||
# online phase comparisons.
|
||||
use_inv = True
|
||||
# If do_precomp is not set, use_inv uses standard inverse tuples, otherwise if
|
||||
# both are set, use a list of "special" tuples of the form
|
||||
# (r[i], r[i]^-1, r[i] * r[i-1]^-1)
|
||||
do_precomp = True
|
||||
|
||||
import instructions_base
|
||||
|
||||
def set_variant(options):
|
||||
""" Set flags based on the command-line option provided """
|
||||
global const_rounds, do_precomp, use_inv
|
||||
variant = options.comparison
|
||||
if variant == 'log':
|
||||
const_rounds = False
|
||||
elif variant == 'plain':
|
||||
const_rounds = True
|
||||
use_inv = False
|
||||
elif variant == 'inv':
|
||||
const_rounds = True
|
||||
use_inv = True
|
||||
do_precomp = True
|
||||
elif variant == 'sinv':
|
||||
const_rounds = True
|
||||
use_inv = True
|
||||
do_precomp = False
|
||||
elif variant is not None:
|
||||
raise CompilerError('Unknown comparison variant: %s' % variant)
|
||||
|
||||
def ld2i(c, n):
|
||||
""" Load immediate 2^n into clear GF(p) register c """
|
||||
t1 = program.curr_block.new_reg('c')
|
||||
ldi(t1, 2 ** (n % 30))
|
||||
for i in range(n / 30):
|
||||
t2 = program.curr_block.new_reg('c')
|
||||
mulci(t2, t1, 2 ** 30)
|
||||
t1 = t2
|
||||
movc(c, t1)
|
||||
|
||||
inverse_of_two = {}
|
||||
|
||||
def divide_by_two(res, x):
|
||||
""" Faster clear division by two using a cached value of 2^-1 mod p """
|
||||
from program import Program
|
||||
import types
|
||||
tape = Program.prog.curr_tape
|
||||
if len(inverse_of_two) == 0 or tape not in inverse_of_two:
|
||||
inverse_of_two[tape] = types.cint(1) / 2
|
||||
mulc(res, x, inverse_of_two[tape])
|
||||
|
||||
def LTZ(s, a, k, kappa):
|
||||
"""
|
||||
s = (a ?< 0)
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
t = program.curr_block.new_reg('s')
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
|
||||
def Trunc(d, a, k, m, kappa, signed):
|
||||
"""
|
||||
d = a >> m
|
||||
|
||||
k: bit length of a
|
||||
m: compile-time integer
|
||||
signed: True/False, describes a
|
||||
"""
|
||||
a_prime = program.curr_block.new_reg('s')
|
||||
t = program.curr_block.new_reg('s')
|
||||
c = [program.curr_block.new_reg('c') for i in range(3)]
|
||||
c2m = program.curr_block.new_reg('c')
|
||||
if m == 0:
|
||||
movs(d, a)
|
||||
return
|
||||
elif m == 1:
|
||||
Mod2(a_prime, a, k, kappa, signed)
|
||||
else:
|
||||
Mod2m(a_prime, a, k, m, kappa, signed)
|
||||
subs(t, a, a_prime)
|
||||
ldi(c[1], 1)
|
||||
ld2i(c2m, m)
|
||||
divc(c[2], c[1], c2m)
|
||||
mulm(d, t, c[2])
|
||||
|
||||
def TruncRoundNearest(a, k, m, kappa):
|
||||
"""
|
||||
Returns a / 2^m, rounded to the nearest integer.
|
||||
|
||||
k: bit length of m
|
||||
m: compile-time integer
|
||||
"""
|
||||
from types import sint, cint
|
||||
from library import reveal, load_int_to_secret
|
||||
if m == 1:
|
||||
lsb = sint()
|
||||
Mod2(lsb, a, k, kappa, False)
|
||||
return (a + lsb) / 2
|
||||
r_dprime = sint()
|
||||
r_prime = sint()
|
||||
r = [sint() for i in range(m)]
|
||||
u = sint()
|
||||
PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
|
||||
c_prime = c % (cint(1) << (m - 1))
|
||||
if const_rounds:
|
||||
BitLTC1(u, c_prime, r[:-1], kappa)
|
||||
else:
|
||||
BitLTL(u, c_prime, r[:-1], kappa)
|
||||
bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2
|
||||
xor = bit + u - 2 * bit * u
|
||||
prod = xor * r[-1]
|
||||
# u_prime = xor * u + (1 - xor) * r[-1]
|
||||
u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
|
||||
a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
|
||||
d = (a - a_prime) / (cint(1) << m)
|
||||
rounding = xor + r[-1] - 2 * prod
|
||||
return d + rounding
|
||||
|
||||
def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
|
||||
k: bit length of a
|
||||
m: compile-time integer
|
||||
signed: True/False, describes a
|
||||
"""
|
||||
if m >= k:
|
||||
movs(a_prime, a)
|
||||
return
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
r = [program.curr_block.new_reg('s') for i in range(m)]
|
||||
c = program.curr_block.new_reg('c')
|
||||
c_prime = program.curr_block.new_reg('c')
|
||||
v = program.curr_block.new_reg('s')
|
||||
u = program.curr_block.new_reg('s')
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2m = program.curr_block.new_reg('c')
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
ld2i(c2m, m)
|
||||
mulm(t[0], r_dprime, c2m)
|
||||
if signed:
|
||||
ld2i(c2k1, k - 1)
|
||||
addm(t[1], a, c2k1)
|
||||
else:
|
||||
t[1] = a
|
||||
adds(t[2], t[0], t[1])
|
||||
adds(t[3], t[2], r_prime)
|
||||
startopen(t[3])
|
||||
stopopen(c)
|
||||
modc(c_prime, c, c2m)
|
||||
if const_rounds:
|
||||
BitLTC1(u, c_prime, r, kappa)
|
||||
else:
|
||||
BitLTL(u, c_prime, r, kappa)
|
||||
mulm(t[4], u, c2m)
|
||||
submr(t[5], c_prime, r_prime)
|
||||
adds(a_prime, t[5], t[4])
|
||||
return r_dprime, r_prime, c, c_prime, u, t, c2k1
|
||||
|
||||
def PRandM(r_dprime, r_prime, b, k, m, kappa):
|
||||
"""
|
||||
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
|
||||
r_prime = random secret integer in range [0, 2^m - 1]
|
||||
b = array containing bits of r_prime
|
||||
"""
|
||||
t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)]
|
||||
t[0][1] = b[-1]
|
||||
PRandInt(r_dprime, k + kappa - m)
|
||||
# r_dprime is always multiplied by 2^m
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
bit(b[-1])
|
||||
for i in range(1,m):
|
||||
adds(t[i][0], t[i-1][1], t[i-1][1])
|
||||
bit(b[-i-1])
|
||||
adds(t[i][1], t[i][0], b[-i-1])
|
||||
movs(r_prime, t[m-1][1])
|
||||
|
||||
def PRandInt(r, k):
|
||||
"""
|
||||
r = random secret integer in range [0, 2^k - 1]
|
||||
"""
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(3)]
|
||||
t[2][k-1] = r
|
||||
bit(t[2][0])
|
||||
for i in range(1,k):
|
||||
adds(t[0][i], t[2][i-1], t[2][i-1])
|
||||
bit(t[1][i])
|
||||
adds(t[2][i], t[0][i], t[1][i])
|
||||
|
||||
def BitLTC1(u, a, b, kappa):
|
||||
"""
|
||||
u = a <? b
|
||||
|
||||
a: array of clear bits
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
k = len(b)
|
||||
p = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
if instructions_base.get_global_vector_size() == 1:
|
||||
b_vec = program.curr_block.new_reg('s', size=k)
|
||||
for i in range(k):
|
||||
movs(b_vec[i], b[i])
|
||||
a_bits = program.curr_block.new_reg('c', size=k)
|
||||
d = program.curr_block.new_reg('s', size=k)
|
||||
s = program.curr_block.new_reg('s', size=k)
|
||||
t = [program.curr_block.new_reg('s', size=k) for j in range(5)]
|
||||
c = [program.curr_block.new_reg('c', size=k) for j in range(4)]
|
||||
else:
|
||||
a_bits = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
d = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(5)]
|
||||
c = [[program.curr_block.new_reg('c') for i in range(k)] for j in range(4)]
|
||||
c[1] = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
# warning: computer scientists count from 0
|
||||
modci(a_bits[0], a, 2)
|
||||
c[1][0] = a
|
||||
for i in range(1,k):
|
||||
subc(c[0][i], c[1][i-1], a_bits[i-1])
|
||||
divide_by_two(c[1][i], c[0][i])
|
||||
modci(a_bits[i], c[1][i], 2)
|
||||
if instructions_base.get_global_vector_size() == 1:
|
||||
vmulci(k, c[2], a_bits, 2)
|
||||
vmulm(k, t[0], b_vec, c[2])
|
||||
vaddm(k, t[1], b_vec, a_bits)
|
||||
vsubs(k, d, t[1], t[0])
|
||||
vaddsi(k, t[2], d, 1)
|
||||
t[2].create_vector_elements()
|
||||
pre_input = t[2].vector[:]
|
||||
else:
|
||||
for i in range(k):
|
||||
mulci(c[2][i], a_bits[i], 2)
|
||||
mulm(t[0][i], b[i], c[2][i])
|
||||
addm(t[1][i], b[i], a_bits[i])
|
||||
subs(d[i], t[1][i], t[0][i])
|
||||
addsi(t[2][i], d[i], 1)
|
||||
pre_input = t[2][:]
|
||||
pre_input.reverse()
|
||||
if use_inv:
|
||||
if instructions_base.get_global_vector_size() == 1:
|
||||
PreMulC_with_inverses_and_vectors(p, pre_input)
|
||||
else:
|
||||
if do_precomp:
|
||||
PreMulC_with_inverses(p, pre_input)
|
||||
else:
|
||||
raise NotImplementedError('Vectors not compatible with -c sinv')
|
||||
else:
|
||||
PreMulC_without_inverses(p, pre_input)
|
||||
p.reverse()
|
||||
for i in range(k-1):
|
||||
subs(s[i], p[i], p[i+1])
|
||||
subsi(s[k-1], p[k-1], 1)
|
||||
subcfi(c[3][0], a_bits[0], 1)
|
||||
mulm(t[4][0], s[0], c[3][0])
|
||||
for i in range(1,k):
|
||||
subcfi(c[3][i], a_bits[i], 1)
|
||||
mulm(t[3][i], s[i], c[3][i])
|
||||
adds(t[4][i], t[4][i-1], t[3][i])
|
||||
Mod2(u, t[4][k-1], k, kappa, False)
|
||||
return p, a_bits, d, s, t, c, b, pre_input
|
||||
|
||||
def carry(b, a, compute_p):
|
||||
""" Carry propogation:
|
||||
return (p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
|
||||
"""
|
||||
if a is None:
|
||||
return b
|
||||
if b is None:
|
||||
return a
|
||||
t = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
if compute_p:
|
||||
muls(t[0], a[0], b[0])
|
||||
muls(t[1], a[0], b[1])
|
||||
adds(t[2], a[1], t[1])
|
||||
return t[0], t[2]
|
||||
|
||||
# from WP9 report
|
||||
# length of a is even
|
||||
def CarryOutAux(d, a, kappa):
|
||||
k = len(a)
|
||||
if k > 1 and k % 2 == 1:
|
||||
a.append(None)
|
||||
k += 1
|
||||
u = [None]*(k/2)
|
||||
a = a[::-1]
|
||||
if k > 1:
|
||||
for i in range(k/2):
|
||||
u[i] = carry(a[2*i+1], a[2*i], i != k/2-1)
|
||||
CarryOutAux(d, u[:k/2][::-1], kappa)
|
||||
else:
|
||||
movs(d, a[0][1])
|
||||
|
||||
# carry out with carry-in bit c
|
||||
def CarryOut(res, a, b, c, kappa):
|
||||
"""
|
||||
res = last carry bit in addition of a and b
|
||||
|
||||
a: array of clear bits
|
||||
b: array of secret bits (same length as a)
|
||||
c: initial carry-in bit
|
||||
"""
|
||||
k = len(a)
|
||||
d = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
for i in range(k):
|
||||
mulm(t[0][i], b[i], a[i])
|
||||
mulsi(t[1][i], t[0][i], 2)
|
||||
addm(t[2][i], b[i], a[i])
|
||||
subs(t[3][i], t[2][i], t[1][i])
|
||||
d[i] = [t[3][i], t[0][i]]
|
||||
mulsi(s[0], d[-1][0], c)
|
||||
adds(s[1], d[-1][1], s[0])
|
||||
d[-1][1] = s[1]
|
||||
|
||||
CarryOutAux(res, d[::-1], kappa)
|
||||
|
||||
def BitLTL(res, a, b, kappa):
|
||||
"""
|
||||
res = a <? b (logarithmic rounds version)
|
||||
|
||||
a: clear integer register
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
k = len(b)
|
||||
a_bits = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
c = [[program.curr_block.new_reg('c') for i in range(k)] for j in range(2)]
|
||||
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
|
||||
t = [program.curr_block.new_reg('s') for i in range(1)]
|
||||
modci(a_bits[0], a, 2)
|
||||
c[1][0] = a
|
||||
for i in range(1,k):
|
||||
subc(c[0][i], c[1][i-1], a_bits[i-1])
|
||||
divide_by_two(c[1][i], c[0][i])
|
||||
modci(a_bits[i], c[1][i], 2)
|
||||
for i in range(len(b)):
|
||||
subsfi(s[0][i], b[i], 1)
|
||||
CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa)
|
||||
subsfi(res, t[0], 1)
|
||||
return a_bits, s[0]
|
||||
|
||||
def PreMulC_with_inverses_and_vectors(p, a):
|
||||
"""
|
||||
p[i] = prod_{j=0}^{i-1} a[i]
|
||||
|
||||
Variant for vector registers using preprocessed inverses.
|
||||
"""
|
||||
k = len(p)
|
||||
a_vec = program.curr_block.new_reg('s', size=k)
|
||||
r = program.curr_block.new_reg('s', size=k)
|
||||
w = program.curr_block.new_reg('s', size=k)
|
||||
w_tmp = program.curr_block.new_reg('s', size=k)
|
||||
z = program.curr_block.new_reg('s', size=k)
|
||||
m = program.curr_block.new_reg('c', size=k)
|
||||
t = [program.curr_block.new_reg('s', size=k) for i in range(1)]
|
||||
c = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
# warning: computer scientists count from 0
|
||||
if do_precomp:
|
||||
vinverse(k, r, z)
|
||||
else:
|
||||
vprep(k, 'PreMulC', r, z, w_tmp)
|
||||
for i in range(1,k):
|
||||
if do_precomp:
|
||||
muls(w[i], r[i], z[i-1])
|
||||
else:
|
||||
movs(w[i], w_tmp[i])
|
||||
movs(a_vec[i], a[i])
|
||||
movs(w[0], r[0])
|
||||
movs(a_vec[0], a[0])
|
||||
vmuls(k, t[0], w, a_vec)
|
||||
vstartopen(k, t[0])
|
||||
vstopopen(k, m)
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_with_inverses(p, a):
|
||||
"""
|
||||
Variant using preprocessed inverses or special inverses.
|
||||
The latter are triples of the form (a_i, a_i^{-1}, a_i * a_{i-1}^{-1}).
|
||||
See also make_PreMulC() in Fake-Offline.cpp.
|
||||
"""
|
||||
k = len(a)
|
||||
r = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(3)]
|
||||
w = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
|
||||
z = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
m = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(1)]
|
||||
c = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
# warning: computer scientists count from 0
|
||||
for i in range(k):
|
||||
if do_precomp:
|
||||
inverse(r[0][i], z[i])
|
||||
else:
|
||||
prep('PreMulC', r[0][i], z[i], w[1][i])
|
||||
if do_precomp:
|
||||
for i in range(1,k):
|
||||
muls(w[1][i], r[0][i], z[i-1])
|
||||
w[1][0] = r[0][0]
|
||||
for i in range(k):
|
||||
muls(t[0][i], w[1][i], a[i])
|
||||
startopen(t[0][i])
|
||||
stopopen(m[i])
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_without_inverses(p, a):
|
||||
"""
|
||||
Plain variant with no extra preprocessing.
|
||||
"""
|
||||
k = len(a)
|
||||
r = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
u = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
v = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
w = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
z = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
m = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(2)]
|
||||
#tt = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
|
||||
u_inv = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
c = [program.curr_block.new_reg('c') for i in range(k)]
|
||||
# warning: computer scientists count from 0
|
||||
for i in range(k):
|
||||
triple(s[i], r[i], t[0][i])
|
||||
#adds(tt[0][i], t[0][i], a[i])
|
||||
#subs(tt[1][i], tt[0][i], a[i])
|
||||
#startopen(tt[1][i])
|
||||
startopen(t[0][i])
|
||||
stopopen(u[i])
|
||||
for i in range(k-1):
|
||||
muls(v[i], r[i+1], s[i])
|
||||
w[0] = r[0]
|
||||
one = program.curr_block.new_reg('c')
|
||||
ldi(one, 1)
|
||||
for i in range(k):
|
||||
divc(u_inv[i], one, u[i])
|
||||
# avoid division by zero, just for benchmarking
|
||||
#divc(u_inv[i], u[i], one)
|
||||
for i in range(1,k):
|
||||
mulm(w[i], v[i-1], u_inv[i-1])
|
||||
for i in range(1,k):
|
||||
mulm(z[i], s[i], u_inv[i])
|
||||
for i in range(k):
|
||||
muls(t[1][i], w[i], a[i])
|
||||
startopen(t[1][i])
|
||||
stopopen(m[i])
|
||||
PreMulC_end(p, a, c, m, z)
|
||||
|
||||
def PreMulC_end(p, a, c, m, z):
|
||||
"""
|
||||
Helper function for all PreMulC variants. Local operation.
|
||||
"""
|
||||
k = len(a)
|
||||
c[0] = m[0]
|
||||
for j in range(1,k):
|
||||
mulc(c[j], c[j-1], m[j])
|
||||
if isinstance(p, list):
|
||||
mulm(p[j], z[j], c[j])
|
||||
if isinstance(p, list):
|
||||
p[0] = a[0]
|
||||
else:
|
||||
mulm(p, z[-1], c[-1])
|
||||
|
||||
def PreMulC(a):
|
||||
p = [type(a[0])() for i in range(len(a))]
|
||||
instructions_base.set_global_instruction_type(a[0].instruction_type)
|
||||
if use_inv:
|
||||
PreMulC_with_inverses(p, a)
|
||||
else:
|
||||
PreMulC_without_inverses(p, a)
|
||||
instructions_base.reset_global_instruction_type()
|
||||
return p
|
||||
|
||||
def KMulC(a):
|
||||
"""
|
||||
Return just the product of all items in a
|
||||
"""
|
||||
from types import sint, cint
|
||||
p = sint()
|
||||
if use_inv:
|
||||
PreMulC_with_inverses(p, a)
|
||||
else:
|
||||
PreMulC_without_inverses(p, a)
|
||||
return p
|
||||
|
||||
def Mod2(a_0, a, k, kappa, signed):
|
||||
"""
|
||||
a_0 = a % 2
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
if k <= 1:
|
||||
movs(a_0, a)
|
||||
return
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
r_0 = program.curr_block.new_reg('s')
|
||||
c = program.curr_block.new_reg('c')
|
||||
c_0 = program.curr_block.new_reg('c')
|
||||
tc = program.curr_block.new_reg('c')
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
|
||||
mulsi(t[0], r_dprime, 2)
|
||||
if signed:
|
||||
ld2i(c2k1, k - 1)
|
||||
addm(t[1], a, c2k1)
|
||||
else:
|
||||
t[1] = a
|
||||
adds(t[2], t[0], t[1])
|
||||
adds(t[3], t[2], r_prime)
|
||||
startopen(t[3])
|
||||
stopopen(c)
|
||||
modci(c_0, c, 2)
|
||||
mulci(tc, c_0, 2)
|
||||
mulm(t[4], r_0, tc)
|
||||
addm(t[5], r_0, c_0)
|
||||
subs(a_0, t[5], t[4])
|
||||
|
||||
|
||||
# hack for circular dependency
|
||||
from instructions import *
|
||||
79
Compiler/compilerLib.py
Normal file
79
Compiler/compilerLib.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
from Compiler.program import Program
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
import instructions, instructions_base, types, comparison, library
|
||||
|
||||
import random
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
def run(filename, options, param=-1, merge_opens=True, emulate=True, \
|
||||
reallocate=True, assemblymode=False, debug=False):
|
||||
""" Compile a file and output a Program object.
|
||||
|
||||
If merge_opens is set to True, will attempt to merge any parallelisable open
|
||||
instructions. """
|
||||
|
||||
prog = Program(filename, options, param, assemblymode)
|
||||
instructions.program = prog
|
||||
instructions_base.program = prog
|
||||
types.program = prog
|
||||
comparison.program = prog
|
||||
prog.EMULATE = emulate
|
||||
prog.DEBUG = debug
|
||||
VARS['program'] = prog
|
||||
comparison.set_variant(options)
|
||||
|
||||
print 'Compiling file', prog.infile
|
||||
|
||||
# no longer needed, but may want to support assembly in future (?)
|
||||
if assemblymode:
|
||||
prog.restart_main_thread()
|
||||
for i in xrange(INIT_REG_MAX):
|
||||
VARS['c%d'%i] = prog.curr_block.new_reg('c')
|
||||
VARS['s%d'%i] = prog.curr_block.new_reg('s')
|
||||
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
|
||||
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
|
||||
if i % 10000000 == 0 and i > 0:
|
||||
print "Initialized %d register variables at" % i, time.asctime()
|
||||
|
||||
# first pass determines how many assembler registers are used
|
||||
prog.FIRST_PASS = True
|
||||
execfile(prog.infile, VARS)
|
||||
|
||||
if instructions_base.Instruction.count != 0:
|
||||
print 'instructions count', instructions_base.Instruction.count
|
||||
instructions_base.Instruction.count = 0
|
||||
prog.FIRST_PASS = False
|
||||
prog.reset_values()
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
execfile(prog.infile, VARS)
|
||||
|
||||
# optimize the tapes
|
||||
for tape in prog.tapes:
|
||||
tape.optimize(options)
|
||||
|
||||
# check program still does the same thing after optimizations
|
||||
if emulate:
|
||||
clearmem = list(prog.mem_c)
|
||||
sharedmem = list(prog.mem_s)
|
||||
prog.emulate()
|
||||
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
|
||||
print 'Warning: emulated memory values changed after compiler optimization'
|
||||
# raise CompilerError('Compiler optimization caused incorrect memory write.')
|
||||
|
||||
if prog.main_thread_running:
|
||||
prog.update_req(prog.curr_tape)
|
||||
print 'Program requires:', repr(prog.req_num)
|
||||
print 'Cost:', prog.req_num.cost()
|
||||
print 'Memory size:', prog.allocated_mem
|
||||
|
||||
# finalize the memory
|
||||
prog.finalize_memory()
|
||||
|
||||
return prog
|
||||
58
Compiler/config.py
Normal file
58
Compiler/config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
#INIT_REG_MAX = 655360
|
||||
INIT_REG_MAX = 1310720
|
||||
REG_MAX = 2 ** 32
|
||||
USER_MEM = 8192
|
||||
TMP_MEM = 8192
|
||||
TMP_MEM_BASE = USER_MEM
|
||||
TMP_REG = 3
|
||||
TMP_REG_BASE = REG_MAX - TMP_REG
|
||||
|
||||
P_VALUES = { -1: 2147483713, \
|
||||
32: 2147565569, \
|
||||
64: 9223372036855103489, \
|
||||
128: 172035116406933162231178957667602464769, \
|
||||
256: 57896044624266469032429686755131815517604980759976795324963608525438406557697, \
|
||||
512: 6703903964971298549787012499123814115273848577471136527425966013026501536706464354255445443244279389455058889493431223951165286470575994074291745908195329 }
|
||||
|
||||
BIT_LENGTHS = { -1: 24,
|
||||
32: 24,
|
||||
64: 32,
|
||||
128: 64,
|
||||
256: 64,
|
||||
512: 64 }
|
||||
|
||||
STAT_SEC = { -1: 6,
|
||||
32: 6,
|
||||
64: 30,
|
||||
128: 40,
|
||||
256: 40,
|
||||
512: 40 }
|
||||
|
||||
|
||||
COST = { 'modp': defaultdict(lambda: 0,
|
||||
{ 'triple': 0.00020652622883106154,
|
||||
'square': 0.00020652622883106154,
|
||||
'bit': 0.00020652622883106154,
|
||||
'inverse': 0.00020652622883106154,
|
||||
'PreMulC': 2 * 0.00020652622883106154,
|
||||
}),
|
||||
'gf2n': defaultdict(lambda: 0,
|
||||
{ 'triple': 0.00020716801325875284,
|
||||
'square': 0.00020716801325875284,
|
||||
'inverse': 0.00020716801325875284,
|
||||
'bit': 1.4492753623188405e-07,
|
||||
'bittriple': 0.00004828818388140422,
|
||||
'bitgf2ntriple': 0.00020716801325875284,
|
||||
'PreMulC': 2 * 0.00020716801325875284,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from config_mine import *
|
||||
except:
|
||||
pass
|
||||
17
Compiler/exceptions.py
Normal file
17
Compiler/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
class CompilerError(Exception):
|
||||
"""Base class for compiler exceptions."""
|
||||
pass
|
||||
|
||||
class RegisterOverflowError(CompilerError):
|
||||
pass
|
||||
|
||||
class MemoryOverflowError(CompilerError):
|
||||
pass
|
||||
|
||||
class ArgumentError(CompilerError):
|
||||
""" Exception raised for errors in instruction argument parsing. """
|
||||
def __init__(self, arg, msg):
|
||||
self.arg = arg
|
||||
self.msg = msg
|
||||
517
Compiler/floatingpoint.py
Normal file
517
Compiler/floatingpoint.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
from math import log, floor, ceil
|
||||
from Compiler.instructions import *
|
||||
import types
|
||||
import comparison
|
||||
import program
|
||||
|
||||
##
|
||||
## Helper functions for floating point arithmetic
|
||||
##
|
||||
|
||||
|
||||
def two_power(n):
|
||||
if isinstance(n, int) and n < 31:
|
||||
return 2**n
|
||||
else:
|
||||
max = types.cint(1) << 31
|
||||
res = 2**(n%31)
|
||||
for i in range(n / 31):
|
||||
res *= max
|
||||
return res
|
||||
|
||||
def EQZ(a, k, kappa):
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
d = [None]*k
|
||||
r = [types.sint() for i in range(k)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
|
||||
startopen(a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
|
||||
stopopen(c)
|
||||
for i,b in enumerate(bits(c, k)):
|
||||
d[i] = b + r[i] - 2*b*r[i]
|
||||
return 1 - KOR(d, kappa)
|
||||
|
||||
def bits(a,m):
|
||||
""" Get the bits of an int """
|
||||
if isinstance(a, int):
|
||||
res = [None]*m
|
||||
for i in range(m):
|
||||
res[i] = a & 1
|
||||
a >>= 1
|
||||
else:
|
||||
c = [[types.cint() for i in range(m)] for i in range(2)]
|
||||
res = [types.cint() for i in range(m)]
|
||||
modci(res[0], a, 2)
|
||||
c[1][0] = a
|
||||
for i in range(1,m):
|
||||
subc(c[0][i], c[1][i-1], res[i-1])
|
||||
divci(c[1][i], c[0][i], 2)
|
||||
modci(res[i], c[1][i], 2)
|
||||
return res
|
||||
|
||||
def carry(b, a, compute_p=True):
|
||||
""" Carry propogation:
|
||||
(p,g) = (p_2, g_2)o(p_1, g_1) -> (p_1 & p_2, g_2 | (p_2 & g_1))
|
||||
"""
|
||||
if compute_p:
|
||||
t1 = a[0]*b[0]
|
||||
else:
|
||||
t1 = None
|
||||
t2 = a[1] + a[0]*b[1]
|
||||
return (t1, t2)
|
||||
|
||||
def or_op(a, b, void=None):
|
||||
return a + b - a*b
|
||||
|
||||
def mul_op(a, b, void=None):
|
||||
return a * b
|
||||
|
||||
def PreORC(a, kappa=None, m=None, raw=False):
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return [a[0]]
|
||||
m = m or k
|
||||
if isinstance(a[0], types.sgf2n):
|
||||
max_k = program.Program.prog.galois_length - 1
|
||||
else:
|
||||
max_k = int(log(program.Program.prog.P) / log(2)) - kappa
|
||||
if k <= max_k:
|
||||
p = [None] * m
|
||||
if m == k:
|
||||
p[0] = a[0]
|
||||
if isinstance(a[0], types.sgf2n):
|
||||
b = comparison.PreMulC([3 - a[i] for i in range(k)])
|
||||
for i in range(m):
|
||||
tmp = b[k-1-i]
|
||||
if not raw:
|
||||
tmp = tmp.bit_decompose()[0]
|
||||
p[m-1-i] = 1 - tmp
|
||||
else:
|
||||
t = [types.sint() for i in range(m)]
|
||||
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
|
||||
for i in range(m):
|
||||
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
|
||||
p[m-1-i] = 1 - t[i]
|
||||
return p
|
||||
else:
|
||||
# not constant-round anymore
|
||||
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
|
||||
return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])
|
||||
|
||||
def PreOpL(op, items):
|
||||
"""
|
||||
Uses algorithm from SecureSCM WP9 deliverable.
|
||||
|
||||
op must be a binary function that outputs a new register
|
||||
"""
|
||||
k = len(items)
|
||||
logk = int(ceil(log(k,2)))
|
||||
kmax = 2**logk
|
||||
output = list(items)
|
||||
for i in range(logk):
|
||||
for j in range(kmax/(2**(i+1))):
|
||||
y = two_power(i) + j*two_power(i+1) - 1
|
||||
for z in range(1, 2**i+1):
|
||||
if y+z < k:
|
||||
output[y+z] = op(output[y], output[y+z], j != 0)
|
||||
return output
|
||||
|
||||
def PreOpN(op, items):
|
||||
""" Naive PreOp algorithm """
|
||||
k = len(items)
|
||||
output = [None]*k
|
||||
output[0] = items[0]
|
||||
for i in range(1, k):
|
||||
output[i] = op(output[i-1], items[i])
|
||||
return output
|
||||
|
||||
def PreOR(a, kappa=None, raw=False):
|
||||
if comparison.const_rounds:
|
||||
return PreORC(a, kappa, raw=raw)
|
||||
else:
|
||||
return PreOpL(or_op, a)
|
||||
|
||||
def KOpL(op, a):
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KOpL(op, a[:k/2])
|
||||
t2 = KOpL(op, a[k/2:])
|
||||
return op(t1, t2)
|
||||
|
||||
def KORL(a, kappa):
|
||||
""" log rounds k-ary OR """
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KORL(a[:k/2], kappa)
|
||||
t2 = KORL(a[k/2:], kappa)
|
||||
return t1 + t2 - t1*t2
|
||||
|
||||
def KORC(a, kappa):
|
||||
return PreORC(a, kappa, 1)[0]
|
||||
|
||||
def KOR(a, kappa):
|
||||
if comparison.const_rounds:
|
||||
return KORC(a, kappa)
|
||||
else:
|
||||
return KORL(a, None)
|
||||
|
||||
def KMul(a):
|
||||
if comparison.const_rounds:
|
||||
return comparison.KMulC(a)
|
||||
else:
|
||||
return KOpL(mul_op, a)
|
||||
|
||||
|
||||
def Inv(a):
|
||||
""" Invert a non-zero value """
|
||||
t = [types.sint() for i in range(3)]
|
||||
c = [types.cint() for i in range(2)]
|
||||
one = types.cint()
|
||||
ldi(one, 1)
|
||||
inverse(t[0], t[1])
|
||||
s = t[0]*a
|
||||
asm_open(c[0], s)
|
||||
# avoid division by zero for benchmarking
|
||||
divc(c[1], one, c[0])
|
||||
#divc(c[1], c[0], one)
|
||||
return c[1]*t[0]
|
||||
|
||||
def BitAdd(a, b, bits_to_compute=None):
|
||||
""" Add the bits a[k-1], ..., a[0] and b[k-1], ..., b[0], return k+1
|
||||
bits s[0], ... , s[k] """
|
||||
k = len(a)
|
||||
if not bits_to_compute:
|
||||
bits_to_compute = range(k)
|
||||
d = [None] * k
|
||||
for i in range(1,k):
|
||||
#assert(a[i].value == 0 or a[i].value == 1)
|
||||
#assert(b[i].value == 0 or b[i].value == 1)
|
||||
t = a[i]*b[i]
|
||||
d[i] = (a[i] + b[i] - 2*t, t)
|
||||
#assert(d[i][0].value == 0 or d[i][0].value == 1)
|
||||
d[0] = (None, a[0]*b[0])
|
||||
pg = PreOpL(carry, d)
|
||||
c = [pair[1] for pair in pg]
|
||||
|
||||
# (for testing)
|
||||
def print_state():
|
||||
print 'a: ',
|
||||
for i in range(k):
|
||||
print '%d ' % a[i].value,
|
||||
print '\nb: ',
|
||||
for i in range(k):
|
||||
print '%d ' % b[i].value,
|
||||
print '\nd: ',
|
||||
for i in range(k):
|
||||
print '%d ' % d[i][0].value,
|
||||
print '\n ',
|
||||
for i in range(k):
|
||||
print '%d ' % d[i][1].value,
|
||||
print '\n\npg:',
|
||||
for i in range(k):
|
||||
print '%d ' % pg[i][0].value,
|
||||
print '\n ',
|
||||
for i in range(k):
|
||||
print '%d ' % pg[i][1].value,
|
||||
print ''
|
||||
|
||||
for bit in c:
|
||||
pass#assert(bit.value == 0 or bit.value == 1)
|
||||
s = [None] * (k+1)
|
||||
if 0 in bits_to_compute:
|
||||
s[0] = a[0] + b[0] - 2*c[0]
|
||||
bits_to_compute.remove(0)
|
||||
#assert(c[0].value == a[0].value*b[0].value)
|
||||
#assert(s[0].value == 0 or s[0].value == 1)
|
||||
for i in bits_to_compute:
|
||||
s[i] = a[i] + b[i] + c[i-1] - 2*c[i]
|
||||
try:
|
||||
pass#assert(s[i].value == 0 or s[i].value == 1)
|
||||
except AssertionError:
|
||||
print '#assertion failed in BitAdd for s[%d]' % i
|
||||
print_state()
|
||||
s[k] = c[k-1]
|
||||
#print_state()
|
||||
return s
|
||||
|
||||
def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
r = [types.sint() for i in range(m)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
#assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P)
|
||||
pow2 = two_power(k + kappa)
|
||||
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
#rval = 2**m*r_dprime.value + r_prime.value
|
||||
#assert(rval % 2**m == r_prime.value)
|
||||
#assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P ))
|
||||
try:
|
||||
pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
|
||||
except AssertionError:
|
||||
print 'BitDec assertion failed'
|
||||
print 'a =', a.value
|
||||
print 'a mod 2^%d =' % k, (a.value % 2**k)
|
||||
return BitAdd(list(bits(c,m)), r, bits_to_compute)[:-1]
|
||||
|
||||
|
||||
def Pow2(a, l, kappa):
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
x = [types.sint() for i in range(m)]
|
||||
pow2k = [types.cint() for i in range(m)]
|
||||
for i in range(m):
|
||||
pow2k[i] = two_power(2**i)
|
||||
t[i] = t[i]*pow2k[i] + 1 - t[i]
|
||||
return KMul(t)
|
||||
|
||||
def B2U(a, l, kappa):
|
||||
pow2a = Pow2(a, l, kappa)
|
||||
#assert(pow2a.value == 2**a.value)
|
||||
r = [types.sint() for i in range(l)]
|
||||
t = types.sint()
|
||||
c = types.cint()
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
comparison.PRandInt(t, kappa)
|
||||
asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l)))
|
||||
comparison.program.curr_tape.require_bit_length(l + kappa)
|
||||
c = list(bits(c, l))
|
||||
x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)]
|
||||
#print ' '.join(str(b.value) for b in x)
|
||||
y = PreOR(x, kappa)
|
||||
#print ' '.join(str(b.value) for b in y)
|
||||
return [1 - y[i] for i in range(l)], pow2a
|
||||
|
||||
def Trunc(a, l, m, kappa, compute_modulo=False):
|
||||
""" Oblivious truncation by secret m """
|
||||
if l == 1:
|
||||
if compute_modulo:
|
||||
return a * m, 1 + m
|
||||
else:
|
||||
return a * (1 - m)
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
rk = types.sint()
|
||||
c = types.cint()
|
||||
ci = [types.cint() for i in range(l)]
|
||||
d = types.sint()
|
||||
x, pow2m = B2U(m, l, kappa)
|
||||
#assert(pow2m.value == 2**m.value)
|
||||
#assert(sum(b.value for b in x) == m.value)
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
t1 = two_power(i) * r[i]
|
||||
t2 = t1*x[i]
|
||||
r_prime += t2
|
||||
r_dprime += t1 - t2
|
||||
#assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
|
||||
comparison.PRandInt(rk, kappa)
|
||||
r_dprime += two_power(l) * rk
|
||||
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
|
||||
asm_open(c, a + r_dprime + r_prime)
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
#assert(ci[i].value == c.value % 2**i)
|
||||
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
|
||||
#assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P))
|
||||
lts(d, c_dprime, r_prime, l, kappa)
|
||||
if compute_modulo:
|
||||
b = c_dprime - r_prime + pow2m * d
|
||||
return b, pow2m
|
||||
else:
|
||||
pow2inv = Inv(pow2m)
|
||||
#assert(pow2inv.value * pow2m.value % comparison.program.P == 1)
|
||||
b = (a - c_dprime + r_prime) * pow2inv - d
|
||||
return b
|
||||
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
|
||||
s = (1 - overflow) * t + overflow * t / 2
|
||||
return s, overflow
|
||||
|
||||
def Int2FL(a, gamma, l, kappa):
|
||||
lam = gamma - 1
|
||||
s = types.sint()
|
||||
comparison.LTZ(s, a, gamma, kappa)
|
||||
z = EQZ(a, gamma, kappa)
|
||||
a = (1 - 2 * s) * a
|
||||
a_bits = BitDec(a, lam, lam, kappa)
|
||||
a_bits.reverse()
|
||||
b = PreOR(a_bits, kappa)
|
||||
t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b)))
|
||||
p = - (lam - sum(b))
|
||||
if gamma - 1 > l:
|
||||
if types.sfloat.round_nearest:
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
|
||||
p = p + overflow
|
||||
else:
|
||||
v = types.sint()
|
||||
comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
|
||||
else:
|
||||
v = 2**(l-gamma+1) * t
|
||||
p = (p + gamma - 1 - l) * (1 -z)
|
||||
return v, p, z, s
|
||||
|
||||
def FLRound(x, mode):
|
||||
""" Rounding with floating point output.
|
||||
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
|
||||
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
|
||||
a = types.sint()
|
||||
comparison.LTZ(a, p1, k, x.kappa)
|
||||
b = p1.less_than(-l + 1, k, x.kappa)
|
||||
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
|
||||
c = EQZ(v2, l, x.kappa)
|
||||
if mode == -1:
|
||||
away_from_zero = 0
|
||||
mode = x.s
|
||||
else:
|
||||
away_from_zero = mode + s1 - 2 * mode * s1
|
||||
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
|
||||
d = v.equal(two_power(l), l + 1, x.kappa)
|
||||
v = d * two_power(l-1) + (1 - d) * v
|
||||
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
|
||||
s = (1 - b * mode) * s1
|
||||
z = or_op(EQZ(v, l, x.kappa), z1)
|
||||
v = v * (1 - z)
|
||||
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
|
||||
return v, p, z, s
|
||||
|
||||
def TruncPr(a, k, m, kappa=None):
|
||||
""" Probabilistic truncation [a/2^m + u]
|
||||
where Pr[u = 1] = (a % 2^m) / 2^m
|
||||
"""
|
||||
if kappa is None:
|
||||
kappa = 40
|
||||
|
||||
b = two_power(k-1) + a
|
||||
r_prime, r_dprime = types.sint(), types.sint()
|
||||
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
|
||||
k, m, kappa)
|
||||
two_to_m = two_power(m)
|
||||
r = two_to_m * r_dprime + r_prime
|
||||
c = (b + r).reveal()
|
||||
c_prime = c % two_to_m
|
||||
a_prime = c_prime - r_prime
|
||||
d = (a - a_prime) / two_to_m
|
||||
return d
|
||||
|
||||
def SDiv(a, b, l, kappa):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
beta = 1 / types.cint(two_power(l))
|
||||
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = TruncPr(y, 2 * l, l, kappa)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
|
||||
x1 = (x - x2) * beta
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
|
||||
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
|
||||
x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa)
|
||||
x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
|
||||
x1 = (x - x2) * beta
|
||||
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
|
||||
y = TruncPr(y, 2 * l + 1, l - 1, kappa)
|
||||
return y
|
||||
|
||||
def SDiv_mono(a, b, l, kappa):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
|
||||
for i in range(theta-1):
|
||||
y = y * (alpha + x)
|
||||
# keep y with l bits
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
x = x**2
|
||||
# keep x with 2l bits
|
||||
x = TruncPr(x, 4 * l, 2 * l, kappa)
|
||||
y = y * (alpha + x)
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
return y
|
||||
|
||||
|
||||
def FPDiv(a, b, k, f, kappa):
|
||||
theta = int(ceil(log(k/3.5)))
|
||||
alpha = types.cint(1 * two_power(2*f))
|
||||
|
||||
w = AppRcr(b, k, f, kappa)
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = TruncPr(y, 2*k, f, kappa)
|
||||
|
||||
for i in range(theta):
|
||||
y = y * (alpha + x)
|
||||
x = x * x
|
||||
y = TruncPr(y, 2*k, 2*f, kappa)
|
||||
x = TruncPr(x, 2*k, 2*f, kappa)
|
||||
|
||||
y = y * (alpha + x)
|
||||
y = TruncPr(y, 2*k, 2*f, kappa)
|
||||
return y
|
||||
|
||||
def AppRcr(b, k, f, kappa):
|
||||
"""
|
||||
Approximate reciprocal of [b]:
|
||||
Given [b], compute [1/b]
|
||||
"""
|
||||
alpha = types.cint(int(2.9142 * (2**k)))
|
||||
c, v = Norm(b, k, f, kappa)
|
||||
d = alpha - 2 * c
|
||||
w = d * v
|
||||
w = TruncPr(w, 2 * k, 2 * (k - f))
|
||||
return w
|
||||
|
||||
def Norm(b, k, f, kappa):
|
||||
"""
|
||||
Computes secret integer values [c] and [v_prime] st.
|
||||
2^{k-1} <= c < 2^k and c = b*v_prime
|
||||
"""
|
||||
temp = types.sint()
|
||||
comparison.LTZ(temp, b, k, kappa)
|
||||
sign = 1 - 2 * temp # 1 - 2 * [b < 0]
|
||||
|
||||
x = sign * b
|
||||
#x = |b|
|
||||
bits = x.bit_decompose(k)
|
||||
y = PreOR(bits)
|
||||
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
z[i] = y[i] - y[i + 1]
|
||||
|
||||
z[k - 1] = y[k - 1]
|
||||
# z[i] = 0 for all i except when bits[i + 1] = first one
|
||||
|
||||
#now reverse bits of z[i]
|
||||
v = types.sint()
|
||||
for i in range(k):
|
||||
v += two_power(k - i - 1) * z[i]
|
||||
c = x * v
|
||||
v_prime = sign * v
|
||||
return c, v_prime
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
220
Compiler/graph.py
Normal file
220
Compiler/graph.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import heapq
|
||||
from Compiler.exceptions import *
|
||||
|
||||
class GraphError(CompilerError):
|
||||
pass
|
||||
|
||||
class SparseDiGraph(object):
|
||||
""" Directed graph suitable when each node only has a small number of edges.
|
||||
|
||||
Edges are stored as a list instead of a dictionary to save memory, leading
|
||||
to slower searching for dense graphs.
|
||||
|
||||
Node attributes must be specified in advance, as these are stored in the
|
||||
same list as edges.
|
||||
"""
|
||||
def __init__(self, max_nodes, default_attributes=None):
|
||||
""" max_nodes: maximum no of nodes
|
||||
default_attributes: dict of node attributes and default values """
|
||||
if default_attributes is None:
|
||||
default_attributes = { 'merges': None, 'stop': -1, 'start': -1, 'is_source': True }
|
||||
self.default_attributes = default_attributes
|
||||
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
|
||||
self.n = max_nodes
|
||||
# each node contains list of default attributes, followed by outoing edges
|
||||
self.nodes = [self.default_attributes.values() for i in range(self.n)]
|
||||
self.pred = [[] for i in range(self.n)]
|
||||
self.weights = {}
|
||||
|
||||
def __len__(self):
|
||||
return self.n
|
||||
|
||||
def __getitem__(self, i):
|
||||
""" Get list of the neighbours of node i """
|
||||
return self.nodes[i][len(self.default_attributes):]
|
||||
|
||||
def __iter__(self):
|
||||
pass #return iter(self.nodes)
|
||||
|
||||
def __contains__(self, i):
|
||||
return i >= 0 and i < self.n
|
||||
|
||||
def add_node(self, i, **attr):
|
||||
if i >= self.n:
|
||||
raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n))
|
||||
node = self.nodes[i]
|
||||
|
||||
for a,value in attr.items():
|
||||
if a in self.default_attributes:
|
||||
node[self.attribute_pos[a]] = value
|
||||
else:
|
||||
raise CompilerError('Invalid attribute %s for graph node' % a)
|
||||
|
||||
def set_attr(self, i, attr, value):
|
||||
if attr in self.default_attributes:
|
||||
self.nodes[i][self.attribute_pos[attr]] = value
|
||||
else:
|
||||
raise CompilerError('Invalid attribute %s for graph node' % attr)
|
||||
|
||||
def get_attr(self, i, attr):
|
||||
return self.nodes[i][self.attribute_pos[attr]]
|
||||
|
||||
def remove_node(self, i):
|
||||
""" Remove node i and all its edges """
|
||||
succ = self[i]
|
||||
pred = self.pred[i]
|
||||
for v in succ:
|
||||
self.pred[v].remove(i)
|
||||
#del self.weights[(i,v)]
|
||||
for v in pred:
|
||||
# find index to ensure attribute isn't removed instead
|
||||
index = self[v].index(i) + len(self.default_attributes)
|
||||
del self.nodes[v][index]
|
||||
#del self.weights[(v,i)]
|
||||
#self.nodes[v].remove(i)
|
||||
self.pred[i] = []
|
||||
self.nodes[i] = self.default_attributes.values()
|
||||
|
||||
def add_edge(self, i, j, weight=1):
|
||||
if j not in self[i]:
|
||||
self.nodes[i].append(j)
|
||||
self.pred[j].append(i)
|
||||
self.weights[(i,j)] = weight
|
||||
|
||||
def add_edges_from(self, tuples):
|
||||
for edge in tuples:
|
||||
if len(edge) == 3:
|
||||
# use weight
|
||||
self.add_edge(edge[0], edge[1], edge[2])
|
||||
else:
|
||||
self.add_edge(edge[0], edge[1])
|
||||
|
||||
def remove_edge(self, i, j):
|
||||
jindex = self[i].index(j) + len(self.default_attributes)
|
||||
del self.nodes[i][jindex]
|
||||
self.pred[j].remove(i)
|
||||
del self.weights[(i,j)]
|
||||
|
||||
def remove_edges_from(self, pairs):
|
||||
for i,j in pairs:
|
||||
self.remove_edge(i, j)
|
||||
|
||||
def degree(self, i):
|
||||
return len(self.nodes[i]) - len(self.default_attributes)
|
||||
|
||||
|
||||
def topological_sort(G, nbunch=None, pref=None):
|
||||
seen={}
|
||||
order_explored=[] # provide order and
|
||||
explored={} # fast search without more general priorityDictionary
|
||||
|
||||
if pref is None:
|
||||
def get_children(node):
|
||||
return G[node]
|
||||
else:
|
||||
def get_children(node):
|
||||
if pref.has_key(node):
|
||||
pref_set = set(pref[node])
|
||||
for i in G[node]:
|
||||
if i not in pref_set:
|
||||
yield i
|
||||
for i in reversed(pref[node]):
|
||||
yield i
|
||||
else:
|
||||
for i in G[node]:
|
||||
yield i
|
||||
|
||||
if nbunch is None:
|
||||
nbunch = range(len(G))
|
||||
for v in nbunch: # process all vertices in G
|
||||
if v in explored:
|
||||
continue
|
||||
fringe=[v] # nodes yet to look at
|
||||
while fringe:
|
||||
w=fringe[-1] # depth first search
|
||||
if w in explored: # already looked down this branch
|
||||
fringe.pop()
|
||||
continue
|
||||
seen[w]=1 # mark as seen
|
||||
# Check successors for cycles and for new nodes
|
||||
new_nodes=[]
|
||||
for n in get_children(w):
|
||||
if n not in explored:
|
||||
if n in seen: #CYCLE !!
|
||||
raise GraphError("Graph contains a cycle at %d (%s,%s)." % \
|
||||
(n, G[n], G.pred[n]))
|
||||
new_nodes.append(n)
|
||||
if new_nodes: # Add new_nodes to fringe
|
||||
fringe.extend(new_nodes)
|
||||
else: # No new nodes so w is fully explored
|
||||
explored[w]=1
|
||||
order_explored.append(w)
|
||||
fringe.pop() # done considering this node
|
||||
|
||||
order_explored.reverse() # reverse order explored
|
||||
return order_explored
|
||||
|
||||
def dag_shortest_paths(G, source):
|
||||
top_order = topological_sort(G)
|
||||
dist = [None] * len(G)
|
||||
dist[source] = 0
|
||||
for u in top_order:
|
||||
if dist[u] is None:
|
||||
continue
|
||||
for v in G[u]:
|
||||
if dist[v] is None or dist[v] > dist[u] + G.weights[(u,v)]:
|
||||
dist[v] = dist[u] + G.weights[(u,v)]
|
||||
return dist
|
||||
|
||||
def reverse_dag_shortest_paths(G, source):
|
||||
top_order = reversed(topological_sort(G))
|
||||
dist = [None] * len(G)
|
||||
dist[source] = 0
|
||||
for u in top_order:
|
||||
if u ==68273:
|
||||
print 'dist[68273]', dist[u]
|
||||
print 'pred[u]', G.pred[u]
|
||||
if dist[u] is None:
|
||||
continue
|
||||
for v in G.pred[u]:
|
||||
if dist[v] is None or dist[v] > dist[u] + G.weights[(v,u)]:
|
||||
dist[v] = dist[u] + G.weights[(v,u)]
|
||||
return dist
|
||||
|
||||
def single_source_longest_paths(G, source, reverse=False):
|
||||
# make weights negative, then do shortest paths
|
||||
for edge in G.weights:
|
||||
G.weights[edge] = -G.weights[edge]
|
||||
if reverse:
|
||||
dist = reverse_dag_shortest_paths(G, source)
|
||||
else:
|
||||
dist = dag_shortest_paths(G, source)
|
||||
#dist = johnson(G, sources)
|
||||
# reset weights
|
||||
for edge in G.weights:
|
||||
G.weights[edge] = -G.weights[edge]
|
||||
for i,n in enumerate(dist):
|
||||
if n is None:
|
||||
dist[i] = 0
|
||||
else:
|
||||
dist[i] = -dist[i]
|
||||
#for k, v in dist.iteritems():
|
||||
# dist[k] = -v
|
||||
return dist
|
||||
|
||||
|
||||
def longest_paths(G, sources=None):
|
||||
# make weights negative, then do shortest paths
|
||||
for edge in G.weights:
|
||||
G.weights[edge] = -G.weights[edge]
|
||||
dist = {}
|
||||
for source in sources:
|
||||
print ('%s, ' % source),
|
||||
dist[source] = dag_shortest_paths(G, source)
|
||||
#dist = johnson(G, sources)
|
||||
# reset weights
|
||||
for edge in G.weights:
|
||||
G.weights[edge] = -G.weights[edge]
|
||||
return dist
|
||||
1310
Compiler/instructions.py
Normal file
1310
Compiler/instructions.py
Normal file
File diff suppressed because it is too large
Load Diff
742
Compiler/instructions_base.py
Normal file
742
Compiler/instructions_base.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import itertools
|
||||
from random import randint
|
||||
import time
|
||||
import inspect
|
||||
import functools
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.config import *
|
||||
from Compiler import util
|
||||
|
||||
|
||||
###
|
||||
### Opcode constants
|
||||
###
|
||||
### Whenever these are changed the corresponding enums in Processor/instruction.h
|
||||
### MUST also be changed. (+ the documentation)
|
||||
###
|
||||
opcodes = dict(
|
||||
# Load/store
|
||||
LDI = 0x1,
|
||||
LDSI = 0x2,
|
||||
LDMC = 0x3,
|
||||
LDMS = 0x4,
|
||||
STMC = 0x5,
|
||||
STMS = 0x6,
|
||||
LDMCI = 0x7,
|
||||
LDMSI = 0x8,
|
||||
STMCI = 0x9,
|
||||
STMSI = 0xA,
|
||||
MOVC = 0xB,
|
||||
MOVS = 0xC,
|
||||
PROTECTMEMS = 0xD,
|
||||
PROTECTMEMC = 0xE,
|
||||
PROTECTMEMINT = 0xF,
|
||||
LDMINT = 0xCA,
|
||||
STMINT = 0xCB,
|
||||
LDMINTI = 0xCC,
|
||||
STMINTI = 0xCD,
|
||||
PUSHINT = 0xCE,
|
||||
POPINT = 0xCF,
|
||||
MOVINT = 0xD0,
|
||||
# Machine
|
||||
LDTN = 0x10,
|
||||
LDARG = 0x11,
|
||||
REQBL = 0x12,
|
||||
STARG = 0x13,
|
||||
TIME = 0x14,
|
||||
START = 0x15,
|
||||
STOP = 0x16,
|
||||
USE = 0x17,
|
||||
USE_INP = 0x18,
|
||||
RUN_TAPE = 0x19,
|
||||
JOIN_TAPE = 0x1A,
|
||||
CRASH = 0x1B,
|
||||
USE_PREP = 0x1C,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
ADDM = 0x22,
|
||||
ADDCI = 0x23,
|
||||
ADDSI = 0x24,
|
||||
SUBC = 0x25,
|
||||
SUBS = 0x26,
|
||||
SUBML = 0x27,
|
||||
SUBMR = 0x28,
|
||||
SUBCI = 0x29,
|
||||
SUBSI = 0x2A,
|
||||
SUBCFI = 0x2B,
|
||||
SUBSFI = 0x2C,
|
||||
# Multiplication/division
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
MULCI = 0x32,
|
||||
MULSI = 0x33,
|
||||
DIVC = 0x34,
|
||||
DIVCI = 0x35,
|
||||
MODC = 0x36,
|
||||
MODCI = 0x37,
|
||||
LEGENDREC = 0x38,
|
||||
GMULBITC = 0x136,
|
||||
GMULBITM = 0x137,
|
||||
# Open
|
||||
STARTOPEN = 0xA0,
|
||||
STOPOPEN = 0xA1,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
SQUARE = 0x52,
|
||||
INV = 0x53,
|
||||
GBITTRIPLE = 0x154,
|
||||
GBITGF2NTRIPLE = 0x155,
|
||||
INPUTMASK = 0x56,
|
||||
PREP = 0x57,
|
||||
# Input
|
||||
INPUT = 0x60,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
READSOCKETS = 0x64,
|
||||
WRITESOCKETC = 0x65,
|
||||
WRITESOCKETS = 0x66,
|
||||
OPENSOCKET = 0x67,
|
||||
CLOSESOCKET = 0x68,
|
||||
# Bitwise logic
|
||||
ANDC = 0x70,
|
||||
XORC = 0x71,
|
||||
ORC = 0x72,
|
||||
ANDCI = 0x73,
|
||||
XORCI = 0x74,
|
||||
ORCI = 0x75,
|
||||
NOTC = 0x76,
|
||||
# Bitwise shifts
|
||||
SHLC = 0x80,
|
||||
SHRC = 0x81,
|
||||
SHLCI = 0x82,
|
||||
SHRCI = 0x83,
|
||||
# Branching and comparison
|
||||
JMP = 0x90,
|
||||
JMPNZ = 0x91,
|
||||
JMPEQZ = 0x92,
|
||||
EQZC = 0x93,
|
||||
LTZC = 0x94,
|
||||
LTC = 0x95,
|
||||
GTC = 0x96,
|
||||
EQC = 0x97,
|
||||
JMPI = 0x98,
|
||||
# Integers
|
||||
LDINT = 0x9A,
|
||||
ADDINT = 0x9B,
|
||||
SUBINT = 0x9C,
|
||||
MULINT = 0x9D,
|
||||
DIVINT = 0x9E,
|
||||
# Conversion
|
||||
CONVINT = 0xC0,
|
||||
CONVMODP = 0xC1,
|
||||
GCONVGF2N = 0x1C1,
|
||||
# IO
|
||||
PRINTMEM = 0xB0,
|
||||
PRINTREG = 0XB1,
|
||||
RAND = 0xB2,
|
||||
PRINTREGPLAIN = 0xB3,
|
||||
PRINTCHR = 0xB4,
|
||||
PRINTSTR = 0xB5,
|
||||
PUBINPUT = 0xB6,
|
||||
RAWOUTPUT = 0xB7,
|
||||
STARTPRIVATEOUTPUT = 0xB8,
|
||||
STOPPRIVATEOUTPUT = 0xB9,
|
||||
PRINTCHRINT = 0xBA,
|
||||
PRINTSTRINT = 0xBB,
|
||||
GBITDEC = 0x184,
|
||||
GBITCOM = 0x185,
|
||||
)
|
||||
|
||||
|
||||
def int_to_bytes(x):
|
||||
""" 32 bit int to big-endian 4 byte conversion. """
|
||||
return [(x >> 8*i) % 256 for i in (3,2,1,0)]
|
||||
|
||||
|
||||
global_vector_size = 1
|
||||
global_vector_size_depth = 0
|
||||
global_instruction_type_stack = ['modp']
|
||||
|
||||
def set_global_vector_size(size):
|
||||
global global_vector_size, global_vector_size_depth
|
||||
if size == 1:
|
||||
return
|
||||
if global_vector_size == 1 or global_vector_size == size:
|
||||
global_vector_size = size
|
||||
global_vector_size_depth += 1
|
||||
else:
|
||||
raise CompilerError('Cannot set global vector size when already set')
|
||||
|
||||
def set_global_instruction_type(t):
|
||||
if t == 'modp' or t == 'gf2n':
|
||||
global_instruction_type_stack.append(t)
|
||||
else:
|
||||
raise CompilerError('Invalid type %s for setting global instruction type')
|
||||
|
||||
def reset_global_vector_size():
|
||||
global global_vector_size, global_vector_size_depth
|
||||
if global_vector_size_depth > 0:
|
||||
global_vector_size_depth -= 1
|
||||
if global_vector_size_depth == 0:
|
||||
global_vector_size = 1
|
||||
|
||||
def reset_global_instruction_type():
|
||||
global_instruction_type_stack.pop()
|
||||
|
||||
def get_global_vector_size():
|
||||
return global_vector_size
|
||||
|
||||
def get_global_instruction_type():
|
||||
return global_instruction_type_stack[-1]
|
||||
|
||||
|
||||
def vectorize(instruction, global_dict=None):
|
||||
""" Decorator to vectorize instructions. """
|
||||
|
||||
if global_dict is None:
|
||||
global_dict = inspect.getmodule(instruction).__dict__
|
||||
|
||||
class Vectorized_Instruction(instruction):
|
||||
__slots__ = ['size']
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
self.size = size
|
||||
super(Vectorized_Instruction, self).__init__(*args, **kwargs)
|
||||
for arg,f in zip(self.args, self.arg_format):
|
||||
if issubclass(ArgFormats[f], RegisterArgFormat):
|
||||
arg.set_size(size)
|
||||
def get_code(self):
|
||||
return (self.size << 9) + self.code
|
||||
def get_pre_arg(self):
|
||||
return "%d, " % self.size
|
||||
def is_vec(self):
|
||||
return self.size > 1
|
||||
def get_size(self):
|
||||
return self.size
|
||||
def expand(self):
|
||||
set_global_vector_size(self.size)
|
||||
super(Vectorized_Instruction, self).expand()
|
||||
reset_global_vector_size()
|
||||
|
||||
@functools.wraps(instruction)
|
||||
def maybe_vectorized_instruction(*args, **kwargs):
|
||||
if global_vector_size == 1:
|
||||
return instruction(*args, **kwargs)
|
||||
else:
|
||||
return Vectorized_Instruction(global_vector_size, *args, **kwargs)
|
||||
maybe_vectorized_instruction.vec_ins = Vectorized_Instruction
|
||||
maybe_vectorized_instruction.std_ins = instruction
|
||||
|
||||
vectorized_name = 'v' + instruction.__name__
|
||||
Vectorized_Instruction.__name__ = vectorized_name
|
||||
global_dict[vectorized_name] = Vectorized_Instruction
|
||||
global_dict[instruction.__name__ + '_class'] = instruction
|
||||
return maybe_vectorized_instruction
|
||||
|
||||
|
||||
def gf2n(instruction):
|
||||
""" Decorator to create GF_2^n instruction corresponding to a given
|
||||
modp instruction.
|
||||
|
||||
Adds the new GF_2^n instruction to the globals dictionary. Also adds a
|
||||
vectorized GF_2^n instruction if a modp version exists. """
|
||||
global_dict = inspect.getmodule(instruction).__dict__
|
||||
|
||||
if global_dict.has_key('v' + instruction.__name__):
|
||||
vectorized = True
|
||||
else:
|
||||
vectorized = False
|
||||
|
||||
if isinstance(instruction, type) and issubclass(instruction, Instruction):
|
||||
instruction_cls = instruction
|
||||
else:
|
||||
try:
|
||||
instruction_cls = global_dict[instruction.__name__ + '_class']
|
||||
except KeyError:
|
||||
raise CompilerError('Cannot decorate instruction %s' % instruction)
|
||||
|
||||
class GF2N_Instruction(instruction_cls):
|
||||
__doc__ = instruction_cls.__doc__.replace('c_', 'c^g_').replace('s_', 's^g_')
|
||||
__slots__ = []
|
||||
field_type = 'gf2n'
|
||||
if isinstance(instruction_cls.code, int):
|
||||
code = (1 << 8) + instruction_cls.code
|
||||
|
||||
# set modp registers in arg_format to GF2N registers
|
||||
if 'gf2n_arg_format' in instruction_cls.__dict__:
|
||||
arg_format = instruction_cls.gf2n_arg_format
|
||||
elif isinstance(instruction_cls.arg_format, itertools.repeat):
|
||||
__f = instruction_cls.arg_format.next()
|
||||
if __f != 'int' and __f != 'p':
|
||||
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
|
||||
else:
|
||||
__format = []
|
||||
for __f in instruction_cls.arg_format:
|
||||
if __f in ('int', 'p', 'ci', 'str'):
|
||||
__format.append(__f)
|
||||
else:
|
||||
__format.append(__f[0] + 'g' + __f[1:])
|
||||
arg_format = __format
|
||||
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
def expand(self):
|
||||
set_global_instruction_type('gf2n')
|
||||
super(GF2N_Instruction, self).expand()
|
||||
reset_global_instruction_type()
|
||||
|
||||
GF2N_Instruction.__name__ = 'g' + instruction_cls.__name__
|
||||
if vectorized:
|
||||
vec_GF2N = vectorize(GF2N_Instruction, global_dict)
|
||||
|
||||
@functools.wraps(instruction)
|
||||
def maybe_gf2n_instruction(*args, **kwargs):
|
||||
if get_global_instruction_type() == 'gf2n':
|
||||
if vectorized:
|
||||
return vec_GF2N(*args, **kwargs)
|
||||
else:
|
||||
return GF2N_Instruction(*args, **kwargs)
|
||||
else:
|
||||
return instruction(*args, **kwargs)
|
||||
|
||||
# If instruction is vectorized, new GF2N instruction must also be
|
||||
if vectorized:
|
||||
global_dict[GF2N_Instruction.__name__] = vec_GF2N
|
||||
else:
|
||||
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
|
||||
|
||||
global_dict[instruction.__name__ + '_class'] = instruction_cls
|
||||
return maybe_gf2n_instruction
|
||||
#return instruction
|
||||
|
||||
|
||||
class RegType(object):
|
||||
""" enum-like static class for Register types """
|
||||
ClearModp = 'c'
|
||||
SecretModp = 's'
|
||||
ClearGF2N = 'cg'
|
||||
SecretGF2N = 'sg'
|
||||
ClearInt = 'ci'
|
||||
|
||||
Types = [ClearModp, SecretModp, ClearGF2N, SecretGF2N, ClearInt]
|
||||
|
||||
@staticmethod
|
||||
def create_dict(init_value_fn):
|
||||
""" Create a dictionary with all the RegTypes as keys """
|
||||
return {
|
||||
RegType.ClearModp : init_value_fn(),
|
||||
RegType.SecretModp : init_value_fn(),
|
||||
RegType.ClearGF2N : init_value_fn(),
|
||||
RegType.SecretGF2N : init_value_fn(),
|
||||
RegType.ClearInt : init_value_fn(),
|
||||
}
|
||||
|
||||
class ArgFormat(object):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return NotImplemented
|
||||
|
||||
class RegisterArgFormat(ArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, program.curr_tape.Register):
|
||||
raise ArgumentError(arg, 'Invalid register argument')
|
||||
if arg.i > REG_MAX:
|
||||
raise ArgumentError(arg, 'Register index too large')
|
||||
if arg.program != program.curr_tape:
|
||||
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
|
||||
util.format_trace(arg.caller))
|
||||
if arg.reg_type != cls.reg_type:
|
||||
raise ArgumentError(arg, "Wrong register type '%s', expected '%s'" % \
|
||||
(arg.reg_type, cls.reg_type))
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return int_to_bytes(arg.i)
|
||||
|
||||
class ClearModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearModp
|
||||
|
||||
class SecretModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretModp
|
||||
|
||||
class ClearGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearGF2N
|
||||
|
||||
class SecretGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretGF2N
|
||||
|
||||
class ClearIntAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearInt
|
||||
|
||||
class IntArgFormat(ArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, (int, long)):
|
||||
raise ArgumentError(arg, 'Expected an integer-valued argument')
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return int_to_bytes(arg)
|
||||
|
||||
class ImmediateModpAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
super(ImmediateModpAF, cls).check(arg)
|
||||
if arg >= 2**31 or arg < -2**31:
|
||||
raise ArgumentError(arg, 'Immediate value outside of 32-bit range')
|
||||
|
||||
class ImmediateGF2NAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
# bounds checking for GF(2^n)???
|
||||
super(ImmediateGF2NAF, cls).check(arg)
|
||||
|
||||
class PlayerNoAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
super(PlayerNoAF, cls).check(arg)
|
||||
if arg > 256:
|
||||
raise ArgumentError(arg, 'Player number > 256')
|
||||
|
||||
class String(ArgFormat):
|
||||
length = 12
|
||||
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, str):
|
||||
raise ArgumentError(arg, 'Argument is not string')
|
||||
if len(arg) > cls.length:
|
||||
raise ArgumentError(arg, 'String longer than ' + cls.length)
|
||||
if '\0' in arg:
|
||||
raise ArgumentError(arg, 'String contains zero-byte')
|
||||
|
||||
@classmethod
|
||||
def encode(cls, arg):
|
||||
return arg + '\0' * (cls.length - len(arg))
|
||||
|
||||
ArgFormats = {
|
||||
'c': ClearModpAF,
|
||||
's': SecretModpAF,
|
||||
'cw': ClearModpAF,
|
||||
'sw': SecretModpAF,
|
||||
'cg': ClearGF2NAF,
|
||||
'sg': SecretGF2NAF,
|
||||
'cgw': ClearGF2NAF,
|
||||
'sgw': SecretGF2NAF,
|
||||
'ci': ClearIntAF,
|
||||
'ciw': ClearIntAF,
|
||||
'i': ImmediateModpAF,
|
||||
'ig': ImmediateGF2NAF,
|
||||
'int': IntArgFormat,
|
||||
'p': PlayerNoAF,
|
||||
'str': String,
|
||||
}
|
||||
|
||||
def format_str_is_reg(format_str):
|
||||
return issubclass(ArgFormats[format_str], RegisterArgFormat)
|
||||
|
||||
def format_str_is_writeable(format_str):
|
||||
return format_str_is_reg(format_str) and format_str[-1] == 'w'
|
||||
|
||||
|
||||
class Instruction(object):
|
||||
"""
|
||||
Base class for a RISC-type instruction. Has methods for checking arguments,
|
||||
getting byte encoding, emulating the instruction, etc.
|
||||
"""
|
||||
__slots__ = ['args', 'arg_format', 'code', 'caller']
|
||||
count = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
""" Create an instruction and append it to the program list. """
|
||||
self.args = list(args)
|
||||
self.check_args()
|
||||
if not program.FIRST_PASS:
|
||||
if kwargs.get('add_to_prog', True):
|
||||
program.curr_block.instructions.append(self)
|
||||
if program.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
else:
|
||||
self.caller = None
|
||||
if program.EMULATE:
|
||||
self.execute()
|
||||
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
print "Compiled %d lines at" % self.__class__.count, time.asctime()
|
||||
|
||||
def get_code(self):
|
||||
return self.code
|
||||
|
||||
def get_encoding(self):
|
||||
enc = int_to_bytes(self.get_code())
|
||||
# add the number of registers to a start/stop open instruction
|
||||
if self.has_var_args():
|
||||
enc += int_to_bytes(len(self.args))
|
||||
for arg,format in zip(self.args, self.arg_format):
|
||||
enc += ArgFormats[format].encode(arg)
|
||||
return enc
|
||||
|
||||
def get_bytes(self):
|
||||
return bytearray(self.get_encoding())
|
||||
|
||||
def execute(self):
|
||||
""" Emulate execution of this instruction """
|
||||
raise NotImplementedError('execute method must be implemented')
|
||||
|
||||
def check_args(self):
|
||||
""" Check the args match up with that specified in arg_format """
|
||||
for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)):
|
||||
if arg is None:
|
||||
if not isinstance(self.arg_format, (list, tuple)):
|
||||
break # end of optional arguments
|
||||
else:
|
||||
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
|
||||
try:
|
||||
ArgFormats[f].check(arg)
|
||||
except ArgumentError as e:
|
||||
raise CompilerError('Invalid argument "%s" to instruction: %s'
|
||||
% (e.arg, self) + '\n' + e.msg)
|
||||
|
||||
def get_used(self):
|
||||
""" Return the set of registers that are read in this instruction. """
|
||||
return set(arg for arg,w in zip(self.args, self.arg_format) if \
|
||||
format_str_is_reg(w) and not format_str_is_writeable(w))
|
||||
|
||||
def get_def(self):
|
||||
""" Return the set of registers that are written to in this instruction. """
|
||||
return set(arg for arg,w in zip(self.args, self.arg_format) if \
|
||||
format_str_is_writeable(w))
|
||||
|
||||
def get_pre_arg(self):
|
||||
return ""
|
||||
|
||||
def has_var_args(self):
|
||||
return False
|
||||
|
||||
def is_vec(self):
|
||||
return False
|
||||
|
||||
def is_gf2n(self):
|
||||
return False
|
||||
|
||||
def get_size(self):
|
||||
return 1
|
||||
|
||||
def add_usage(self, req_node):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
|
||||
|
||||
###
|
||||
### Basic arithmetic
|
||||
###
|
||||
|
||||
class AddBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value + self.args[2].value) % program.P
|
||||
|
||||
class SubBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value - self.args[2].value) % program.P
|
||||
|
||||
class MulBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value * self.args[2].value) % program.P
|
||||
|
||||
###
|
||||
### Basic arithmetic with immediate values
|
||||
###
|
||||
|
||||
class ImmediateBase(Instruction):
|
||||
__slots__ = ['op']
|
||||
|
||||
def execute(self):
|
||||
exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op)
|
||||
|
||||
class SharedImmediate(ImmediateBase):
|
||||
__slots__ = []
|
||||
arg_format = ['sw', 's', 'i']
|
||||
|
||||
class ClearImmediate(ImmediateBase):
|
||||
__slots__ = []
|
||||
arg_format = ['cw', 'c', 'i']
|
||||
|
||||
|
||||
###
|
||||
### Memory access instructions
|
||||
###
|
||||
|
||||
class DirectMemoryInstruction(Instruction):
|
||||
__slots__ = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DirectMemoryInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
class ReadMemoryInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
class WriteMemoryInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
class DirectMemoryWriteInstruction(DirectMemoryInstruction, \
|
||||
WriteMemoryInstruction):
|
||||
__slots__ = []
|
||||
def __init__(self, *args, **kwargs):
|
||||
if program.curr_tape.prevent_direct_memory_write:
|
||||
raise CompilerError('Direct memory writing prevented')
|
||||
super(DirectMemoryWriteInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
###
|
||||
### I/O instructions
|
||||
###
|
||||
|
||||
class DoNotEliminateInstruction(Instruction):
|
||||
""" What do you think? """
|
||||
__slots__ = []
|
||||
|
||||
class IOInstruction(DoNotEliminateInstruction):
|
||||
""" Instruction that uses stdin/stdout during runtime. These are linked
|
||||
to prevent instruction reordering during optimization. """
|
||||
__slots__ = []
|
||||
|
||||
@classmethod
|
||||
def str_to_int(cls, s):
|
||||
""" Convert a 4 character string to an integer. """
|
||||
if len(s) > 4:
|
||||
raise CompilerError('String longer than 4 characters')
|
||||
n = 0
|
||||
for c in reversed(s.ljust(4)):
|
||||
n <<= 8
|
||||
n += ord(c)
|
||||
return n
|
||||
|
||||
class AsymmetricCommunicationInstruction(DoNotEliminateInstruction):
|
||||
""" Instructions involving sending from or to only one party. """
|
||||
__slots__ = []
|
||||
|
||||
class RawInputInstruction(AsymmetricCommunicationInstruction):
|
||||
""" Raw input instructions. """
|
||||
__slots__ = []
|
||||
|
||||
class PublicFileIOInstruction(DoNotEliminateInstruction):
|
||||
""" Instruction to reads/writes public information from/to files. """
|
||||
__slots__ = []
|
||||
|
||||
###
|
||||
### Data access instructions
|
||||
###
|
||||
|
||||
class DataInstruction(Instruction):
|
||||
__slots__ = []
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, self.data_type), self.get_size())
|
||||
|
||||
###
|
||||
### Integer operations
|
||||
###
|
||||
|
||||
class IntegerInstruction(Instruction):
|
||||
""" Base class for integer operations. """
|
||||
__slots__ = []
|
||||
arg_format = ['ciw', 'ci', 'ci']
|
||||
|
||||
###
|
||||
### Clear comparison instructions
|
||||
###
|
||||
|
||||
class UnaryComparisonInstruction(Instruction):
|
||||
""" Base class for unary comparisons. """
|
||||
__slots__ = []
|
||||
arg_format = ['ciw', 'ci']
|
||||
|
||||
###
|
||||
### Clear shift instructions
|
||||
###
|
||||
|
||||
class ClearShiftInstruction(ClearImmediate):
|
||||
__slots__ = []
|
||||
|
||||
def check_args(self):
|
||||
super(ClearShiftInstruction, self).check_args()
|
||||
if program.galois_length > 64:
|
||||
bits = 127
|
||||
else:
|
||||
# assume 64-bit machine
|
||||
bits = 63
|
||||
if self.args[2] > bits:
|
||||
raise CompilerError('Shifting by more than %d bits '
|
||||
'not implemented' % bits)
|
||||
|
||||
###
|
||||
### Jumps etc
|
||||
###
|
||||
|
||||
class dummywrite(Instruction):
|
||||
""" Dummy instruction to create source node in the dependency graph,
|
||||
preventing read-before-write warnings. """
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.arg_format = [arg.reg_type + 'w' for arg in args]
|
||||
super(dummywrite, self).__init__(*args, **kwargs)
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
def get_encoding(self):
|
||||
return []
|
||||
|
||||
class JumpInstruction(Instruction):
|
||||
__slots__ = ['jump_arg']
|
||||
|
||||
def set_relative_jump(self, value):
|
||||
if value == -1:
|
||||
raise CompilerException('Jump by -1 would cause infinite loop')
|
||||
self.args[self.jump_arg] = value
|
||||
|
||||
def get_relative_jump(self):
|
||||
return self.args[self.jump_arg]
|
||||
|
||||
|
||||
class CISC(Instruction):
|
||||
"""
|
||||
Base class for a CISC instruction.
|
||||
|
||||
Children must implement expand(self) to process the instruction.
|
||||
"""
|
||||
__slots__ = []
|
||||
code = None
|
||||
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.check_args()
|
||||
#if EMULATE:
|
||||
# self.expand()
|
||||
if not program.FIRST_PASS:
|
||||
self.expand()
|
||||
|
||||
def expand(self):
|
||||
""" Expand this into a sequence of RISC instructions. """
|
||||
raise NotImplementedError('expand method must be implemented')
|
||||
1115
Compiler/library.py
Normal file
1115
Compiler/library.py
Normal file
File diff suppressed because it is too large
Load Diff
902
Compiler/program.py
Normal file
902
Compiler/program.py
Normal file
@@ -0,0 +1,902 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.instructions_base import RegType
|
||||
import Compiler.instructions
|
||||
import Compiler.instructions_base
|
||||
import compilerLib
|
||||
import allocator as al
|
||||
import random
|
||||
import time
|
||||
import sys, os, errno
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
import itertools
|
||||
import math
|
||||
|
||||
|
||||
data_types = dict(
|
||||
triple = 0,
|
||||
square = 1,
|
||||
bit = 2,
|
||||
inverse = 3,
|
||||
bittriple = 4,
|
||||
bitgf2ntriple = 5
|
||||
)
|
||||
|
||||
field_types = dict(
|
||||
modp = 0,
|
||||
gf2n = 1,
|
||||
)
|
||||
|
||||
|
||||
class Program(object):
|
||||
""" A program consists of a list of tapes and a scheduled order
|
||||
of execution for these tapes.
|
||||
|
||||
These are created by executing a file containing appropriate instructions
|
||||
and threads. """
|
||||
def __init__(self, name, options, param=-1, assemblymode=False):
|
||||
self.options = options
|
||||
self.init_names(name, assemblymode)
|
||||
self.P = P_VALUES[param]
|
||||
self.param = param
|
||||
self.bit_length = BIT_LENGTHS[param]
|
||||
print 'Default bit length:', self.bit_length
|
||||
self.security = STAT_SEC[param]
|
||||
print 'Default security parameter:', self.security
|
||||
self.galois_length = int(options.galois)
|
||||
print 'Galois length:', self.galois_length
|
||||
self.schedule = [('start', [])]
|
||||
self.main_ctr = 0
|
||||
self.tapes = []
|
||||
self._curr_tape = None
|
||||
self.EMULATE = True # defaults
|
||||
self.FIRST_PASS = False
|
||||
self.DEBUG = False
|
||||
self.main_thread_running = False
|
||||
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
|
||||
self.free_mem_blocks = defaultdict(set)
|
||||
self.allocated_mem_blocks = {}
|
||||
self.req_num = None
|
||||
self.tape_stack = []
|
||||
self.n_threads = 1
|
||||
self.free_threads = set()
|
||||
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % name, 'w')
|
||||
Program.prog = self
|
||||
|
||||
self.reset_values()
|
||||
|
||||
def max_par_tapes(self):
|
||||
""" Upper bound on number of tapes that will be run in parallel.
|
||||
(Excludes empty tapes) """
|
||||
if self.n_threads > 1:
|
||||
if len(self.schedule) > 1:
|
||||
raise CompilerError('Static and dynamic parallelism not compatible')
|
||||
return self.n_threads
|
||||
res = 1
|
||||
running = defaultdict(lambda: 0)
|
||||
for action,tapes in self.schedule:
|
||||
tapes = [t[0] for t in tapes if not t[0].is_empty()]
|
||||
if action == 'start':
|
||||
for tape in tapes:
|
||||
running[tape] += 1
|
||||
elif action == 'stop':
|
||||
for tape in tapes:
|
||||
running[tape] -= 1
|
||||
else:
|
||||
raise CompilerError('Invalid schedule action')
|
||||
res = max(res, sum(running.itervalues()))
|
||||
return res
|
||||
|
||||
def init_names(self, name, assemblymode):
|
||||
# ignore path to file - source must be in Programs/Source
|
||||
if 'Programs' in os.listdir(os.getcwd()):
|
||||
# compile prog in ./Programs/Source directory
|
||||
self.programs_dir = os.getcwd() + '/Programs'
|
||||
else:
|
||||
# assume source is in main SPDZ directory
|
||||
self.programs_dir = sys.path[0] + '/Programs'
|
||||
print 'Compiling program in', self.programs_dir
|
||||
|
||||
# create extra directories if needed
|
||||
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
|
||||
if not os.path.exists(self.programs_dir + '/' + dirname):
|
||||
os.mkdir(self.programs_dir + '/' + dirname)
|
||||
|
||||
name = name.split('/')[-1]
|
||||
if name.endswith('.mpc'):
|
||||
self.name = name[:-4]
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
if assemblymode:
|
||||
self.infile = self.programs_dir + '/Source/' + self.name + '.asm'
|
||||
else:
|
||||
self.infile = self.programs_dir + '/Source/' + self.name + '.mpc'
|
||||
|
||||
def new_tape(self, function, args=[], name=None):
|
||||
if name is None:
|
||||
name = function.__name__
|
||||
name = "%s-%s" % (self.name, name)
|
||||
# make sure there is a current tape
|
||||
self.curr_tape
|
||||
tape_index = len(self.tapes)
|
||||
name += "-%d" % tape_index
|
||||
self.tape_stack.append(self.curr_tape)
|
||||
self.curr_tape = Tape(name, self)
|
||||
self.curr_tape.prevent_direct_memory_write = True
|
||||
self.tapes.append(self.curr_tape)
|
||||
function(*args)
|
||||
self.finalize_tape(self.curr_tape)
|
||||
if self.tape_stack:
|
||||
self.curr_tape = self.tape_stack.pop()
|
||||
return tape_index
|
||||
|
||||
def run_tape(self, tape_index, arg):
|
||||
if self.curr_tape is not self.tapes[0]:
|
||||
raise CompilerError('Compiler does not support ' \
|
||||
'recursive spawning of threads')
|
||||
if self.free_threads:
|
||||
thread_number = self.free_threads.pop()
|
||||
else:
|
||||
thread_number = self.n_threads
|
||||
self.n_threads += 1
|
||||
self.curr_tape.start_new_basicblock(name='pre-run_tape')
|
||||
Compiler.instructions.run_tape(thread_number, arg, tape_index)
|
||||
self.curr_tape.start_new_basicblock(name='post-run_tape')
|
||||
self.curr_tape.req_node.children.append(self.tapes[tape_index].req_tree)
|
||||
return thread_number
|
||||
|
||||
def join_tape(self, thread_number):
|
||||
self.curr_tape.start_new_basicblock(name='pre-join_tape')
|
||||
Compiler.instructions.join_tape(thread_number)
|
||||
self.curr_tape.start_new_basicblock(name='post-join_tape')
|
||||
self.free_threads.add(thread_number)
|
||||
|
||||
def start_thread(self, thread, arg):
|
||||
if self.main_thread_running:
|
||||
# wait for main thread to finish
|
||||
self.schedule_wait(self.curr_tape)
|
||||
self.main_thread_running = False
|
||||
|
||||
# compile thread if not been used already
|
||||
if thread.tape not in self.tapes:
|
||||
self.curr_tape = thread.tape
|
||||
self.tapes.append(thread.tape)
|
||||
thread.target(*thread.args)
|
||||
|
||||
# add thread to schedule
|
||||
self.schedule_start(thread.tape, arg)
|
||||
self.curr_tape = None
|
||||
|
||||
def stop_thread(self, thread):
|
||||
tape = thread.tape
|
||||
self.schedule_wait(tape)
|
||||
|
||||
def update_req(self, tape):
|
||||
if self.req_num is None:
|
||||
self.req_num = tape.req_num
|
||||
else:
|
||||
self.req_num += tape.req_num
|
||||
|
||||
def read_memory(self, filename):
|
||||
""" Read the clear and shared memory from a file """
|
||||
f = open(filename)
|
||||
n = int(f.next())
|
||||
self.mem_c = [0]*n
|
||||
self.mem_s = [0]*n
|
||||
mem = self.mem_c
|
||||
done_c = False
|
||||
for line in f:
|
||||
line = line.split(' ')
|
||||
a = int(line[0])
|
||||
b = int(line[1])
|
||||
if a != -1:
|
||||
mem[a] = b
|
||||
elif done_c:
|
||||
break
|
||||
else:
|
||||
mem = self.mem_s
|
||||
done_c = True
|
||||
|
||||
def get_memory(self, mem_type, i):
|
||||
if mem_type == 'c':
|
||||
return self.mem_c[i]
|
||||
elif mem_type == 's':
|
||||
return self.mem_s[i]
|
||||
raise CompilerError('Invalid memory type')
|
||||
|
||||
def reset_values(self):
|
||||
""" Reset register and memory values. """
|
||||
for tape in self.tapes:
|
||||
tape.reset_registers()
|
||||
self.mem_c = range(USER_MEM + TMP_MEM)
|
||||
self.mem_s = range(USER_MEM + TMP_MEM)
|
||||
random.seed(0)
|
||||
|
||||
def write_bytes(self, outfile=None):
|
||||
""" Write all non-empty threads and schedule to files. """
|
||||
# runtime doesn't support 'new-style' parallelism yet
|
||||
old_style = True
|
||||
|
||||
nonempty_tapes = [t for t in self.tapes if not t.is_empty()]
|
||||
|
||||
sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name
|
||||
sch_file = open(sch_filename, 'w')
|
||||
print 'Writing to', sch_filename
|
||||
sch_file.write(str(self.max_par_tapes()) + '\n')
|
||||
sch_file.write(str(len(nonempty_tapes)) + '\n')
|
||||
sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n')
|
||||
|
||||
# assign tapes indices (needed for scheduler)
|
||||
for i,tape in enumerate(nonempty_tapes):
|
||||
tape.index = i
|
||||
|
||||
for sch in self.schedule:
|
||||
# schedule may still contain empty tapes: ignore these
|
||||
tapes = filter(lambda x: not x[0].is_empty(), sch[1])
|
||||
# no empty line
|
||||
if not tapes:
|
||||
continue
|
||||
line = ' '.join(str(t[0].index) +
|
||||
(':' + str(t[1]) if t[1] is not None else '') for t in tapes)
|
||||
if old_style:
|
||||
if sch[0] == 'start':
|
||||
sch_file.write('%d %s\n' % (len(tapes), line))
|
||||
else:
|
||||
sch_file.write('%s %d %s\n' % (tapes[0], len(tapes), line))
|
||||
|
||||
sch_file.write('0\n')
|
||||
sch_file.write(' '.join(sys.argv) + '\n')
|
||||
for tape in self.tapes:
|
||||
tape.write_bytes()
|
||||
|
||||
def schedule_start(self, tape, arg=None):
|
||||
""" Schedule the start of a thread. """
|
||||
if self.schedule[-1][0] == 'start':
|
||||
self.schedule[-1][1].append((tape, arg))
|
||||
else:
|
||||
self.schedule.append(('start', [(tape, arg)]))
|
||||
|
||||
def schedule_wait(self, tape):
|
||||
""" Schedule the end of a thread. """
|
||||
if self.schedule[-1][0] == 'stop':
|
||||
self.schedule[-1][1].append((tape, None))
|
||||
else:
|
||||
self.schedule.append(('stop', [(tape, None)]))
|
||||
self.finalize_tape(tape)
|
||||
self.update_req(tape)
|
||||
|
||||
def finalize_tape(self, tape):
|
||||
if not tape.purged:
|
||||
tape.optimize(self.options)
|
||||
tape.write_bytes()
|
||||
if self.options.asmoutfile:
|
||||
tape.write_str(self.options.asmoutfile + '-' + tape.name)
|
||||
tape.purge()
|
||||
|
||||
def emulate(self):
|
||||
""" Emulate execution of entire program. """
|
||||
self.reset_values()
|
||||
for sch in self.schedule:
|
||||
if sch[0] == 'start':
|
||||
for tape in sch[1]:
|
||||
self._curr_tape = tape
|
||||
for block in tape.basicblocks:
|
||||
for line in block.instructions:
|
||||
line.execute()
|
||||
|
||||
def restart_main_thread(self):
|
||||
if self.main_thread_running:
|
||||
# wait for main thread to finish
|
||||
self.schedule_wait(self._curr_tape)
|
||||
self.main_thread_running = False
|
||||
name = '%s-%d' % (self.name, self.main_ctr)
|
||||
self._curr_tape = Tape(name, self)
|
||||
self.tapes.append(self._curr_tape)
|
||||
self.main_ctr += 1
|
||||
# add to schedule
|
||||
self.schedule_start(self._curr_tape)
|
||||
self.main_thread_running = True
|
||||
|
||||
@property
|
||||
def curr_tape(self):
|
||||
""" The tape that is currently running."""
|
||||
if self._curr_tape is None:
|
||||
# Create a new main thread if necessary
|
||||
self.restart_main_thread()
|
||||
return self._curr_tape
|
||||
|
||||
@curr_tape.setter
|
||||
def curr_tape(self, value):
|
||||
self._curr_tape = value
|
||||
|
||||
@property
|
||||
def curr_block(self):
|
||||
""" The basic block that is currently being created. """
|
||||
return self.curr_tape.active_basicblock
|
||||
|
||||
def malloc(self, size, mem_type):
|
||||
""" Allocate memory from the top """
|
||||
if size == 0:
|
||||
return
|
||||
if isinstance(mem_type, type):
|
||||
mem_type = mem_type.reg_type
|
||||
key = size, mem_type
|
||||
if self.free_mem_blocks[key]:
|
||||
addr = self.free_mem_blocks[key].pop()
|
||||
else:
|
||||
addr = self.allocated_mem[mem_type]
|
||||
self.allocated_mem[mem_type] += size
|
||||
if len(str(addr)) != len(str(addr + size)):
|
||||
print "Memory of type '%s' now of size %d" % (mem_type, addr + size)
|
||||
self.allocated_mem_blocks[addr,mem_type] = size
|
||||
return addr
|
||||
|
||||
def free(self, addr, mem_type):
|
||||
""" Free memory """
|
||||
if self.curr_block.persistent_allocation:
|
||||
raise CompilerError('Cannot free memory within function block')
|
||||
size = self.allocated_mem_blocks.pop((addr,mem_type))
|
||||
self.free_mem_blocks[size,mem_type].add(addr)
|
||||
|
||||
def finalize_memory(self):
|
||||
import library
|
||||
self.curr_tape.start_new_basicblock(None, 'memory-usage')
|
||||
for mem_type,size in self.allocated_mem.items():
|
||||
if size:
|
||||
#print "Memory of type '%s' of size %d" % (mem_type, size)
|
||||
library.load_mem(size - 1, mem_type)
|
||||
|
||||
def public_input(self, x):
|
||||
self.public_input_file.write('%s\n' % str(x))
|
||||
|
||||
def set_bit_length(self, bit_length):
|
||||
self.bit_length = bit_length
|
||||
print 'Changed bit length for comparisons etc. to', bit_length
|
||||
|
||||
def set_security(self, security):
|
||||
self.security = security
|
||||
print 'Changed statistical security for comparison etc. to', security
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
def __init__(self, name, program, param=-1):
|
||||
""" Set prime p and the initial instructions and registers. """
|
||||
self.program = program
|
||||
self.init_names(name)
|
||||
self.P = P_VALUES[param]
|
||||
self.init_registers()
|
||||
self.req_tree = self.ReqNode(name)
|
||||
self.req_node = self.req_tree
|
||||
self.basicblocks = []
|
||||
self.purged = False
|
||||
self.active_basicblock = None
|
||||
self.start_new_basicblock()
|
||||
self._is_empty = False
|
||||
self.merge_opens = True
|
||||
self.if_states = []
|
||||
self.req_bit_length = defaultdict(lambda: 0)
|
||||
self.function_basicblocks = {}
|
||||
self.functions = []
|
||||
self.prevent_direct_memory_write = False
|
||||
|
||||
class BasicBlock(object):
|
||||
def __init__(self, parent, name, scope, exit_condition=None):
|
||||
self.parent = parent
|
||||
self.P = parent.P
|
||||
self.instructions = []
|
||||
self.name = name
|
||||
self.index = len(parent.basicblocks)
|
||||
self.open_queue = []
|
||||
self.exit_condition = exit_condition
|
||||
self.exit_block = None
|
||||
self.previous_block = None
|
||||
self.scope = scope
|
||||
self.children = []
|
||||
if scope is not None:
|
||||
scope.children.append(self)
|
||||
self.persistent_allocation = scope.persistent_allocation
|
||||
else:
|
||||
self.persistent_allocation = False
|
||||
|
||||
def new_reg(self, reg_type, size=None):
|
||||
return self.parent.new_reg(reg_type, size=size)
|
||||
|
||||
def set_return(self, previous_block, sub_block):
|
||||
self.previous_block = previous_block
|
||||
self.sub_block = sub_block
|
||||
|
||||
def adjust_return(self):
|
||||
offset = self.sub_block.get_offset(self)
|
||||
self.previous_block.return_address_store.args[1] = offset
|
||||
|
||||
def set_exit(self, condition, exit_true=None):
|
||||
""" Sets the block which we start from next, depending on the condition.
|
||||
|
||||
(Default is to go to next block in the list)
|
||||
"""
|
||||
self.exit_condition = condition
|
||||
self.exit_block = exit_true
|
||||
for reg in condition.get_used():
|
||||
reg.can_eliminate = False
|
||||
|
||||
def add_jump(self):
|
||||
""" Add the jump for this block's exit condition to list of
|
||||
instructions (must be done after merging) """
|
||||
self.instructions.append(self.exit_condition)
|
||||
|
||||
def get_offset(self, next_block):
|
||||
return next_block.offset - (self.offset + len(self.instructions))
|
||||
|
||||
def adjust_jump(self):
|
||||
""" Set the correct relative jump offset """
|
||||
offset = self.get_offset(self.exit_block)
|
||||
self.exit_condition.set_relative_jump(offset)
|
||||
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def is_empty(self):
|
||||
""" Returns True if the list of basic blocks is empty.
|
||||
|
||||
Note: False is returned even when tape only contains basic
|
||||
blocks with no instructions. However, these are removed when
|
||||
optimize is called. """
|
||||
if not self.purged:
|
||||
self._is_empty = (len(self.basicblocks) == 0)
|
||||
return self._is_empty
|
||||
|
||||
def start_new_basicblock(self, scope=False, name=''):
|
||||
# use False because None means no scope
|
||||
if scope is False:
|
||||
scope = self.active_basicblock
|
||||
suffix = '%s-%d' % (name, len(self.basicblocks))
|
||||
sub = self.BasicBlock(self, self.name + '-' + suffix, scope)
|
||||
self.basicblocks.append(sub)
|
||||
self.active_basicblock = sub
|
||||
self.req_node.add_block(sub)
|
||||
print 'Compiling basic block', sub.name
|
||||
|
||||
def init_registers(self):
|
||||
self.reset_registers()
|
||||
self.reg_counter = RegType.create_dict(lambda: 0)
|
||||
|
||||
def init_names(self, name):
|
||||
# ignore path to file - source must be in Programs/Source
|
||||
name = name.split('/')[-1]
|
||||
if name.endswith('.asm'):
|
||||
self.name = name[:-4]
|
||||
else:
|
||||
self.name = name
|
||||
self.infile = self.program.programs_dir + '/Source/' + self.name + '.asm'
|
||||
self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc'
|
||||
|
||||
def purge(self):
|
||||
self._is_empty = (len(self.basicblocks) == 0)
|
||||
del self.reg_values
|
||||
del self.basicblocks
|
||||
del self.active_basicblock
|
||||
self.purged = True
|
||||
|
||||
def unpurged(function):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.purged:
|
||||
print '%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name)
|
||||
return
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@unpurged
|
||||
def optimize(self, options):
|
||||
if len(self.basicblocks) == 0:
|
||||
print 'Tape %s is empty' % self.name
|
||||
return
|
||||
|
||||
if self.if_states:
|
||||
raise CompilerError('Unclosed if/else blocks')
|
||||
|
||||
print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)
|
||||
|
||||
for block in self.basicblocks:
|
||||
al.determine_scope(block)
|
||||
|
||||
# merge open instructions
|
||||
# need to do this if there are several blocks
|
||||
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
|
||||
for i,block in enumerate(self.basicblocks):
|
||||
if len(block.instructions) > 0:
|
||||
print 'Processing basic block %s, %d/%d, %d instructions' % \
|
||||
(block.name, i, len(self.basicblocks), \
|
||||
len(block.instructions))
|
||||
# the next call is necessary for allocation later even without merging
|
||||
merger = al.Merger(block, options)
|
||||
if options.dead_code_elimination:
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Eliminate dead code...'
|
||||
merger.eliminate_dead_code()
|
||||
if options.merge_opens and self.merge_opens:
|
||||
if len(block.instructions) == 0:
|
||||
block.used_from_scope = set()
|
||||
block.defined_registers = set()
|
||||
continue
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Merging open instructions...'
|
||||
numrounds = merger.longest_paths_merge()
|
||||
if numrounds > 0:
|
||||
print 'Program requires %d rounds of communication' % numrounds
|
||||
numinv = sum(len(i.args) for i in block.instructions if isinstance(i, Compiler.instructions.startopen_class))
|
||||
if numinv > 0:
|
||||
print 'Program requires %d invocations' % numinv
|
||||
if options.dead_code_elimination:
|
||||
block.instructions = filter(lambda x: x is not None, block.instructions)
|
||||
if not (options.merge_opens and self.merge_opens):
|
||||
print 'Not merging open instructions in tape %s' % self.name
|
||||
|
||||
# add jumps
|
||||
offset = 0
|
||||
for block in self.basicblocks:
|
||||
if block.exit_condition is not None:
|
||||
block.add_jump()
|
||||
block.offset = offset
|
||||
offset += len(block.instructions)
|
||||
for block in self.basicblocks:
|
||||
if block.exit_block is not None:
|
||||
block.adjust_jump()
|
||||
if block.previous_block is not None:
|
||||
block.adjust_return()
|
||||
|
||||
# now remove any empty blocks (must be done after setting jumps)
|
||||
self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks)
|
||||
|
||||
# allocate registers
|
||||
reg_counts = self.count_regs()
|
||||
if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate:
|
||||
print 'Tape register usage:'
|
||||
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
|
||||
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
|
||||
print 'Re-allocating...'
|
||||
allocator = al.StraightlineAllocator(REG_MAX)
|
||||
def alloc_loop(block):
|
||||
for reg in block.used_from_scope:
|
||||
allocator.alloc_reg(reg, block.persistent_allocation)
|
||||
for child in block.children:
|
||||
if child.instructions:
|
||||
alloc_loop(child)
|
||||
for i,block in enumerate(reversed(self.basicblocks)):
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Allocating %s, %d/%d' % \
|
||||
(block.name, i, len(self.basicblocks))
|
||||
if block.exit_condition is not None:
|
||||
jump = block.exit_condition.get_relative_jump()
|
||||
if isinstance(jump, (int,long)) and jump < 0 and \
|
||||
block.exit_block.scope is not None:
|
||||
alloc_loop(block.exit_block.scope)
|
||||
allocator.process(block.instructions, block.persistent_allocation)
|
||||
|
||||
# offline data requirements
|
||||
print 'Compile offline data requirements...'
|
||||
self.req_num = self.req_tree.aggregate()
|
||||
print 'Tape requires', self.req_num
|
||||
for req,num in self.req_num.items():
|
||||
if num == float('inf'):
|
||||
num = -1
|
||||
if req[1] in data_types:
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use(field_types[req[0]], \
|
||||
data_types[req[1]], num, \
|
||||
add_to_prog=False))
|
||||
elif req[1] == 'input':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_inp(field_types[req[0]], \
|
||||
req[2], num, \
|
||||
add_to_prog=False))
|
||||
elif req[0] == 'modp':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_prep(req[1], num, \
|
||||
add_to_prog=False))
|
||||
elif req[0] == 'gf2n':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.guse_prep(req[1], num, \
|
||||
add_to_prog=False))
|
||||
|
||||
if not self.is_empty():
|
||||
# bit length requirement
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False))
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False))
|
||||
print 'Tape requires prime bit length', self.req_bit_length['p']
|
||||
print 'Tape requires galois bit length', self.req_bit_length['2']
|
||||
|
||||
@unpurged
|
||||
def _get_instructions(self):
|
||||
return itertools.chain.\
|
||||
from_iterable(b.instructions for b in self.basicblocks)
|
||||
|
||||
@unpurged
|
||||
def get_encoding(self):
|
||||
""" Get the encoding of the program, in human-readable format. """
|
||||
return [i.get_encoding() for i in self._get_instructions() if i is not None]
|
||||
|
||||
@unpurged
|
||||
def get_bytes(self):
|
||||
""" Get the byte encoding of the program as an actual string of bytes. """
|
||||
return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None)
|
||||
|
||||
@unpurged
|
||||
def write_encoding(self, filename):
|
||||
""" Write the readable encoding to a file. """
|
||||
print 'Writing to', filename
|
||||
f = open(filename, 'w')
|
||||
for line in self.get_encoding():
|
||||
f.write(str(line) + '\n')
|
||||
f.close()
|
||||
|
||||
@unpurged
|
||||
def write_str(self, filename):
|
||||
""" Write the sequence of instructions to a file. """
|
||||
print 'Writing to', filename
|
||||
f = open(filename, 'w')
|
||||
n = 0
|
||||
for block in self.basicblocks:
|
||||
if block.instructions:
|
||||
f.write('# %s\n' % block.name)
|
||||
for line in block.instructions:
|
||||
f.write('%s # %d\n' % (line, n))
|
||||
n += 1
|
||||
f.close()
|
||||
|
||||
@unpurged
|
||||
def write_bytes(self, filename=None):
|
||||
""" Write the program's byte encoding to a file. """
|
||||
if filename is None:
|
||||
filename = self.outfile
|
||||
if not filename.endswith('.bc'):
|
||||
filename += '.bc'
|
||||
if not 'Bytecode' in filename:
|
||||
filename = self.program.programs_dir + '/Bytecode/' + filename
|
||||
print 'Writing to', filename
|
||||
f = open(filename, 'w')
|
||||
f.write(self.get_bytes())
|
||||
f.close()
|
||||
|
||||
def new_reg(self, reg_type, size=None):
|
||||
return self.Register(reg_type, self, size=size)
|
||||
|
||||
def count_regs(self, reg_type=None):
|
||||
if reg_type is None:
|
||||
return self.reg_counter
|
||||
else:
|
||||
return self.reg_counter[reg_type]
|
||||
|
||||
def reset_registers(self):
|
||||
""" Reset register values to zero. """
|
||||
self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX)
|
||||
|
||||
def get_value(self, reg_type, i):
|
||||
return self.reg_values[reg_type][i]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class ReqNum(defaultdict):
|
||||
def __init__(self, init={}):
|
||||
super(Tape.ReqNum, self).__init__(lambda: 0, init)
|
||||
def __add__(self, other):
|
||||
res = Tape.ReqNum()
|
||||
for i,count in self.items():
|
||||
res[i] += count
|
||||
for i,count in other.items():
|
||||
res[i] += count
|
||||
return res
|
||||
def __mul__(self, other):
|
||||
res = Tape.ReqNum()
|
||||
for i in self:
|
||||
res[i] = other * self[i]
|
||||
return res
|
||||
__rmul__ = __mul__
|
||||
def set_all(self, value):
|
||||
res = Tape.ReqNum()
|
||||
for i in self:
|
||||
res[i] = value
|
||||
return res
|
||||
def max(self, other):
|
||||
res = Tape.ReqNum()
|
||||
for i in self:
|
||||
res[i] = max(self[i], other[i])
|
||||
for i in other:
|
||||
res[i] = max(self[i], other[i])
|
||||
return res
|
||||
def cost(self):
|
||||
return sum(num * COST[req[0]][req[1]] for req,num in self.items() \
|
||||
if req[1] != 'input')
|
||||
def __str__(self):
|
||||
return ", ".join('%s inputs in %s from player %d' \
|
||||
% (num, req[0], req[2]) \
|
||||
if req[1] == 'input' \
|
||||
else '%s %ss in %s' % (num, req[1], req[0]) \
|
||||
for req,num in self.items())
|
||||
def __repr__(self):
|
||||
return repr(dict(self))
|
||||
|
||||
class ReqNode(object):
|
||||
__slots__ = ['num', 'children', 'name', 'blocks']
|
||||
def __init__(self, name):
|
||||
self.children = []
|
||||
self.name = name
|
||||
self.blocks = []
|
||||
def aggregate(self, *args):
|
||||
self.num = Tape.ReqNum()
|
||||
for block in self.blocks:
|
||||
for inst in block.instructions:
|
||||
inst.add_usage(self)
|
||||
res = reduce(lambda x,y: x + y.aggregate(self.name),
|
||||
self.children, self.num)
|
||||
return res
|
||||
def increment(self, data_type, num=1):
|
||||
self.num[data_type] += num
|
||||
def add_block(self, block):
|
||||
self.blocks.append(block)
|
||||
|
||||
class ReqChild(object):
|
||||
__slots__ = ['aggregator', 'nodes', 'parent']
|
||||
def __init__(self, aggregator, parent):
|
||||
self.aggregator = aggregator
|
||||
self.nodes = []
|
||||
self.parent = parent
|
||||
def aggregate(self, name):
|
||||
res = self.aggregator([node.aggregate() for node in self.nodes])
|
||||
return res
|
||||
def add_node(self, tape, name):
|
||||
new_node = Tape.ReqNode(name)
|
||||
self.nodes.append(new_node)
|
||||
tape.req_node = new_node
|
||||
|
||||
def open_scope(self, aggregator, scope=False, name=''):
|
||||
child = self.ReqChild(aggregator, self.req_node)
|
||||
self.req_node.children.append(child)
|
||||
child.add_node(self, '%s-%d' % (name, len(self.basicblocks)))
|
||||
self.start_new_basicblock(name=name)
|
||||
return child
|
||||
|
||||
def close_scope(self, outer_scope, parent_req_node, name):
|
||||
self.req_node = parent_req_node
|
||||
self.start_new_basicblock(outer_scope, name)
|
||||
|
||||
def require_bit_length(self, bit_length, t='p'):
|
||||
if t == 'p':
|
||||
self.req_bit_length[t] = max(bit_length + 1, \
|
||||
self.req_bit_length[t])
|
||||
if self.program.param != -1 and bit_length >= self.program.param:
|
||||
raise CompilerError('Inadequate bit length %d for prime, ' \
|
||||
'program requires %d bits' % \
|
||||
(self.program.param, self.req_bit_length['p']))
|
||||
else:
|
||||
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
|
||||
|
||||
class Register(object):
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
based on the block's reg_counter dictionary.
|
||||
|
||||
The 'value' property is for emulation.
|
||||
"""
|
||||
__slots__ = ["reg_type", "program", "i", "value", "_is_active", \
|
||||
"size", "vector", "vectorbase", "caller", \
|
||||
"can_eliminate"]
|
||||
|
||||
def __init__(self, reg_type, program, value=None, size=None, i=None):
|
||||
""" Creates a new register.
|
||||
reg_type must be one of those defined in RegType. """
|
||||
if Compiler.instructions_base.get_global_instruction_type() == 'gf2n':
|
||||
if reg_type == RegType.ClearModp:
|
||||
reg_type = RegType.ClearGF2N
|
||||
elif reg_type == RegType.SecretModp:
|
||||
reg_type = RegType.SecretGF2N
|
||||
self.reg_type = reg_type
|
||||
self.program = program
|
||||
if size is None:
|
||||
size = Compiler.instructions_base.get_global_vector_size()
|
||||
self.size = size
|
||||
if i:
|
||||
self.i = i
|
||||
else:
|
||||
self.i = program.reg_counter[reg_type]
|
||||
program.reg_counter[reg_type] += size
|
||||
self.vector = []
|
||||
self.vectorbase = self
|
||||
if value is not None:
|
||||
self.value = value
|
||||
self._is_active = False
|
||||
self.can_eliminate = True
|
||||
if Program.prog.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
else:
|
||||
self.caller = None
|
||||
if self.i % 1000000 == 0 and self.i > 0:
|
||||
print "Initialized %d registers at" % self.i, time.asctime()
|
||||
|
||||
def set_size(self, size):
|
||||
if self.size == size:
|
||||
return
|
||||
elif self.size == 1 and self.vectorbase is self:
|
||||
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
|
||||
# create vector register in assembly mode
|
||||
self.size = size
|
||||
self.vector = [self]
|
||||
for i in range(1,size):
|
||||
reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)]
|
||||
reg.set_vectorbase(self)
|
||||
self.vector.append(reg)
|
||||
else:
|
||||
raise CompilerError('Cannot find %s in VARS' % str(self))
|
||||
else:
|
||||
raise CompilerError('Cannot reset size of vector register')
|
||||
|
||||
def set_vectorbase(self, vectorbase):
|
||||
if self.vectorbase != self:
|
||||
raise CompilerError('Cannot assign one register' \
|
||||
'to several vectors')
|
||||
self.vectorbase = vectorbase
|
||||
|
||||
def create_vector_elements(self):
|
||||
if self.vector:
|
||||
return
|
||||
elif self.size == 1:
|
||||
self.vector = [self]
|
||||
return
|
||||
self.vector = [self]
|
||||
for i in range(1,self.size):
|
||||
reg = Tape.Register(self.reg_type, self.program, size=1, i=self.i+i)
|
||||
reg.set_vectorbase(self)
|
||||
self.vector.append(reg)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not self.vector:
|
||||
self.create_vector_elements()
|
||||
return self.vector[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def activate(self):
|
||||
""" Activating a register signals that it will at some point be used
|
||||
in the program.
|
||||
|
||||
Inactive registers are reserved for temporaries for CISC instructions. """
|
||||
if not self._is_active:
|
||||
self._is_active = True
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.program.reg_values[self.reg_type][self.i]
|
||||
|
||||
@value.setter
|
||||
def value(self, val):
|
||||
while (len(self.program.reg_values[self.reg_type]) <= self.i):
|
||||
self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX
|
||||
self.program.reg_values[self.reg_type][self.i] = val
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self._is_active
|
||||
|
||||
@property
|
||||
def is_gf2n(self):
|
||||
return self.reg_type == RegType.ClearGF2N or \
|
||||
self.reg_type == RegType.SecretGF2N
|
||||
|
||||
@property
|
||||
def is_clear(self):
|
||||
return self.reg_type == RegType.ClearModp or \
|
||||
self.reg_type == RegType.ClearGF2N or \
|
||||
self.reg_type == RegType.ClearInt
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
__repr__ = __str__
|
||||
9
Compiler/tools.py
Normal file
9
Compiler/tools.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import itertools
|
||||
|
||||
class chain(object):
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
def __iter__(self):
|
||||
return itertools.chain(*self.args)
|
||||
2013
Compiler/types.py
Normal file
2013
Compiler/types.py
Normal file
File diff suppressed because it is too large
Load Diff
105
Compiler/util.py
Normal file
105
Compiler/util.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
def format_trace(trace, prefix=' '):
|
||||
if trace is None:
|
||||
return '<omitted>'
|
||||
else:
|
||||
return ''.join('\n%sFile "%s", line %s, in %s\n%s %s' %
|
||||
(prefix,i[0],i[1],i[2],prefix,i[3][0].strip()) \
|
||||
for i in reversed(trace))
|
||||
|
||||
def tuplify(x):
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple(x)
|
||||
else:
|
||||
return (x,)
|
||||
|
||||
def untuplify(x):
|
||||
if len(x) == 1:
|
||||
return x[0]
|
||||
else:
|
||||
return x
|
||||
|
||||
def greater_than(a, b, bits):
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
return a > b
|
||||
else:
|
||||
return a.greater_than(b, bits)
|
||||
|
||||
def pow2(a, bits):
|
||||
if isinstance(a, int):
|
||||
return 2**a
|
||||
else:
|
||||
return a.pow2(bits)
|
||||
|
||||
def mod2m(a, b, bits, signed):
|
||||
if isinstance(a, int):
|
||||
return a % 2**b
|
||||
else:
|
||||
return a.mod2m(b, bits, signed=signed)
|
||||
|
||||
def right_shift(a, b, bits):
|
||||
if isinstance(a, int):
|
||||
return a >> b
|
||||
else:
|
||||
return a.right_shift(b, bits)
|
||||
|
||||
def bit_decompose(a, bits):
|
||||
if isinstance(a, (int,long)):
|
||||
return [int((a >> i) & 1) for i in range(bits)]
|
||||
else:
|
||||
return a.bit_decompose(bits)
|
||||
|
||||
def bit_compose(bits):
|
||||
return sum(b << i for i,b in enumerate(bits))
|
||||
|
||||
def series(a):
|
||||
sum = 0
|
||||
for i in a:
|
||||
yield sum
|
||||
sum += i
|
||||
yield sum
|
||||
|
||||
def if_else(cond, a, b):
|
||||
try:
|
||||
if isinstance(cond, (bool, int)):
|
||||
if cond:
|
||||
return a
|
||||
else:
|
||||
return b
|
||||
return cond.if_else(a, b)
|
||||
except:
|
||||
print cond, a, b
|
||||
raise
|
||||
|
||||
def cond_swap(cond, a, b):
|
||||
if isinstance(cond, (bool, int)):
|
||||
if cond:
|
||||
return a, b
|
||||
else:
|
||||
return b, a
|
||||
return cond.cond_swap(a, b)
|
||||
|
||||
def log2(x):
|
||||
#print 'Compute log2 of', x
|
||||
return int(math.ceil(math.log(x, 2)))
|
||||
|
||||
def tree_reduce(function, sequence):
|
||||
n = len(sequence)
|
||||
if n == 1:
|
||||
return sequence[0]
|
||||
else:
|
||||
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)]
|
||||
return tree_reduce(function, reduced + sequence[n/2*2:])
|
||||
|
||||
def or_op(a, b):
|
||||
return a + b - a * b
|
||||
|
||||
OR = or_op
|
||||
|
||||
def pow2(bits):
|
||||
powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)]
|
||||
return tree_reduce(operator.mul, powers)
|
||||
162
Exceptions/Exceptions.h
Normal file
162
Exceptions/Exceptions.h
Normal file
@@ -0,0 +1,162 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Exceptions
|
||||
#define _Exceptions
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
using namespace std;
|
||||
|
||||
class not_implemented: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Case not implemented"; }
|
||||
};
|
||||
class division_by_zero: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Division by zero"; }
|
||||
};
|
||||
class invalid_plaintext: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Inconsistent plaintext space"; }
|
||||
};
|
||||
class rep_mismatch: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Representation mismatch"; }
|
||||
};
|
||||
class pr_mismatch: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Prime mismatch"; }
|
||||
};
|
||||
class params_mismatch: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "FHE params mismatch"; }
|
||||
};
|
||||
class field_mismatch: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Plaintext Field mismatch"; }
|
||||
};
|
||||
class level_mismatch: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Level mismatch"; }
|
||||
};
|
||||
class invalid_length: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Invalid length"; }
|
||||
};
|
||||
class invalid_commitment: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Invalid Commitment"; }
|
||||
};
|
||||
class IO_Error: public exception
|
||||
{ string msg, ans;
|
||||
public:
|
||||
IO_Error(string m) : msg(m)
|
||||
{ ans="IO-Error : ";
|
||||
ans+=msg;
|
||||
}
|
||||
~IO_Error()throw() { }
|
||||
virtual const char* what() const throw()
|
||||
{
|
||||
return ans.c_str();
|
||||
}
|
||||
};
|
||||
class broadcast_invalid: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Inconsistent broadcast at some point"; }
|
||||
};
|
||||
class bad_keygen: public exception
|
||||
{ string msg;
|
||||
public:
|
||||
bad_keygen(string m) : msg(m) {}
|
||||
~bad_keygen()throw() { }
|
||||
virtual const char* what() const throw()
|
||||
{ string ans="KeyGen has gone wrong: "+msg;
|
||||
return ans.c_str();
|
||||
}
|
||||
};
|
||||
class bad_enccommit: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Error in EncCommit"; }
|
||||
};
|
||||
class invalid_params: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Invalid Params"; }
|
||||
};
|
||||
class bad_value: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Some value is wrong somewhere"; }
|
||||
};
|
||||
class Offline_Check_Error: public exception
|
||||
{ string msg;
|
||||
public:
|
||||
Offline_Check_Error(string m) : msg(m) {}
|
||||
~Offline_Check_Error()throw() { }
|
||||
virtual const char* what() const throw()
|
||||
{ string ans="Offline-Check-Error : ";
|
||||
ans+=msg;
|
||||
return ans.c_str();
|
||||
}
|
||||
};
|
||||
class mac_fail: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "MacCheck Failure"; }
|
||||
};
|
||||
class invalid_program: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Invalid Program"; }
|
||||
};
|
||||
class file_error: public exception
|
||||
{ string filename, ans;
|
||||
public:
|
||||
file_error(string m="") : filename(m)
|
||||
{
|
||||
ans="File Error : ";
|
||||
ans+=filename;
|
||||
}
|
||||
~file_error()throw() { }
|
||||
virtual const char* what() const throw()
|
||||
{
|
||||
return ans.c_str();
|
||||
}
|
||||
};
|
||||
class end_of_file: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "End of file reached"; }
|
||||
};
|
||||
class Processor_Error: public exception
|
||||
{ string msg;
|
||||
public:
|
||||
Processor_Error(string m)
|
||||
{
|
||||
msg = "Processor-Error : " + m;
|
||||
}
|
||||
~Processor_Error()throw() { }
|
||||
virtual const char* what() const throw()
|
||||
{
|
||||
return msg.c_str();
|
||||
}
|
||||
};
|
||||
class max_mod_sz_too_small : public exception
|
||||
{ int len;
|
||||
public:
|
||||
max_mod_sz_too_small(int len) : len(len) {}
|
||||
~max_mod_sz_too_small() throw() {}
|
||||
virtual const char* what() const throw()
|
||||
{ stringstream out;
|
||||
out << "MAX_MOD_SZ too small for desired bit length of p, "
|
||||
<< "must be at least ceil(len(p)/len(word))+1, "
|
||||
<< "in this case: " << len;
|
||||
return out.str().c_str();
|
||||
}
|
||||
};
|
||||
class crash_requested: public exception
|
||||
{ virtual const char* what() const throw()
|
||||
{ return "Crash requested by program"; }
|
||||
};
|
||||
class memory_exception : public exception {};
|
||||
class how_would_that_work : public exception {};
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
554
Fake-Offline.cpp
Normal file
554
Fake-Offline.cpp
Normal file
@@ -0,0 +1,554 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Share.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Auth/fake-stuff.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include "Math/Setup.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Tools/mkpath.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
|
||||
string prep_data_prefix;
|
||||
|
||||
/* N = Number players
|
||||
* ntrip = Number triples needed
|
||||
* str = "2" or "p"
|
||||
*/
|
||||
template<class T>
|
||||
void make_mult_triples(const T& key,int N,int ntrip,const string& str,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
T a,b,c;
|
||||
vector<Share<T> > Sa(N),Sb(N),Sc(N);
|
||||
/* Generate Triples */
|
||||
for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << "Triples-" << str << "-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
if (!zero)
|
||||
a.randomize(G);
|
||||
make_share(Sa,a,N,key,G);
|
||||
if (!zero)
|
||||
b.randomize(G);
|
||||
make_share(Sb,b,N,key,G);
|
||||
c.mul(a,b);
|
||||
make_share(Sc,c,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false);
|
||||
Sb[j].output(outf[j],false);
|
||||
Sc[j].output(outf[j],false);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
gf2n a,b,c, one;
|
||||
one.assign_one();
|
||||
vector<Share<gf2n> > Sa(N),Sb(N),Sc(N);
|
||||
/* Generate Triples */
|
||||
for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << Data_Files::dtype_names[dtype] << "-2-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
if (!zero)
|
||||
a.randomize(G);
|
||||
a.AND(a, one);
|
||||
make_share(Sa,a,N,key,G);
|
||||
if (!zero)
|
||||
b.randomize(G);
|
||||
if (dtype == DATA_BITTRIPLE)
|
||||
b.AND(b, one);
|
||||
make_share(Sb,b,N,key,G);
|
||||
c.mul(a,b);
|
||||
make_share(Sc,c,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false);
|
||||
Sb[j].output(outf[j],false);
|
||||
Sc[j].output(outf[j],false);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
|
||||
/* N = Number players
|
||||
* ntrip = Number tuples needed
|
||||
* str = "2" or "p"
|
||||
*/
|
||||
template<class T>
|
||||
void make_square_tuples(const T& key,int N,int ntrip,const string& str,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
T a,c;
|
||||
vector<Share<T> > Sa(N),Sc(N);
|
||||
/* Generate Squares */
|
||||
for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << "Squares-" << str << "-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
if (!zero)
|
||||
a.randomize(G);
|
||||
make_share(Sa,a,N,key,G);
|
||||
c.mul(a,a);
|
||||
make_share(Sc,c,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false);
|
||||
Sc[j].output(outf[j],false);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
/* N = Number players
|
||||
* ntrip = Number bits needed
|
||||
* str = "2" or "p"
|
||||
*/
|
||||
template<class T>
|
||||
void make_bits(const T& key,int N,int ntrip,const string& str,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
T a;
|
||||
vector<Share<T> > Sa(N);
|
||||
/* Generate Bits */
|
||||
for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << "Bits-" << str << "-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{ if ((G.get_uchar()&1)==0 || zero) { a.assign_zero(); }
|
||||
else { a.assign_one(); }
|
||||
make_share(Sa,a,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false); }
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
|
||||
/* N = Number players
|
||||
* ntrip = Number inputs needed
|
||||
* str = "2" or "p"
|
||||
*
|
||||
*/
|
||||
template<class T>
|
||||
void make_inputs(const T& key,int N,int ntrip,const string& str,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
T a;
|
||||
vector<Share<T> > Sa(N);
|
||||
/* Generate Inputs */
|
||||
for (int player=0; player<N; player++)
|
||||
{ for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << "Inputs-" << str << "-P" << i << "-" << player;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
if (!zero)
|
||||
a.randomize(G);
|
||||
make_share(Sa,a,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false);
|
||||
if (j==player)
|
||||
{ a.output(outf[j],false); }
|
||||
}
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
}
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
|
||||
/* N = Number players
|
||||
* ntrip = Number inverses needed
|
||||
* str = "2" or "p"
|
||||
*/
|
||||
template<class T>
|
||||
void make_inverse(const T& key,int N,int ntrip,bool zero)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
|
||||
ofstream* outf=new ofstream[N];
|
||||
T a,b;
|
||||
vector<Share<T> > Sa(N),Sb(N);
|
||||
/* Generate Triples */
|
||||
for (int i=0; i<N; i++)
|
||||
{ stringstream filename;
|
||||
filename << prep_data_prefix << "Inverses-" << T::type_char() << "-P" << i;
|
||||
cout << "Opening " << filename.str() << endl;
|
||||
outf[i].open(filename.str().c_str(),ios::out | ios::binary);
|
||||
if (outf[i].fail()) { throw file_error(filename.str().c_str()); }
|
||||
}
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
if (zero)
|
||||
// ironic?
|
||||
a.assign_one();
|
||||
else
|
||||
do
|
||||
a.randomize(G);
|
||||
while (a.is_zero());
|
||||
make_share(Sa,a,N,key,G);
|
||||
b=a; b.invert();
|
||||
make_share(Sb,b,N,key,G);
|
||||
for (int j=0; j<N; j++)
|
||||
{ Sa[j].output(outf[j],false);
|
||||
Sb[j].output(outf[j],false);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<N; i++)
|
||||
{ outf[i].close(); }
|
||||
delete[] outf;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void make_PreMulC(const T& key, int N, int ntrip, bool zero)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << prep_data_prefix << "PreMulC-" << T::type_char();
|
||||
Files<T> files(N, key, ss.str());
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
T a, b, c;
|
||||
c = 1;
|
||||
for (int i=0; i<ntrip; i++)
|
||||
{
|
||||
// close the circle
|
||||
if (i == ntrip - 1 || zero)
|
||||
a.assign_one();
|
||||
else
|
||||
do
|
||||
a.randomize(G);
|
||||
while (a.is_zero());
|
||||
files.output_shares(a);
|
||||
b = a;
|
||||
b.invert();
|
||||
files.output_shares(b);
|
||||
files.output_shares(a * c);
|
||||
c = b;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
|
||||
opt.syntax = "./Fake-Offline.x <nplayers> [OPTIONS]\n\nOptions with 2 arguments take the form '-X <#gf2n tuples>,<#modp tuples>'";
|
||||
opt.example = "./Fake-Offline.x 2 -lgp 128 -lg2 128 --default 10000\n./Fake-Offline.x 3 -trip 50000,10000 -btrip 100000\n";
|
||||
|
||||
opt.add(
|
||||
"128", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(p) field (default: 128)", // Help description.
|
||||
"-lgp", // Flag token.
|
||||
"--lgp" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"40", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(2^n) field (default: 40)", // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"1000", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Default number of tuples to generate for ALL data types (default: 1000)", // Help description.
|
||||
"-d", // Flag token.
|
||||
"--default" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
2, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Number of triples, for gf2n / modp types", // Help description.
|
||||
"-trip", // Flag token.
|
||||
"--ntriples" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
2, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Number of random bits, for gf2n / modp types", // Help description.
|
||||
"-bit", // Flag token.
|
||||
"--nbits" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
2, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Number of input tuples, for gf2n / modp types", // Help description.
|
||||
"-inp", // Flag token.
|
||||
"--ninputs" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
2, // Number of args expected.
|
||||
',', // Delimiter if expecting multiple args.
|
||||
"Number of square tuples, for gf2n / modp types", // Help description.
|
||||
"-sq", // Flag token.
|
||||
"--nsquares" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of inverse tuples (modp only)", // Help description.
|
||||
"-inv", // Flag token.
|
||||
"--ninverses" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of GF(2) triples", // Help description.
|
||||
"-btrip", // Flag token.
|
||||
"--nbittriples" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of GF(2) x GF(2^n) triples", // Help description.
|
||||
"-mixed", // Flag token.
|
||||
"--nbitgf2ntriples" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Set all values to zero, but not the shares", // Help description.
|
||||
"-z", // Flag token.
|
||||
"--zero" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
|
||||
vector<string> badOptions;
|
||||
string usage;
|
||||
unsigned int i;
|
||||
if(!opt.gotRequired(badOptions))
|
||||
{
|
||||
for (i=0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(!opt.gotExpected(badOptions))
|
||||
{
|
||||
for(i=0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int nplayers;
|
||||
if (opt.firstArgs.size() == 2)
|
||||
{
|
||||
nplayers = atoi(opt.firstArgs[1]->c_str());
|
||||
}
|
||||
else if (opt.lastArgs.size() == 1)
|
||||
{
|
||||
nplayers = atoi(opt.lastArgs[0]->c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "ERROR: invalid number of arguments\n";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int default_num = 0;
|
||||
int ntrip2=0, ntripp=0, nbits2=0,nbitsp=0,nsqr2=0,nsqrp=0,ninp2=0,ninpp=0,ninv=0, nbittrip=0, nbitgf2ntrip=0;
|
||||
vector<int> list_options;
|
||||
int lg2, lgp;
|
||||
|
||||
opt.get("--lgp")->getInt(lgp);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
|
||||
opt.get("--default")->getInt(default_num);
|
||||
ntrip2 = ntripp = nbits2 = nbitsp = nsqr2 = nsqrp = ninp2 = ninpp = ninv =
|
||||
nbittrip = nbitgf2ntrip = default_num;
|
||||
|
||||
if (opt.isSet("--ntriples"))
|
||||
{
|
||||
opt.get("--ntriples")->getInts(list_options);
|
||||
ntrip2 = list_options[0];
|
||||
ntripp = list_options[1];
|
||||
}
|
||||
if (opt.isSet("--nbits"))
|
||||
{
|
||||
opt.get("--nbits")->getInts(list_options);
|
||||
nbits2 = list_options[0];
|
||||
nbitsp = list_options[1];
|
||||
}
|
||||
if (opt.isSet("--ninputs"))
|
||||
{
|
||||
opt.get("--ninputs")->getInts(list_options);
|
||||
ninp2 = list_options[0];
|
||||
ninpp = list_options[1];
|
||||
}
|
||||
if (opt.isSet("--nsquares"))
|
||||
{
|
||||
opt.get("--nsquares")->getInts(list_options);
|
||||
nsqr2 = list_options[0];
|
||||
nsqrp = list_options[1];
|
||||
}
|
||||
if (opt.isSet("--ninverses"))
|
||||
opt.get("--ninverses")->getInt(ninv);
|
||||
if (opt.isSet("--nbittriples"))
|
||||
opt.get("--nbittriples")->getInt(nbittrip);
|
||||
if (opt.isSet("--nbitgf2ntriples"))
|
||||
opt.get("--nbitgf2ntriples")->getInt(nbitgf2ntrip);
|
||||
|
||||
bool zero = opt.isSet("--zero");
|
||||
if (zero)
|
||||
cout << "Set all values to zero" << endl;
|
||||
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
prep_data_prefix = get_prep_dir(nplayers, lgp, lg2);
|
||||
// Set up the fields
|
||||
ofstream outf;
|
||||
bigint p;
|
||||
generate_online_setup(outf, prep_data_prefix, p, lgp, lg2);
|
||||
|
||||
generate_keys(prep_data_prefix, nplayers);
|
||||
|
||||
/* Find number players and MAC keys etc*/
|
||||
gfp keyp,pp; keyp.assign_zero();
|
||||
gf2n key2,p2; key2.assign_zero();
|
||||
int tmpN = 0;
|
||||
ifstream inpf;
|
||||
|
||||
// create Player-Data if not there
|
||||
if (mkdir_p("Player-Data") == -1)
|
||||
{
|
||||
cerr << "mkdir_p(Player-Data) failed\n";
|
||||
throw file_error();
|
||||
}
|
||||
|
||||
for (i = 0; i < (unsigned int)nplayers; i++)
|
||||
{
|
||||
stringstream filename;
|
||||
filename << prep_data_prefix << "Player-MAC-Keys-P" << i;
|
||||
inpf.open(filename.str().c_str());
|
||||
if (inpf.fail())
|
||||
{
|
||||
inpf.close();
|
||||
cout << "No MAC key share for player " << i << ", generating a fresh one\n";
|
||||
pp.randomize(G);
|
||||
p2.randomize(G);
|
||||
ofstream outf(filename.str().c_str());
|
||||
if (outf.fail())
|
||||
throw file_error(filename.str().c_str());
|
||||
outf << nplayers << " " << pp << " " << p2;
|
||||
outf.close();
|
||||
cout << "Written new MAC key share to " << filename.str() << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
inpf >> tmpN; // not needed here
|
||||
pp.input(inpf,true);
|
||||
p2.input(inpf,true);
|
||||
inpf.close();
|
||||
}
|
||||
cout << " Key " << i << "\t p: " << pp << "\n\t 2: " << p2 << endl;
|
||||
keyp.add(pp);
|
||||
key2.add(p2);
|
||||
}
|
||||
cout << "--------------\n";
|
||||
cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl;
|
||||
|
||||
make_mult_triples(key2,nplayers,ntrip2,"2",zero);
|
||||
make_mult_triples(keyp,nplayers,ntripp,"p",zero);
|
||||
make_bits(key2,nplayers,nbits2,"2",zero);
|
||||
make_bits(keyp,nplayers,nbitsp,"p",zero);
|
||||
make_square_tuples(key2,nplayers,nsqr2,"2",zero);
|
||||
make_square_tuples(keyp,nplayers,nsqrp,"p",zero);
|
||||
make_inputs(key2,nplayers,ninp2,"2",zero);
|
||||
make_inputs(keyp,nplayers,ninpp,"p",zero);
|
||||
make_inverse(key2,nplayers,ninv,zero);
|
||||
make_inverse(keyp,nplayers,ninv,zero);
|
||||
make_bit_triples(key2,nplayers,nbittrip,DATA_BITTRIPLE,zero);
|
||||
make_bit_triples(key2,nplayers,nbitgf2ntrip,DATA_BITGF2NTRIPLE,zero);
|
||||
make_PreMulC(key2,nplayers,ninv,zero);
|
||||
make_PreMulC(keyp,nplayers,ninv,zero);
|
||||
}
|
||||
5
HOSTS.example
Normal file
5
HOSTS.example
Normal file
@@ -0,0 +1,5 @@
|
||||
192.168.0.1
|
||||
192.168.0.2
|
||||
192.168.0.3
|
||||
192.168.0.4
|
||||
192.168.0.5
|
||||
19
License.txt
Normal file
19
License.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
University of Bristol : Open Access Software Licence
|
||||
|
||||
Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
|
||||
|
||||
All rights reserved
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Any use of the software for scientific publications or commercial purposes should be reported to the University of Bristol (OSI-notifications@bristol.ac.uk and quote reference 1914). This is for impact and usage monitoring purposes only.
|
||||
|
||||
Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk
|
||||
|
||||
75
Makefile
Normal file
75
Makefile
Normal file
@@ -0,0 +1,75 @@
|
||||
# (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
include CONFIG
|
||||
|
||||
MATH = $(patsubst %.cpp,%.o,$(wildcard Math/*.cpp))
|
||||
|
||||
TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp))
|
||||
|
||||
NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
|
||||
|
||||
AUTH = $(patsubst %.cpp,%.o,$(wildcard Auth/*.cpp))
|
||||
|
||||
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp))
|
||||
|
||||
# OT stuff needs GF2N_LONG, so only compile if this is enabled
|
||||
ifeq ($(USE_GF2N_LONG),1)
|
||||
OT = $(patsubst %.cpp,%.o,$(filter-out OT/OText_main.cpp,$(wildcard OT/*.cpp)))
|
||||
OT_EXE = ot.x ot-offline.x
|
||||
endif
|
||||
|
||||
COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH)
|
||||
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE)
|
||||
|
||||
LIB = libSPDZ.a
|
||||
LIBSIMPLEOT = SimpleOT/libsimpleot.a
|
||||
|
||||
|
||||
all: gen_input online offline
|
||||
|
||||
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x
|
||||
|
||||
offline: $(OT_EXE) Check-Offline.x
|
||||
|
||||
gen_input: gen_input_f2n.x gen_input_fp.x
|
||||
|
||||
Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR)
|
||||
$(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS)
|
||||
|
||||
Server.x: Server.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)
|
||||
|
||||
Player-Online.x: Player-Online.cpp $(COMMON) $(PROCESSOR)
|
||||
$(CXX) $(CFLAGS) Player-Online.cpp -o Player-Online.x $(COMMON) $(PROCESSOR) $(LDLIBS)
|
||||
|
||||
ifeq ($(USE_GF2N_LONG),1)
|
||||
ot.x: $(OT) $(COMMON) OT/OText_main.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT)
|
||||
|
||||
ot-check.x: $(OT) $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o ot-check.x OT/BitVector.o OT/OutputCheck.cpp $(COMMON) $(LDLIBS)
|
||||
|
||||
ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp
|
||||
$(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o OT/BitVector.o $(COMMON) $(LDLIBS)
|
||||
|
||||
ot-offline.x: $(OT) $(COMMON) ot-offline.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(LIBSIMPLEOT)
|
||||
endif
|
||||
|
||||
check-passive.x: $(COMMON) check-passive.cpp
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS)
|
||||
|
||||
gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS)
|
||||
|
||||
|
||||
clean:
|
||||
-rm */*.o *.o *.x core.* *.a gmon.out
|
||||
|
||||
24
Math/Integer.cpp
Normal file
24
Math/Integer.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Integer.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Integer.h"
|
||||
|
||||
void Integer::output(ostream& s,bool human) const
|
||||
{
|
||||
if (human)
|
||||
s << a;
|
||||
else
|
||||
s.write((char*)&a, sizeof(a));
|
||||
}
|
||||
|
||||
void Integer::input(istream& s,bool human)
|
||||
{
|
||||
if (human)
|
||||
s >> a;
|
||||
else
|
||||
s.read((char*)&a, sizeof(a));
|
||||
}
|
||||
34
Math/Integer.h
Normal file
34
Math/Integer.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Integer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef INTEGER_H_
|
||||
#define INTEGER_H_
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
// Wrapper class for integer, used for Memory
|
||||
|
||||
class Integer
|
||||
{
|
||||
long a;
|
||||
|
||||
public:
|
||||
|
||||
Integer() { a = 0; }
|
||||
Integer(long a) : a(a) {}
|
||||
|
||||
long get() const { return a; }
|
||||
|
||||
void assign_zero() { a = 0; }
|
||||
|
||||
void output(ostream& s,bool human) const;
|
||||
void input(istream& s,bool human);
|
||||
|
||||
};
|
||||
|
||||
#endif /* INTEGER_H_ */
|
||||
148
Math/Setup.cpp
Normal file
148
Math/Setup.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Math/Setup.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
#include "Tools/mkpath.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
|
||||
/*
|
||||
* Just setup the primes, doesn't need NTL.
|
||||
* Sets idx and m to be used by SHE setup if necessary
|
||||
*/
|
||||
void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m)
|
||||
{
|
||||
cout << "Setting up parameters" << endl;
|
||||
|
||||
switch (lgp)
|
||||
{ case -1:
|
||||
m=16;
|
||||
idx=1; // Any old figures will do, but need to be for lgp at last
|
||||
lgp=32; // Switch to bigger prime to get parameters
|
||||
break;
|
||||
case 32:
|
||||
m=8192;
|
||||
idx=0;
|
||||
break;
|
||||
case 64:
|
||||
m=16384;
|
||||
idx=1;
|
||||
break;
|
||||
case 128:
|
||||
m=32768;
|
||||
idx=2;
|
||||
break;
|
||||
case 256:
|
||||
m=32768;
|
||||
idx=3;
|
||||
break;
|
||||
case 512:
|
||||
m=65536;
|
||||
idx=4;
|
||||
break;
|
||||
default:
|
||||
throw invalid_params();
|
||||
break;
|
||||
}
|
||||
cout << "m = " << m << endl;
|
||||
|
||||
// Here we choose a prime which is the order of a BN curve
|
||||
// - Reason is that there are some applications where this
|
||||
// would be a good idea. So I have hard coded it in here
|
||||
// - This is pointless/impossible for lgp=32, 64 so for
|
||||
// these do something naive
|
||||
// - Have not tested 256 and 512
|
||||
bigint u;
|
||||
int ex;
|
||||
if (lgp!=32 && lgp!=64)
|
||||
{ u=1; u=u<<(lgp-1); u=sqrt(sqrt(u/36))/m;
|
||||
u=u*m;
|
||||
bigint q;
|
||||
// cout << ex << " " << u << " " << numBits(u) << endl;
|
||||
p=(((36*u+36)*u+18)*u+6)*u+1; // The group order of a BN curve
|
||||
q=(((36*u+36)*u+24)*u+6)*u+1; // The base field size of a BN curve
|
||||
while (!probPrime(p) || !probPrime(q) || numBits(p)<lgp)
|
||||
{ u=u+m;
|
||||
p=(((36*u+36)*u+18)*u+6)*u+1;
|
||||
q=(((36*u+36)*u+24)*u+6)*u+1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{ ex=lgp-numBits(m);
|
||||
u=1; u=(u<<ex)*m; p=u+1;
|
||||
while (!probPrime(p) || numBits(p)<lgp)
|
||||
{ u=u+m; p=u+1; }
|
||||
}
|
||||
cout << "\t p = " << p << " u = " << u << " : ";
|
||||
cout << lgp << " <= " << numBits(p) << endl;
|
||||
}
|
||||
|
||||
|
||||
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2)
|
||||
{
|
||||
int idx, m;
|
||||
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
|
||||
|
||||
stringstream ss;
|
||||
ss << dirname;
|
||||
cout << "Writing to file in " << ss.str() << endl;
|
||||
// create preprocessing dir. if necessary
|
||||
if (mkdir_p(ss.str().c_str()) == -1)
|
||||
{
|
||||
cerr << "mkdir_p(" << ss.str() << ") failed\n";
|
||||
throw file_error();
|
||||
}
|
||||
|
||||
// Output the data
|
||||
ss << "/Params-Data";
|
||||
outf.open(ss.str().c_str());
|
||||
// Only need p and lg2 for online phase
|
||||
outf << p << endl;
|
||||
// Fix as a negative lg2 is a ``signal'' to choose slightly weaker
|
||||
// LWE parameters
|
||||
outf << abs(lg2) << endl;
|
||||
|
||||
gfp::init_field(p, true);
|
||||
gf2n::init_field(lg2);
|
||||
}
|
||||
|
||||
string get_prep_dir(int nparties, int lg2p, int gf2ndegree)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << PREP_DIR << nparties << "-" << lg2p << "-" << gf2ndegree << "/";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// Only read enough to initialize the fields (i.e. for OT offline or online phase only)
|
||||
void read_setup(const string& dir_prefix)
|
||||
{
|
||||
int lg2;
|
||||
bigint p;
|
||||
|
||||
string filename = dir_prefix + "Params-Data";
|
||||
cerr << "loading params from: " << filename << endl;
|
||||
|
||||
// backwards compatibility hack
|
||||
if (dir_prefix.compare("") == 0)
|
||||
filename = string("Player-Data/Params-Data");
|
||||
|
||||
ifstream inpf(filename.c_str());
|
||||
if (inpf.fail()) { throw file_error(filename.c_str()); }
|
||||
inpf >> p;
|
||||
inpf >> lg2;
|
||||
|
||||
inpf.close();
|
||||
|
||||
gfp::init_field(p);
|
||||
gf2n::init_field(lg2);
|
||||
}
|
||||
|
||||
void read_setup(int nparties, int lg2p, int gf2ndegree)
|
||||
{
|
||||
string dir = get_prep_dir(nparties, lg2p, gf2ndegree);
|
||||
read_setup(dir);
|
||||
}
|
||||
35
Math/Setup.h
Normal file
35
Math/Setup.h
Normal file
@@ -0,0 +1,35 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Setup.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MATH_SETUP_H_
|
||||
#define MATH_SETUP_H_
|
||||
|
||||
#include "Math/bigint.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
/*
|
||||
* Routines to create and read setup files for the finite fields
|
||||
*/
|
||||
|
||||
// Create setup file for gfp and gf2n
|
||||
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2);
|
||||
|
||||
// Setup primes only
|
||||
// Chooses a p of at least lgp bits
|
||||
void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m);
|
||||
|
||||
// get main directory for prep. data
|
||||
string get_prep_dir(int nparties, int lg2p, int gf2ndegree);
|
||||
|
||||
// Read online setup file for gfp and gf2n
|
||||
void read_setup(const string& dir_prefix);
|
||||
void read_setup(int nparties, int lg2p, int gf2ndegree);
|
||||
|
||||
|
||||
#endif /* MATH_SETUP_H_ */
|
||||
126
Math/Share.cpp
Normal file
126
Math/Share.cpp
Normal file
@@ -0,0 +1,126 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Share.h"
|
||||
//#include "Tools/random.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/operators.h"
|
||||
|
||||
|
||||
template<class T>
|
||||
Share<T>::Share(const T& aa, int my_num, const T& alphai)
|
||||
{
|
||||
if (my_num == 0)
|
||||
a = aa;
|
||||
else
|
||||
a.assign_zero();
|
||||
mac = aa * alphai;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Share<T>::mul_by_bit(const Share<T>& S,const T& aa)
|
||||
{
|
||||
a.mul(S.a,aa);
|
||||
mac.mul(S.mac,aa);
|
||||
}
|
||||
|
||||
template<>
|
||||
void Share<gf2n>::mul_by_bit(const Share<gf2n>& S, const gf2n& aa)
|
||||
{
|
||||
a.mul_by_bit(S.a,aa);
|
||||
mac.mul_by_bit(S.mac,aa);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Share<T>::add(const Share<T>& S,const T& aa,bool playerone,const T& alphai)
|
||||
{
|
||||
if (playerone)
|
||||
{ a.add(S.a,aa); }
|
||||
else
|
||||
{ a=S.a; }
|
||||
|
||||
T tmp;
|
||||
tmp.mul(alphai,aa);
|
||||
mac.add(S.mac,tmp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
void Share<T>::sub(const Share<T>& S,const T& aa,bool playerone,const T& alphai)
|
||||
{
|
||||
if (playerone)
|
||||
{ a.sub(S.a,aa); }
|
||||
else
|
||||
{ a=S.a; }
|
||||
|
||||
T tmp;
|
||||
tmp.mul(alphai,aa);
|
||||
mac.sub(S.mac,tmp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
void Share<T>::sub(const T& aa,const Share<T>& S,bool playerone,const T& alphai)
|
||||
{
|
||||
if (playerone)
|
||||
{ a.sub(aa,S.a); }
|
||||
else
|
||||
{ a=S.a;
|
||||
a.negate();
|
||||
}
|
||||
|
||||
T tmp;
|
||||
tmp.mul(alphai,aa);
|
||||
mac.sub(tmp,S.mac);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
void Share<T>::sub(const Share<T>& S1,const Share<T>& S2)
|
||||
{
|
||||
a.sub(S1.a,S2.a);
|
||||
mac.sub(S1.mac,S2.mac);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T>
|
||||
T combine(const vector< Share<T> >& S)
|
||||
{
|
||||
T ans=S[0].a;
|
||||
for (unsigned int i=1; i<S.size(); i++)
|
||||
{ ans.add(ans,S[i].a); }
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
bool check_macs(const vector< Share<T> >& S,const T& key)
|
||||
{
|
||||
T val=combine(S);
|
||||
|
||||
// Now check the MAC is valid
|
||||
val.mul(val,key);
|
||||
for (unsigned i=0; i<S.size(); i++)
|
||||
{ val.sub(val,S[i].mac); }
|
||||
if (!val.is_zero()) { return false; }
|
||||
return true;
|
||||
}
|
||||
|
||||
template class Share<gf2n>;
|
||||
template class Share<gfp>;
|
||||
template gf2n combine(const vector< Share<gf2n> >& S);
|
||||
template gfp combine(const vector< Share<gfp> >& S);
|
||||
template bool check_macs(const vector< Share<gf2n> >& S,const gf2n& key);
|
||||
template bool check_macs(const vector< Share<gfp> >& S,const gfp& key);
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template class Share<gf2n_short>;
|
||||
template gf2n_short combine(const vector< Share<gf2n_short> >& S);
|
||||
template bool check_macs(const vector< Share<gf2n_short> >& S,const gf2n_short& key);
|
||||
#endif
|
||||
117
Math/Share.h
Normal file
117
Math/Share.h
Normal file
@@ -0,0 +1,117 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#ifndef _Share
|
||||
#define _Share
|
||||
|
||||
/* Class for holding a share of either a T or gfp element */
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
// Forward declaration as apparently this is needed for friends in templates
|
||||
template<class T> class Share;
|
||||
template<class T> T combine(const vector< Share<T> >& S);
|
||||
template<class T> bool check_macs(const vector< Share<T> >& S,const T& key);
|
||||
|
||||
|
||||
template<class T>
|
||||
class Share
|
||||
{
|
||||
T a; // The share
|
||||
T mac; // Shares of the mac
|
||||
|
||||
public:
|
||||
|
||||
typedef T value_type;
|
||||
|
||||
static int size()
|
||||
{ return 2 * T::size(); }
|
||||
|
||||
static string type_string()
|
||||
{ return T::type_string(); }
|
||||
|
||||
void assign(const Share<T>& S)
|
||||
{ a=S.a; mac=S.mac; }
|
||||
void assign(const char* buffer)
|
||||
{ a.assign(buffer); mac.assign(buffer + T::size()); }
|
||||
void assign_zero()
|
||||
{ a.assign_zero();
|
||||
mac.assign_zero();
|
||||
}
|
||||
|
||||
Share() { assign_zero(); }
|
||||
Share(const Share<T>& S) { assign(S); }
|
||||
Share(const T& aa, int my_num, const T& alphai);
|
||||
~Share() { ; }
|
||||
Share& operator=(const Share<T>& S)
|
||||
{ if (this!=&S) { assign(S); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
const T& get_share() const { return a; }
|
||||
const T& get_mac() const { return mac; }
|
||||
void set_share(const T& aa) { a=aa; }
|
||||
void set_mac(const T& aa) { mac=aa; }
|
||||
|
||||
/* Arithmetic Routines */
|
||||
void mul(const Share<T>& S,const T& aa);
|
||||
void mul_by_bit(const Share<T>& S,const T& aa);
|
||||
void add(const Share<T>& S,const T& aa,bool playerone,const T& alphai);
|
||||
void negate() { a.negate(); mac.negate(); }
|
||||
void sub(const Share<T>& S,const T& aa,bool playerone,const T& alphai);
|
||||
void sub(const T& aa,const Share<T>& S,bool playerone,const T& alphai);
|
||||
void add(const Share<T>& S1,const Share<T>& S2);
|
||||
void sub(const Share<T>& S1,const Share<T>& S2);
|
||||
void add(const Share<T>& S1) { add(*this,S1); }
|
||||
|
||||
// Input and output from a stream
|
||||
// - Can do in human or machine only format (later should be faster)
|
||||
void output(ostream& s,bool human) const
|
||||
{ a.output(s,human); if (human) { s << " "; }
|
||||
mac.output(s,human);
|
||||
}
|
||||
void input(istream& s,bool human)
|
||||
{ a.input(s,human);
|
||||
mac.input(s,human);
|
||||
}
|
||||
|
||||
/* Takes a vector of shares, one from each player and
|
||||
* determines the shared value
|
||||
* - i.e. Partially open the shares
|
||||
*/
|
||||
friend T combine<T>(const vector< Share<T> >& S);
|
||||
|
||||
/* Given a set of shares, one from each player and
|
||||
* the global key, determines if the sharing is valid
|
||||
* - Mainly for test purposes
|
||||
*/
|
||||
friend bool check_macs<T>(const vector< Share<T> >& S,const T& key);
|
||||
};
|
||||
|
||||
// specialized mul by bit for gf2n
|
||||
template <>
|
||||
void Share<gf2n>::mul_by_bit(const Share<gf2n>& S,const gf2n& aa);
|
||||
|
||||
template <class T>
|
||||
Share<T> operator*(const T& y, const Share<T>& x) { Share<T> res; res.mul(x, y); return res; }
|
||||
|
||||
template<class T>
|
||||
inline void Share<T>::add(const Share<T>& S1,const Share<T>& S2)
|
||||
{
|
||||
a.add(S1.a,S2.a);
|
||||
mac.add(S1.mac,S2.mac);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void Share<T>::mul(const Share<T>& S,const T& aa)
|
||||
{
|
||||
a.mul(S.a,aa);
|
||||
mac.mul(S.mac,aa);
|
||||
}
|
||||
|
||||
#endif
|
||||
138
Math/Zp_Data.cpp
Normal file
138
Math/Zp_Data.cpp
Normal file
@@ -0,0 +1,138 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Zp_Data.h"
|
||||
|
||||
|
||||
void Zp_Data::init(const bigint& p,bool mont)
|
||||
{ pr=p;
|
||||
mask=(1<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1;
|
||||
|
||||
montgomery=mont;
|
||||
t=mpz_size(pr.get_mpz_t());
|
||||
if (t>=MAX_MOD_SZ)
|
||||
throw max_mod_sz_too_small(t+1);
|
||||
if (montgomery)
|
||||
{ mpn_zero(R,MAX_MOD_SZ);
|
||||
mpn_zero(R2,MAX_MOD_SZ);
|
||||
mpn_zero(R3,MAX_MOD_SZ);
|
||||
bigint r=2,pp=pr;
|
||||
mpz_pow_ui(r.get_mpz_t(),r.get_mpz_t(),t*8*sizeof(mp_limb_t));
|
||||
mpz_invert(pp.get_mpz_t(),pr.get_mpz_t(),r.get_mpz_t());
|
||||
pp=r-pp; // pi=-1/p mod R
|
||||
pi=(pp.get_mpz_t()->_mp_d)[0];
|
||||
|
||||
r=r%pr;
|
||||
mpn_copyi(R,r.get_mpz_t()->_mp_d,mpz_size(r.get_mpz_t()));
|
||||
|
||||
bigint r2=(r*r)%pr;
|
||||
mpn_copyi(R2,r2.get_mpz_t()->_mp_d,mpz_size(r2.get_mpz_t()));
|
||||
|
||||
bigint r3=(r2*r)%pr;
|
||||
mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t()));
|
||||
|
||||
if (sizeof(unsigned long)!=sizeof(mp_limb_t))
|
||||
{ cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl;
|
||||
throw not_implemented();
|
||||
}
|
||||
}
|
||||
mpn_zero(prA,MAX_MOD_SZ);
|
||||
mpn_copyi(prA,pr.get_mpz_t()->_mp_d,t);
|
||||
}
|
||||
|
||||
|
||||
void Zp_Data::assign(const Zp_Data& Zp)
|
||||
{ pr=Zp.pr;
|
||||
mask=Zp.mask;
|
||||
|
||||
montgomery=Zp.montgomery;
|
||||
t=Zp.t;
|
||||
mpn_copyi(R,Zp.R,t+1);
|
||||
mpn_copyi(R2,Zp.R2,t+1);
|
||||
mpn_copyi(R3,Zp.R3,t+1);
|
||||
pi=Zp.pi;
|
||||
|
||||
mpn_copyi(prA,Zp.prA,t+1);
|
||||
}
|
||||
|
||||
|
||||
void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
mp_limb_t borrow = mpn_sub_n(ans,x,y,t);
|
||||
if (borrow!=0)
|
||||
mpn_add_n(ans,ans,prA,t);
|
||||
}
|
||||
|
||||
__m128i Zp_Data::get_random128(PRNG& G)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
__m128i res = G.get_doubleword();
|
||||
if (mpn_cmp((mp_limb_t*)&res, prA, t) < 0)
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
if (x[t]!=0 || y[t]!=0) { cout << "Mont_Mult Bug" << endl; abort(); }
|
||||
mp_limb_t ans[2*MAX_MOD_SZ],u;
|
||||
// First loop
|
||||
u=x[0]*y[0]*pi;
|
||||
ans[t] = mpn_mul_1(ans,y,t,x[0]);
|
||||
ans[t+1] = mpn_addmul_1(ans,prA,t+1,u);
|
||||
for (int i=1; i<t; i++)
|
||||
{ // u=(ans0+xi*y0)*pd
|
||||
u=(ans[i]+x[i]*y[0])*pi;
|
||||
// ans=ans+xi*y+u*pr
|
||||
ans[t+i+1]=mpn_addmul_1(ans+i,y,t+1,x[i]);
|
||||
ans[t+i+1]+=mpn_addmul_1(ans+i,prA,t+1,u);
|
||||
}
|
||||
// if (ans>=pr) { ans=z-pr; }
|
||||
// else { z=ans; }
|
||||
if (mpn_cmp(ans+t,prA,t+1)>=0)
|
||||
{ mpn_sub_n(z,ans+t,prA,t); }
|
||||
else
|
||||
{ mpn_copyi(z,ans+t,t); }
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<<(ostream& s,const Zp_Data& ZpD)
|
||||
{
|
||||
s << ZpD.pr << " " << ZpD.montgomery << endl;
|
||||
if (ZpD.montgomery)
|
||||
{ s << ZpD.t << " " << ZpD.pi << endl;
|
||||
for (int i=0; i<ZpD.t; i++) { s << ZpD.R[i] << " "; }
|
||||
s << endl;
|
||||
for (int i=0; i<ZpD.t; i++) { s << ZpD.R2[i] << " "; }
|
||||
s << endl;
|
||||
for (int i=0; i<ZpD.t; i++) { s << ZpD.R3[i] << " "; }
|
||||
s << endl;
|
||||
for (int i=0; i<ZpD.t; i++) { s << ZpD.prA[i] << " "; }
|
||||
s << endl;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
istream& operator>>(istream& s,Zp_Data& ZpD)
|
||||
{
|
||||
s >> ZpD.pr >> ZpD.montgomery;
|
||||
if (ZpD.montgomery)
|
||||
{ s >> ZpD.t >> ZpD.pi;
|
||||
if (ZpD.t>=MAX_MOD_SZ)
|
||||
throw max_mod_sz_too_small(ZpD.t+1);
|
||||
mpn_zero(ZpD.R,MAX_MOD_SZ);
|
||||
mpn_zero(ZpD.R2,MAX_MOD_SZ);
|
||||
mpn_zero(ZpD.R3,MAX_MOD_SZ);
|
||||
mpn_zero(ZpD.prA,MAX_MOD_SZ);
|
||||
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R[i]; }
|
||||
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R2[i]; }
|
||||
for (int i=0; i<ZpD.t; i++) { s >> ZpD.R3[i]; }
|
||||
for (int i=0; i<ZpD.t; i++) { s >> ZpD.prA[i]; }
|
||||
}
|
||||
return s;
|
||||
}
|
||||
129
Math/Zp_Data.h
Normal file
129
Math/Zp_Data.h
Normal file
@@ -0,0 +1,129 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Zp_Data
|
||||
#define _Zp_Data
|
||||
|
||||
/* Class to define helper information for a Zp element
|
||||
*
|
||||
* Basically the data needed for Montgomery operations
|
||||
*
|
||||
* Almost all data is public as this is basically a container class
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Math/bigint.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
#include <smmintrin.h>
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#ifndef MAX_MOD_SZ
|
||||
#ifdef LargeM
|
||||
#define MAX_MOD_SZ 20
|
||||
#else
|
||||
#define MAX_MOD_SZ 3
|
||||
#endif
|
||||
#endif
|
||||
|
||||
class modp;
|
||||
|
||||
class Zp_Data
|
||||
{
|
||||
bool montgomery; // True if we are using Montgomery arithmetic
|
||||
mp_limb_t R[MAX_MOD_SZ],R2[MAX_MOD_SZ],R3[MAX_MOD_SZ],pi;
|
||||
mp_limb_t prA[MAX_MOD_SZ];
|
||||
int t; // More Montgomery data
|
||||
|
||||
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
|
||||
public:
|
||||
|
||||
bigint pr;
|
||||
mp_limb_t mask;
|
||||
|
||||
void assign(const Zp_Data& Zp);
|
||||
void init(const bigint& p,bool mont=true);
|
||||
int get_t() const { return t; }
|
||||
const mp_limb_t* get_prA() const { return prA; }
|
||||
|
||||
// This one does nothing, needed so as to make vectors of Zp_Data
|
||||
Zp_Data() : montgomery(0), pi(0), mask(0) { t=1; }
|
||||
|
||||
// The main init funciton
|
||||
Zp_Data(const bigint& p,bool mont=true)
|
||||
{ init(p,mont); }
|
||||
|
||||
Zp_Data(const Zp_Data& Zp) { assign(Zp); }
|
||||
Zp_Data& operator=(const Zp_Data& Zp)
|
||||
{ if (this!=&Zp) { assign(Zp); }
|
||||
return *this;
|
||||
}
|
||||
~Zp_Data() { ; }
|
||||
|
||||
template <int T>
|
||||
void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
|
||||
__m128i get_random128(PRNG& G);
|
||||
|
||||
friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce);
|
||||
|
||||
friend void to_modp(modp& ans,int x,const Zp_Data& ZpD);
|
||||
friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD);
|
||||
|
||||
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
|
||||
friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD);
|
||||
friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD);
|
||||
|
||||
friend void assignOne(modp& x,const Zp_Data& ZpD);
|
||||
friend void assignZero(modp& x,const Zp_Data& ZpD);
|
||||
friend bool isZero(const modp& x,const Zp_Data& ZpD);
|
||||
friend bool isOne(const modp& x,const Zp_Data& ZpD);
|
||||
friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
|
||||
friend class modp;
|
||||
|
||||
friend ostream& operator<<(ostream& s,const Zp_Data& ZpD);
|
||||
friend istream& operator>>(istream& s,Zp_Data& ZpD);
|
||||
};
|
||||
|
||||
template<>
|
||||
inline void Zp_Data::Add<2>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
__uint128_t a, b, p;
|
||||
memcpy(&a, x, sizeof(__uint128_t));
|
||||
memcpy(&b, y, sizeof(__uint128_t));
|
||||
memcpy(&p, prA, sizeof(__uint128_t));
|
||||
__uint128_t c = a + b;
|
||||
asm goto ("jc %l[sub]" :::: sub);
|
||||
if (c >= p)
|
||||
sub:
|
||||
c -= p;
|
||||
memcpy(ans, &c, sizeof(__uint128_t));
|
||||
}
|
||||
|
||||
template<>
|
||||
inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
mp_limb_t carry = mpn_add_n(ans,x,y,t);
|
||||
if (carry!=0 || mpn_cmp(ans,prA,t)>=0)
|
||||
{ mpn_sub_n(ans,ans,prA,t); }
|
||||
}
|
||||
|
||||
inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{
|
||||
if (t == 2)
|
||||
return Add<2>(ans, x, y);
|
||||
else
|
||||
return Add<0>(ans, x, y);
|
||||
}
|
||||
|
||||
#endif
|
||||
95
Math/bigint.cpp
Normal file
95
Math/bigint.cpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "bigint.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
|
||||
bigint sqrRootMod(const bigint& a,const bigint& p)
|
||||
{
|
||||
bigint ans;
|
||||
if (a==0) { ans=0; return ans; }
|
||||
if (mpz_tstbit(p.get_mpz_t(),1)==1)
|
||||
{ // First do case with p=3 mod 4
|
||||
bigint exp=(p+1)/4;
|
||||
mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
else
|
||||
{ // Shanks algorithm
|
||||
gmp_randclass Gen(gmp_randinit_default);
|
||||
Gen.seed(0);
|
||||
bigint x,y,n,q,t,b,temp;
|
||||
// Find n such that (n/p)=-1
|
||||
int leg=1;
|
||||
while (leg!=-1)
|
||||
{ n=Gen.get_z_range(p);
|
||||
leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
// Split p-1 = 2^e q
|
||||
q=p-1;
|
||||
int e=0;
|
||||
while (mpz_even_p(q.get_mpz_t()))
|
||||
{ e++; q=q/2; }
|
||||
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
|
||||
int r=e;
|
||||
mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t());
|
||||
temp=(q-1)/2;
|
||||
mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t());
|
||||
// b=a*x^2 mod p, x=a*x mod p
|
||||
b=(a*x*x)%p;
|
||||
x=(a*x)%p;
|
||||
// While b!=1 do
|
||||
while (b!=1)
|
||||
{ // Find smallest m such that b^(2^m)=1 mod p
|
||||
int m=1;
|
||||
temp=(b*b)%p;
|
||||
while (temp!=1)
|
||||
{ temp=(temp*temp)%p; m++; }
|
||||
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
|
||||
t=y;
|
||||
for (int i=0; i<r-m-1; i++)
|
||||
{ t=(t*t)%p; }
|
||||
y=(t*t)%p;
|
||||
r=m;
|
||||
// x=x*t mod p, b=b*y mod p
|
||||
x=(x*t)%p;
|
||||
b=(b*y)%p;
|
||||
}
|
||||
ans=x;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bigint powerMod(const bigint& x,const bigint& e,const bigint& p)
|
||||
{
|
||||
bigint ans;
|
||||
if (e>=0)
|
||||
{ mpz_powm(ans.get_mpz_t(),x.get_mpz_t(),e.get_mpz_t(),p.get_mpz_t()); }
|
||||
else
|
||||
{ bigint xi,ei=-e;
|
||||
invMod(xi,x,p);
|
||||
mpz_powm(ans.get_mpz_t(),xi.get_mpz_t(),ei.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
int powerMod(int x,int e,int p)
|
||||
{
|
||||
if (e==1) { return x; }
|
||||
if (e==0) { return 1; }
|
||||
if (e<0)
|
||||
{ throw not_implemented(); }
|
||||
int t=x,ans=1;
|
||||
while (e!=0)
|
||||
{ if ((e&1)==1) { ans=(ans*t)%p; }
|
||||
e>>=1;
|
||||
t=(t*t)%p;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
116
Math/bigint.h
Normal file
116
Math/bigint.h
Normal file
@@ -0,0 +1,116 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _bigint
|
||||
#define _bigint
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include <stddef.h>
|
||||
#include <mpirxx.h>
|
||||
|
||||
typedef mpz_class bigint;
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Tools/int.h"
|
||||
|
||||
/**********************************
|
||||
* Utility Functions *
|
||||
**********************************/
|
||||
|
||||
inline int gcd(const int x,const int y)
|
||||
{
|
||||
bigint xx=x;
|
||||
return mpz_gcd_ui(NULL,xx.get_mpz_t(),y);
|
||||
}
|
||||
|
||||
|
||||
inline bigint gcd(const bigint& x,const bigint& y)
|
||||
{
|
||||
bigint g;
|
||||
mpz_gcd(g.get_mpz_t(),x.get_mpz_t(),y.get_mpz_t());
|
||||
return g;
|
||||
}
|
||||
|
||||
|
||||
inline void invMod(bigint& ans,const bigint& x,const bigint& p)
|
||||
{
|
||||
mpz_invert(ans.get_mpz_t(),x.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
|
||||
inline int numBits(const bigint& m)
|
||||
{
|
||||
return mpz_sizeinbase(m.get_mpz_t(),2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline int numBits(int m)
|
||||
{
|
||||
bigint te=m;
|
||||
return mpz_sizeinbase(te.get_mpz_t(),2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline int numBytes(const bigint& m)
|
||||
{
|
||||
return mpz_sizeinbase(m.get_mpz_t(),256);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
inline int probPrime(const bigint& x)
|
||||
{
|
||||
gmp_randstate_t rand_state;
|
||||
gmp_randinit_default(rand_state);
|
||||
int ans=mpz_probable_prime_p(x.get_mpz_t(),rand_state,40,0);
|
||||
gmp_randclear(rand_state);
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
inline void bigintFromBytes(bigint& x,octet* bytes,int len)
|
||||
{
|
||||
mpz_import(x.get_mpz_t(),len,1,sizeof(octet),0,0,bytes);
|
||||
}
|
||||
|
||||
|
||||
inline void bytesFromBigint(octet* bytes,const bigint& x,unsigned int len)
|
||||
{
|
||||
size_t ll;
|
||||
mpz_export(bytes,&ll,1,sizeof(octet),0,0,x.get_mpz_t());
|
||||
if (ll>len)
|
||||
{ throw invalid_length(); }
|
||||
for (unsigned int i=ll; i<len; i++)
|
||||
{ bytes[i]=0; }
|
||||
}
|
||||
|
||||
|
||||
inline int isOdd(const bigint& x)
|
||||
{
|
||||
return mpz_odd_p(x.get_mpz_t());
|
||||
}
|
||||
|
||||
|
||||
bigint sqrRootMod(const bigint& x,const bigint& p);
|
||||
|
||||
bigint powerMod(const bigint& x,const bigint& e,const bigint& p);
|
||||
|
||||
// Assume e>=0
|
||||
int powerMod(int x,int e,int p);
|
||||
|
||||
inline int Hwt(int N)
|
||||
{
|
||||
int result=0;
|
||||
while(N)
|
||||
{ result++;
|
||||
N&=(N-1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
15
Math/field_types.h
Normal file
15
Math/field_types.h
Normal file
@@ -0,0 +1,15 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* types.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MATH_FIELD_TYPES_H_
|
||||
#define MATH_FIELD_TYPES_H_
|
||||
|
||||
|
||||
enum DataFieldType { DATA_MODP, DATA_GF2N, N_DATA_FIELD_TYPE };
|
||||
|
||||
|
||||
#endif /* MATH_FIELD_TYPES_H_ */
|
||||
345
Math/gf2n.cpp
Normal file
345
Math/gf2n.cpp
Normal file
@@ -0,0 +1,345 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <stdint.h>
|
||||
#include <wmmintrin.h>
|
||||
#include <xmmintrin.h>
|
||||
#include <emmintrin.h>
|
||||
|
||||
int gf2n_short::n;
|
||||
int gf2n_short::t1;
|
||||
int gf2n_short::t2;
|
||||
int gf2n_short::t3;
|
||||
int gf2n_short::l0;
|
||||
int gf2n_short::l1;
|
||||
int gf2n_short::l2;
|
||||
int gf2n_short::l3;
|
||||
int gf2n_short::nterms;
|
||||
word gf2n_short::mask;
|
||||
bool gf2n_short::useC;
|
||||
bool gf2n_short::rewind = false;
|
||||
|
||||
word gf2n_short_table[256][256];
|
||||
|
||||
#define num_2_fields 4
|
||||
|
||||
/* Require
|
||||
* 2*(n-1)-64+t1<64
|
||||
*/
|
||||
int fields_2[num_2_fields][4] = {
|
||||
{4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10}
|
||||
};
|
||||
|
||||
|
||||
void gf2n_short::init_tables()
|
||||
{
|
||||
if (sizeof(word)!=8)
|
||||
{ cout << "Word size is wrong" << endl;
|
||||
throw not_implemented();
|
||||
}
|
||||
int i,j;
|
||||
for (i=0; i<256; i++)
|
||||
{ for (j=0; j<256; j++)
|
||||
{ word ii=i,jj=j;
|
||||
gf2n_short_table[i][j]=0;
|
||||
while (ii!=0)
|
||||
{ if ((ii&1)==1) { gf2n_short_table[i][j]^=jj; }
|
||||
jj<<=1;
|
||||
ii>>=1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gf2n_short::init_field(int nn)
|
||||
{
|
||||
gf2n_short::init_tables();
|
||||
int i,j=-1;
|
||||
for (i=0; i<num_2_fields && j==-1; i++)
|
||||
{ if (nn==fields_2[i][0]) { j=i; } }
|
||||
if (j==-1) { throw invalid_params(); }
|
||||
|
||||
n=nn;
|
||||
nterms=1;
|
||||
l0=64-n;
|
||||
t1=fields_2[j][1];
|
||||
l1=64+t1-n;
|
||||
if (fields_2[j][2]!=0)
|
||||
{ nterms=3;
|
||||
t2=fields_2[j][2];
|
||||
l2=64+t2-n;
|
||||
t3=fields_2[j][3];
|
||||
l3=64+t3-n;
|
||||
}
|
||||
if (2*(n-1)-64+t1>=64) { throw invalid_params(); }
|
||||
|
||||
mask=(1ULL<<n)-1;
|
||||
|
||||
useC=(Check_CPU_support_AES()==0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Takes 16bit x and y and returns the 32 bit product in c1 and c0
|
||||
ans = (c1<<16)^c0
|
||||
where c1 and c0 are 16 bit
|
||||
*/
|
||||
inline void mul16(word x,word y,word& c0,word& c1)
|
||||
{
|
||||
word a1=x&(0xFF), b1=y&(0xFF);
|
||||
word a2=x>>8, b2=y>>8;
|
||||
|
||||
c0=gf2n_short_table[a1][b1];
|
||||
c1=gf2n_short_table[a2][b2];
|
||||
word te=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1];
|
||||
c0^=(te&0xFF)<<8;
|
||||
c1^=te>>8;
|
||||
}
|
||||
|
||||
/* Takes 16 bit x and y and returns the 32 bit product */
|
||||
inline word mul16(word x,word y)
|
||||
{
|
||||
word a1=x&(0xFF), b1=y&(0xFF);
|
||||
word a2=x>>8, b2=y>>8;
|
||||
|
||||
word ans=gf2n_short_table[a2][b2]<<8;
|
||||
ans^=gf2n_short_table[a1][b2]^gf2n_short_table[a2][b1];
|
||||
ans<<=8;
|
||||
ans^=gf2n_short_table[a1][b1];
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Takes 16 bit x the 32 bit square */
|
||||
inline word sqr16(word x)
|
||||
{
|
||||
word a1=x&(0xFF),a2=x>>8;
|
||||
|
||||
word ans=gf2n_short_table[a2][a2]<<16;
|
||||
ans^=gf2n_short_table[a1][a1];
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gf2n_short::reduce_trinomial(word xh,word xl)
|
||||
{
|
||||
// Deal with xh first
|
||||
a=xl;
|
||||
a^=(xh<<l0);
|
||||
a^=(xh<<l1);
|
||||
|
||||
// Now deal with last word
|
||||
word hi=a>>n;
|
||||
while (hi!=0)
|
||||
{ a&=mask;
|
||||
|
||||
a^=hi;
|
||||
a^=(hi<<t1);
|
||||
hi=a>>n;
|
||||
}
|
||||
}
|
||||
|
||||
void gf2n_short::reduce_pentanomial(word xh,word xl)
|
||||
{
|
||||
// Deal with xh first
|
||||
a=xl;
|
||||
a^=(xh<<l0);
|
||||
a^=(xh<<l1);
|
||||
a^=(xh<<l2);
|
||||
a^=(xh<<l3);
|
||||
|
||||
// Now deal with last word
|
||||
word hi=a>>n;
|
||||
while (hi!=0)
|
||||
{ a&=mask;
|
||||
|
||||
a^=hi;
|
||||
a^=(hi<<t1);
|
||||
a^=(hi<<t2);
|
||||
a^=(hi<<t3);
|
||||
|
||||
hi=a>>n;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void mul32(word x,word y,word& ans)
|
||||
{
|
||||
word a1=x&(0xFFFF),b1=y&(0xFFFF);
|
||||
word a2=x>>16, b2=y>>16;
|
||||
|
||||
word c0,c1;
|
||||
|
||||
ans=mul16(a1,b1);
|
||||
word upp=mul16(a2,b2);
|
||||
|
||||
mul16(a1,b2,c0,c1);
|
||||
ans^=c0<<16; upp^=c1;
|
||||
|
||||
mul16(a2,b1,c0,c1);
|
||||
ans^=c0<<16; upp^=c1;
|
||||
|
||||
ans^=(upp<<32);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y)
|
||||
{
|
||||
word hi,lo;
|
||||
|
||||
if (gf2n_short::useC)
|
||||
{ /* Uses Karatsuba */
|
||||
word c,d,e,t;
|
||||
word xl=x.a&0xFFFFFFFF,yl=y.a&0xFFFFFFFF;
|
||||
word xh=x.a>>32,yh=y.a>>32;
|
||||
mul32(xl,yl,c);
|
||||
mul32(xh,yh,d);
|
||||
mul32((xl^xh),(yl^yh),e);
|
||||
t=c^e^d;
|
||||
lo=c^(t<<32);
|
||||
hi=d^(t>>32);
|
||||
}
|
||||
else
|
||||
{ /* Use Intel Instructions */
|
||||
__m128i xx,yy,zz;
|
||||
uint64_t c[] __attribute__((aligned (16))) = { 0,0 };
|
||||
xx=_mm_set1_epi64x(x.a);
|
||||
yy=_mm_set1_epi64x(y.a);
|
||||
zz=_mm_clmulepi64_si128(xx,yy,0);
|
||||
_mm_store_si128((__m128i*)c,zz);
|
||||
lo=c[0];
|
||||
hi=c[1];
|
||||
}
|
||||
|
||||
reduce(hi,lo);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
inline void sqr32(word x,word& ans)
|
||||
{
|
||||
word a1=x&(0xFFFF),a2=x>>16;
|
||||
ans=sqr16(a1)^(sqr16(a2)<<32);
|
||||
}
|
||||
|
||||
void gf2n_short::square()
|
||||
{
|
||||
word xh,xl;
|
||||
sqr32(a&0xFFFFFFFF,xl);
|
||||
sqr32(a>>32,xh);
|
||||
reduce(xh,xl);
|
||||
}
|
||||
|
||||
|
||||
void gf2n_short::square(const gf2n_short& bb)
|
||||
{
|
||||
word xh,xl;
|
||||
sqr32(bb.a&0xFFFFFFFF,xl);
|
||||
sqr32(bb.a>>32,xh);
|
||||
reduce(xh,xl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
void gf2n_short::invert()
|
||||
{
|
||||
if (is_one()) { return; }
|
||||
if (is_zero()) { throw division_by_zero(); }
|
||||
|
||||
word u,v=a,B=0,D=1,mod=1;
|
||||
|
||||
mod^=(1ULL<<n);
|
||||
mod^=(1ULL<<t1);
|
||||
if (nterms==3)
|
||||
{ mod^=(1ULL<<t2);
|
||||
mod^=(1ULL<<t3);
|
||||
}
|
||||
u=mod; v=a;
|
||||
|
||||
while (u!=0)
|
||||
{ while ((u&1)==0)
|
||||
{ u>>=1;
|
||||
if ((B&1)!=0) { B^=mod; }
|
||||
B>>=1;
|
||||
}
|
||||
while ((v&1)==0 && v!=0)
|
||||
{ v>>=1;
|
||||
if ((D&1)!=0) { D^=mod; }
|
||||
D>>=1;
|
||||
}
|
||||
|
||||
if (u>=v) { u=u^v; B=B^D; }
|
||||
else { v=v^u; D=D^B; }
|
||||
}
|
||||
|
||||
a=D;
|
||||
}
|
||||
|
||||
|
||||
void gf2n_short::power(long i)
|
||||
{
|
||||
long n=i;
|
||||
if (n<0) { invert(); n=-n; }
|
||||
|
||||
gf2n_short T=*this;
|
||||
assign_one();
|
||||
while (n!=0)
|
||||
{ if ((n&1)!=0) { mul(*this,T); }
|
||||
n>>=1;
|
||||
T.square();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gf2n_short::randomize(PRNG& G)
|
||||
{
|
||||
a=G.get_uint();
|
||||
a=(a<<32)^G.get_uint();
|
||||
a&=mask;
|
||||
}
|
||||
|
||||
|
||||
void gf2n_short::output(ostream& s,bool human) const
|
||||
{
|
||||
if (human)
|
||||
{ s << hex << a << dec << " "; }
|
||||
else
|
||||
{ s.write((char*) &a,sizeof(word)); }
|
||||
}
|
||||
|
||||
void gf2n_short::input(istream& s,bool human)
|
||||
{
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
}
|
||||
//throw end_of_file();
|
||||
s.clear(); // unset EOF flag
|
||||
s.seekg(0);
|
||||
if (!rewind)
|
||||
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
|
||||
rewind = true;
|
||||
}
|
||||
|
||||
if (human)
|
||||
{ s >> hex >> a >> dec; }
|
||||
else
|
||||
{ s.read((char*) &a,sizeof(word)); }
|
||||
|
||||
a &= mask;
|
||||
}
|
||||
191
Math/gf2n.h
Normal file
191
Math/gf2n.h
Normal file
@@ -0,0 +1,191 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _gf2n
|
||||
#define _gf2n
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/random.h"
|
||||
|
||||
#include "Math/gf2nlong.h"
|
||||
#include "Math/field_types.h"
|
||||
|
||||
/* This interface compatible with the gfp interface
|
||||
* which then allows us to template the Share
|
||||
* data type.
|
||||
*/
|
||||
|
||||
|
||||
/*
|
||||
Arithmetic in Gf_{2^n} with n<64
|
||||
*/
|
||||
|
||||
class gf2n_short
|
||||
{
|
||||
word a;
|
||||
|
||||
static int n,t1,t2,t3,nterms;
|
||||
static int l0,l1,l2,l3;
|
||||
static word mask;
|
||||
static bool useC;
|
||||
static bool rewind;
|
||||
|
||||
/* Assign x[0..2*nwords] to a and reduce it... */
|
||||
void reduce_trinomial(word xh,word xl);
|
||||
void reduce_pentanomial(word xh,word xl);
|
||||
|
||||
void reduce(word xh,word xl)
|
||||
{ if (nterms==3)
|
||||
{ reduce_pentanomial(xh,xl); }
|
||||
else
|
||||
{ reduce_trinomial(xh,xl); }
|
||||
}
|
||||
|
||||
static void init_tables();
|
||||
|
||||
public:
|
||||
|
||||
typedef gf2n_short value_type;
|
||||
typedef word internal_type;
|
||||
|
||||
static void init_field(int nn);
|
||||
static int degree() { return n; }
|
||||
static int get_nterms() { return nterms; }
|
||||
static int get_t(int i)
|
||||
{ if (i==0) { return t1; }
|
||||
else if (i==1) { return t2; }
|
||||
else if (i==2) { return t3; }
|
||||
return -1;
|
||||
}
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2N; }
|
||||
static char type_char() { return '2'; }
|
||||
static string type_string() { return "gf2n"; }
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
static int t() { return 0; }
|
||||
|
||||
word get() const { return a; }
|
||||
word get_word() const { return a; }
|
||||
|
||||
void assign(const gf2n_short& g) { a=g.a; }
|
||||
|
||||
void assign_zero() { a=0; }
|
||||
void assign_one() { a=1; }
|
||||
void assign_x() { a=2; }
|
||||
void assign(word aa) { a=aa&mask; }
|
||||
void assign(int aa) { a=static_cast<unsigned int>(aa)&mask; }
|
||||
void assign(const char* buffer) { a = *(word*)buffer; }
|
||||
|
||||
int get_bit(int i) const
|
||||
{ return (a>>i)&1; }
|
||||
void set_bit(int i,unsigned int b)
|
||||
{ if (b==1)
|
||||
{ a |= (1UL<<i); }
|
||||
else
|
||||
{ a &= ~(1UL<<i); }
|
||||
}
|
||||
|
||||
gf2n_short() { a=0; }
|
||||
gf2n_short(const gf2n_short& g) { assign(g); }
|
||||
gf2n_short(int g) { assign(g); }
|
||||
~gf2n_short() { ; }
|
||||
|
||||
gf2n_short& operator=(const gf2n_short& g)
|
||||
{ assign(g);
|
||||
return *this;
|
||||
}
|
||||
|
||||
int is_zero() const { return (a==0); }
|
||||
int is_one() const { return (a==1); }
|
||||
int equal(const gf2n_short& y) const { return (a==y.a); }
|
||||
bool operator==(const gf2n_short& y) const { return a==y.a; }
|
||||
bool operator!=(const gf2n_short& y) const { return a!=y.a; }
|
||||
|
||||
// x+y
|
||||
void add(const gf2n_short& x,const gf2n_short& y)
|
||||
{ a=x.a^y.a; }
|
||||
void add(const gf2n_short& x)
|
||||
{ a^=x.a; }
|
||||
template<int T>
|
||||
void add(octet* x)
|
||||
{ a^=*(word*)(x); }
|
||||
void add(octet* x)
|
||||
{ add<0>(x); }
|
||||
void sub(const gf2n_short& x,const gf2n_short& y)
|
||||
{ a=x.a^y.a; }
|
||||
void sub(const gf2n_short& x)
|
||||
{ a^=x.a; }
|
||||
// = x * y
|
||||
void mul(const gf2n_short& x,const gf2n_short& y);
|
||||
void mul(const gf2n_short& x) { mul(*this,x); }
|
||||
// x * y when one of x,y is a bit
|
||||
void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; }
|
||||
|
||||
gf2n_short operator+(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; }
|
||||
gf2n_short operator*(const gf2n_short& x) { gf2n_short res; res.mul(*this, x); return res; }
|
||||
gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; }
|
||||
gf2n_short& operator*=(const gf2n_short& x) { mul(x); return *this; }
|
||||
gf2n_short operator-(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; }
|
||||
gf2n_short& operator-=(const gf2n_short& x) { sub(x); return *this; }
|
||||
|
||||
void square();
|
||||
void square(const gf2n_short& aa);
|
||||
void invert();
|
||||
void invert(const gf2n_short& aa)
|
||||
{ *this=aa; invert(); }
|
||||
void negate() { return; }
|
||||
void power(long i);
|
||||
|
||||
/* Bitwise Ops */
|
||||
void AND(const gf2n_short& x,const gf2n_short& y) { a=x.a&y.a; }
|
||||
void XOR(const gf2n_short& x,const gf2n_short& y) { a=x.a^y.a; }
|
||||
void OR(const gf2n_short& x,const gf2n_short& y) { a=x.a|y.a; }
|
||||
void NOT(const gf2n_short& x) { a=(~x.a)&mask; }
|
||||
void SHL(const gf2n_short& x,int n) { a=(x.a<<n)&mask; }
|
||||
void SHR(const gf2n_short& x,int n) { a=x.a>>n; }
|
||||
|
||||
gf2n_short operator&(const gf2n_short& x) { gf2n_short res; res.AND(*this, x); return res; }
|
||||
gf2n_short operator^(const gf2n_short& x) { gf2n_short res; res.XOR(*this, x); return res; }
|
||||
gf2n_short operator|(const gf2n_short& x) { gf2n_short res; res.OR(*this, x); return res; }
|
||||
gf2n_short operator!() { gf2n_short res; res.NOT(*this); return res; }
|
||||
gf2n_short operator<<(int i) { gf2n_short res; res.SHL(*this, i); return res; }
|
||||
gf2n_short operator>>(int i) { gf2n_short res; res.SHR(*this, i); return res; }
|
||||
|
||||
/* Crap RNG */
|
||||
void randomize(PRNG& G);
|
||||
// compatibility with gfp
|
||||
void almost_randomize(PRNG& G) { randomize(G); }
|
||||
|
||||
void output(ostream& s,bool human) const;
|
||||
void input(istream& s,bool human);
|
||||
|
||||
friend ostream& operator<<(ostream& s,const gf2n_short& x)
|
||||
{ s << hex << "0x" << x.a << dec;
|
||||
return s;
|
||||
}
|
||||
friend istream& operator>>(istream& s,gf2n_short& x)
|
||||
{ s >> hex >> x.a >> dec;
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const
|
||||
{ o.append((octet*) &a,sizeof(word)); }
|
||||
void unpack(octetStream& o)
|
||||
{ o.consume((octet*) &a,sizeof(word)); }
|
||||
};
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
typedef gf2n_long gf2n;
|
||||
#else
|
||||
typedef gf2n_short gf2n;
|
||||
#endif
|
||||
|
||||
#endif
|
||||
277
Math/gf2nlong.cpp
Normal file
277
Math/gf2nlong.cpp
Normal file
@@ -0,0 +1,277 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* gf2n_longlong.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "gf2nlong.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <stdint.h>
|
||||
#include <wmmintrin.h>
|
||||
#include <xmmintrin.h>
|
||||
#include <emmintrin.h>
|
||||
|
||||
|
||||
bool is_ge(__m128i a, __m128i b)
|
||||
{
|
||||
word aa[2], bb[2];
|
||||
_mm_storeu_si128((__m128i*)aa, a);
|
||||
_mm_storeu_si128((__m128i*)bb, b);
|
||||
// cout << hex << "is_ge " << aa[1] << " " << bb[1] << " " << (aa[1] > bb[1]) << " ";
|
||||
// cout << aa[0] << " " << bb[0] << " " << (aa[0] >= bb[0]) << endl;
|
||||
return aa[1] == bb[1] ? aa[0] >= bb[0] : aa[1] > bb[1];
|
||||
}
|
||||
|
||||
|
||||
ostream& operator<<(ostream& s, const int128& a)
|
||||
{
|
||||
word* tmp = (word*)&a.a;
|
||||
s << hex;
|
||||
s.width(16);
|
||||
s.fill('0');
|
||||
s << tmp[1];
|
||||
s.width(16);
|
||||
s << tmp[0] << dec;
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
int gf2n_long::n;
|
||||
int gf2n_long::t1;
|
||||
int gf2n_long::t2;
|
||||
int gf2n_long::t3;
|
||||
int gf2n_long::l0;
|
||||
int gf2n_long::l1;
|
||||
int gf2n_long::l2;
|
||||
int gf2n_long::l3;
|
||||
int gf2n_long::nterms;
|
||||
int128 gf2n_long::mask;
|
||||
int128 gf2n_long::lowermask;
|
||||
int128 gf2n_long::uppermask;
|
||||
bool gf2n_long::rewind = false;
|
||||
|
||||
#define num_2_fields 1
|
||||
|
||||
/* Require
|
||||
* 2*(n-1)-64+t1<64
|
||||
*/
|
||||
int long_fields_2[num_2_fields][4] = {
|
||||
{128,7,2,1},
|
||||
};
|
||||
|
||||
|
||||
void gf2n_long::init_field(int nn)
|
||||
{
|
||||
if (nn!=128) {
|
||||
cout << "Compiled for GF(2^128) only. Change parameters or compile "
|
||||
"without USE_GF2N_LONG" << endl;
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
int i,j=-1;
|
||||
for (i=0; i<num_2_fields && j==-1; i++)
|
||||
{ if (nn==long_fields_2[i][0]) { j=i; } }
|
||||
if (j==-1) { throw invalid_params(); }
|
||||
|
||||
n=nn;
|
||||
nterms=1;
|
||||
l0=128-n;
|
||||
t1=long_fields_2[j][1];
|
||||
l1=128+t1-n;
|
||||
if (long_fields_2[j][2]!=0)
|
||||
{ nterms=3;
|
||||
t2=long_fields_2[j][2];
|
||||
l2=128+t2-n;
|
||||
t3=long_fields_2[j][3];
|
||||
l3=128+t3-n;
|
||||
}
|
||||
// 2^128 has a pentanomial
|
||||
// if (nterms==1 && 2*(n-1)-128+t1>=128) { throw not_implemented(); }
|
||||
// if (nterms==3 && n!=128) { throw not_implemented(); }
|
||||
|
||||
mask=_mm_set_epi64x(-1,-1);
|
||||
lowermask=_mm_set_epi64x((1LL<<(64-7))-1,-1);
|
||||
uppermask=_mm_set_epi64x(((word)-1)<<(64-7),0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gf2n_long::reduce_trinomial(int128 xh,int128 xl)
|
||||
{
|
||||
// Deal with xh first
|
||||
a=xl;
|
||||
a^=(xh<<l0);
|
||||
a^=(xh<<l1);
|
||||
|
||||
// Now deal with last int128
|
||||
int128 hi=a>>n;
|
||||
while (hi==0)
|
||||
{ a&=mask;
|
||||
|
||||
a^=hi;
|
||||
a^=(hi<<t1);
|
||||
hi=a>>n;
|
||||
}
|
||||
}
|
||||
|
||||
void gf2n_long::reduce_pentanomial(int128 xh, int128 xl)
|
||||
{
|
||||
// Deal with xh first
|
||||
a=xl;
|
||||
int128 upper, lower;
|
||||
upper=xh&uppermask;
|
||||
lower=xh&lowermask;
|
||||
// Upper part
|
||||
int128 tmp = 0;
|
||||
tmp^=(upper>>(n-t1-l0));
|
||||
tmp^=(upper>>(n-t1-l1));
|
||||
tmp^=(upper>>(n-t1-l2));
|
||||
tmp^=(upper>>(n-t1-l3));
|
||||
lower^=(tmp>>(l1));
|
||||
a^=(tmp<<(n-l1));
|
||||
// Lower part
|
||||
a^=(lower<<l0);
|
||||
a^=(lower<<l1);
|
||||
a^=(lower<<l2);
|
||||
a^=(lower<<l3);
|
||||
|
||||
/*
|
||||
// Now deal with last int128
|
||||
int128 hi=a>>n;
|
||||
while (hi!=0)
|
||||
{ a&=mask;
|
||||
|
||||
a^=hi;
|
||||
a^=(hi<<t1);
|
||||
a^=(hi<<t2);
|
||||
a^=(hi<<t3);
|
||||
|
||||
hi=a>>n;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
gf2n_long& gf2n_long::mul(const gf2n_long& x,const gf2n_long& y)
|
||||
{
|
||||
__m128i res[2];
|
||||
memset(res,0,sizeof(res));
|
||||
|
||||
mul128(x.a.a,y.a.a,res,res+1);
|
||||
|
||||
reduce(res[1],res[0]);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
class int129
|
||||
{
|
||||
int128 lower;
|
||||
bool msb;
|
||||
|
||||
public:
|
||||
int129() : lower(_mm_setzero_si128()), msb(false) { }
|
||||
int129(int128 lower, bool msb) : lower(lower), msb(msb) { }
|
||||
int129(int128 a) : lower(a), msb(false) { }
|
||||
int129(word a)
|
||||
{ *this = a; }
|
||||
int128 get_lower() { return lower; }
|
||||
int129& operator=(const __m128i& other)
|
||||
{ lower = other; msb = false; return *this; }
|
||||
int129& operator=(const word& other)
|
||||
{ lower = _mm_set_epi64x(0, other); msb = false; return *this; }
|
||||
bool operator==(const int129& other)
|
||||
{ return (lower == other.lower) && (msb == other.msb); }
|
||||
bool operator!=(const int129& other)
|
||||
{ return !(*this == other); }
|
||||
bool operator>=(const int129& other)
|
||||
{ //cout << ">= " << msb << other.msb << (msb > other.msb) << is_ge(lower.a, other.lower.a) << endl;
|
||||
return msb == other.msb ? is_ge(lower.a, other.lower.a) : msb > other.msb; }
|
||||
int129 operator<<(int other)
|
||||
{ return int129(lower << other, _mm_cvtsi128_si32(((lower >> (128-other)) & 1).a)); }
|
||||
int129& operator>>=(int other)
|
||||
{ lower >>= other; lower |= (int128(msb) << (128-other)); msb = !other; return *this; }
|
||||
int129 operator^(const int129& other)
|
||||
{ return int129(lower ^ other.lower, msb ^ other.msb); }
|
||||
int129& operator^=(const int129& other)
|
||||
{ lower ^= other.lower; msb ^= other.msb; return *this; }
|
||||
int129 operator&(const word& other)
|
||||
{ return int129(lower & other, false); }
|
||||
friend ostream& operator<<(ostream& s, const int129& a)
|
||||
{ s << a.msb << a.lower; return s; }
|
||||
};
|
||||
|
||||
void gf2n_long::invert()
|
||||
{
|
||||
if (is_one()) { return; }
|
||||
if (is_zero()) { throw division_by_zero(); }
|
||||
|
||||
int129 u,v=a,B=0,D=1,mod=1;
|
||||
|
||||
mod^=(int129(1)<<n);
|
||||
mod^=(int129(1)<<t1);
|
||||
if (nterms==3)
|
||||
{ mod^=(int129(1)<<t2);
|
||||
mod^=(int129(1)<<t3);
|
||||
}
|
||||
u=mod; v=a;
|
||||
|
||||
while (u!=0)
|
||||
{ while ((u&1)==0)
|
||||
{ u>>=1;
|
||||
if ((B&1)!=0) { B^=mod; }
|
||||
B>>=1;
|
||||
}
|
||||
while ((v&1)==0 && v!=0)
|
||||
{ v>>=1;
|
||||
if ((D&1)!=0) { D^=mod; }
|
||||
D>>=1;
|
||||
}
|
||||
|
||||
if (u>=v) { u=u^v; B=B^D; }
|
||||
else { v=v^u; D=D^B; }
|
||||
}
|
||||
|
||||
a=D.get_lower();
|
||||
}
|
||||
|
||||
|
||||
void gf2n_long::randomize(PRNG& G)
|
||||
{
|
||||
a=G.get_doubleword();
|
||||
a&=mask;
|
||||
}
|
||||
|
||||
|
||||
void gf2n_long::output(ostream& s,bool human) const
|
||||
{
|
||||
if (human)
|
||||
{ s << *this; }
|
||||
else
|
||||
{ s.write((char*) &a,sizeof(__m128i)); }
|
||||
}
|
||||
|
||||
void gf2n_long::input(istream& s,bool human)
|
||||
{
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
}
|
||||
//throw end_of_file();
|
||||
s.clear(); // unset EOF flag
|
||||
s.seekg(0);
|
||||
if (!rewind)
|
||||
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
|
||||
rewind = true;
|
||||
}
|
||||
|
||||
if (human)
|
||||
{ s >> *this; }
|
||||
else
|
||||
{ s.read((char*) &a,sizeof(__m128i)); }
|
||||
}
|
||||
274
Math/gf2nlong.h
Normal file
274
Math/gf2nlong.h
Normal file
@@ -0,0 +1,274 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* gf2nlong.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MATH_GF2NLONG_H_
|
||||
#define MATH_GF2NLONG_H_
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include <smmintrin.h>
|
||||
|
||||
#include "Tools/random.h"
|
||||
#include "Math/field_types.h"
|
||||
|
||||
|
||||
class int128
|
||||
{
|
||||
public:
|
||||
__m128i a;
|
||||
|
||||
int128() : a(_mm_setzero_si128()) { }
|
||||
int128(const int128& a) : a(a.a) { }
|
||||
int128(const __m128i& a) : a(a) { }
|
||||
int128(const word& a) : a(_mm_cvtsi64_si128(a)) { }
|
||||
int128(const word& upper, const word& lower) : a(_mm_set_epi64x(upper, lower)) { }
|
||||
|
||||
word get_lower() { return (word)_mm_cvtsi128_si64(a); }
|
||||
|
||||
bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); }
|
||||
bool operator!=(const int128& other) const { return !(*this == other); }
|
||||
|
||||
int128 operator<<(const int& other) const;
|
||||
int128 operator>>(const int& other) const;
|
||||
|
||||
int128 operator^(const int128& other) const { return a ^ other.a; }
|
||||
int128 operator|(const int128& other) const { return a | other.a; }
|
||||
int128 operator&(const int128& other) const { return a & other.a; }
|
||||
|
||||
int128 operator~() const { return ~a; }
|
||||
|
||||
int128& operator<<=(const int& other) { return *this = *this << other; }
|
||||
int128& operator>>=(const int& other) { return *this = *this >> other; }
|
||||
|
||||
int128& operator^=(const int128& other) { a ^= other.a; return *this; }
|
||||
int128& operator|=(const int128& other) { a |= other.a; return *this; }
|
||||
int128& operator&=(const int128& other) { a &= other.a; return *this; }
|
||||
|
||||
friend ostream& operator<<(ostream& s, const int128& a);
|
||||
};
|
||||
|
||||
|
||||
/* This interface compatible with the gfp interface
|
||||
* which then allows us to template the Share
|
||||
* data type.
|
||||
*/
|
||||
|
||||
|
||||
/*
|
||||
Arithmetic in Gf_{2^n} with n<=128
|
||||
*/
|
||||
|
||||
class gf2n_long
|
||||
{
|
||||
int128 a;
|
||||
|
||||
static int n,t1,t2,t3,nterms;
|
||||
static int l0,l1,l2,l3;
|
||||
static int128 mask,lowermask,uppermask;
|
||||
static bool rewind;
|
||||
|
||||
/* Assign x[0..2*nwords] to a and reduce it... */
|
||||
void reduce_trinomial(int128 xh,int128 xl);
|
||||
void reduce_pentanomial(int128 xh,int128 xl);
|
||||
|
||||
public:
|
||||
|
||||
typedef gf2n_long value_type;
|
||||
typedef int128 internal_type;
|
||||
|
||||
void reduce(int128 xh,int128 xl)
|
||||
{
|
||||
if (nterms==3)
|
||||
{ reduce_pentanomial(xh,xl); }
|
||||
else
|
||||
{ reduce_trinomial(xh,xl); }
|
||||
}
|
||||
|
||||
static void init_field(int nn);
|
||||
static int degree() { return n; }
|
||||
static int get_nterms() { return nterms; }
|
||||
static int get_t(int i)
|
||||
{ if (i==0) { return t1; }
|
||||
else if (i==1) { return t2; }
|
||||
else if (i==2) { return t3; }
|
||||
return -1;
|
||||
}
|
||||
|
||||
static DataFieldType field_type() { return DATA_GF2N; }
|
||||
static char type_char() { return '2'; }
|
||||
static string type_string() { return "gf2n_long"; }
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
static int t() { return 0; }
|
||||
|
||||
int128 get() const { return a; }
|
||||
__m128i to_m128i() const { return a.a; }
|
||||
word get_word() const { return _mm_cvtsi128_si64x(a.a); }
|
||||
|
||||
void assign(const gf2n_long& g) { a=g.a; }
|
||||
|
||||
void assign_zero() { a=_mm_setzero_si128(); }
|
||||
void assign_one() { a=int128(0,1); }
|
||||
void assign_x() { a=int128(0,2); }
|
||||
void assign(int128 aa) { a=aa&mask; }
|
||||
void assign(int aa) { a=int128(static_cast<unsigned int>(aa))&mask; }
|
||||
void assign(const char* buffer) { a = _mm_loadu_si128((__m128i*)buffer); }
|
||||
|
||||
int get_bit(int i) const
|
||||
{ return ((a>>i)&1).get_lower(); }
|
||||
void set_bit(int i,unsigned int b)
|
||||
{ if (b==1)
|
||||
{ a |= (1UL<<i); }
|
||||
else
|
||||
{ a &= ~(1UL<<i); }
|
||||
}
|
||||
|
||||
gf2n_long() { assign_zero(); }
|
||||
gf2n_long(const gf2n_long& g) { assign(g); }
|
||||
gf2n_long(const int128& g) { assign(g); }
|
||||
gf2n_long(int g) { assign(g); }
|
||||
~gf2n_long() { ; }
|
||||
|
||||
gf2n_long& operator=(const gf2n_long& g)
|
||||
{ if (&g!=this) { assign(g); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
int is_zero() const { return a==int128(0); }
|
||||
int is_one() const { return a==int128(1); }
|
||||
int equal(const gf2n_long& y) const { return (a==y.a); }
|
||||
bool operator==(const gf2n_long& y) const { return a==y.a; }
|
||||
bool operator!=(const gf2n_long& y) const { return a!=y.a; }
|
||||
|
||||
// x+y
|
||||
void add(const gf2n_long& x,const gf2n_long& y)
|
||||
{ a=x.a^y.a; }
|
||||
void add(const gf2n_long& x)
|
||||
{ a^=x.a; }
|
||||
template<int T>
|
||||
void add(octet* x)
|
||||
{ a^=int128(_mm_loadu_si128((__m128i*)x)); }
|
||||
void add(octet* x)
|
||||
{ add<0>(x); }
|
||||
void sub(const gf2n_long& x,const gf2n_long& y)
|
||||
{ a=x.a^y.a; }
|
||||
void sub(const gf2n_long& x)
|
||||
{ a^=x.a; }
|
||||
// = x * y
|
||||
gf2n_long& mul(const gf2n_long& x,const gf2n_long& y);
|
||||
void mul(const gf2n_long& x) { mul(*this,x); }
|
||||
// x * y when one of x,y is a bit
|
||||
void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; }
|
||||
|
||||
gf2n_long operator+(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; }
|
||||
gf2n_long operator*(const gf2n_long& x) { gf2n_long res; res.mul(*this, x); return res; }
|
||||
gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; }
|
||||
gf2n_long& operator*=(const gf2n_long& x) { mul(x); return *this; }
|
||||
gf2n_long operator-(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; }
|
||||
gf2n_long& operator-=(const gf2n_long& x) { sub(x); return *this; }
|
||||
|
||||
void square();
|
||||
void square(const gf2n_long& aa);
|
||||
void invert();
|
||||
void invert(const gf2n_long& aa)
|
||||
{ *this=aa; invert(); }
|
||||
void negate() { return; }
|
||||
void power(long i);
|
||||
|
||||
/* Bitwise Ops */
|
||||
void AND(const gf2n_long& x,const gf2n_long& y) { a=x.a&y.a; }
|
||||
void XOR(const gf2n_long& x,const gf2n_long& y) { a=x.a^y.a; }
|
||||
void OR(const gf2n_long& x,const gf2n_long& y) { a=x.a|y.a; }
|
||||
void NOT(const gf2n_long& x) { a=(~x.a)&mask; }
|
||||
void SHL(const gf2n_long& x,int n) { a=(x.a<<n)&mask; }
|
||||
void SHR(const gf2n_long& x,int n) { a=x.a>>n; }
|
||||
|
||||
gf2n_long operator&(const gf2n_long& x) { gf2n_long res; res.AND(*this, x); return res; }
|
||||
gf2n_long operator^(const gf2n_long& x) { gf2n_long res; res.XOR(*this, x); return res; }
|
||||
gf2n_long operator|(const gf2n_long& x) { gf2n_long res; res.OR(*this, x); return res; }
|
||||
gf2n_long operator!() { gf2n_long res; res.NOT(*this); return res; }
|
||||
gf2n_long operator<<(int i) { gf2n_long res; res.SHL(*this, i); return res; }
|
||||
gf2n_long operator>>(int i) { gf2n_long res; res.SHR(*this, i); return res; }
|
||||
|
||||
/* Crap RNG */
|
||||
void randomize(PRNG& G);
|
||||
// compatibility with gfp
|
||||
void almost_randomize(PRNG& G) { randomize(G); }
|
||||
|
||||
void output(ostream& s,bool human) const;
|
||||
void input(istream& s,bool human);
|
||||
|
||||
friend ostream& operator<<(ostream& s,const gf2n_long& x)
|
||||
{ s << hex << x.a << dec;
|
||||
return s;
|
||||
}
|
||||
friend istream& operator>>(istream& s,gf2n_long& x)
|
||||
{ bigint tmp;
|
||||
s >> hex >> tmp >> dec;
|
||||
x.a = 0;
|
||||
mpn_copyi((word*)&x.a.a, tmp.get_mpz_t()->_mp_d, tmp.get_mpz_t()->_mp_size);
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const
|
||||
{ o.append((octet*) &a,sizeof(__m128i)); }
|
||||
void unpack(octetStream& o)
|
||||
{ o.consume((octet*) &a,sizeof(__m128i)); }
|
||||
};
|
||||
|
||||
|
||||
inline int128 int128::operator<<(const int& other) const
|
||||
{
|
||||
int128 res(_mm_slli_epi64(a, other));
|
||||
__m128i mask;
|
||||
if (other < 64)
|
||||
mask = _mm_srli_epi64(a, 64 - other);
|
||||
else
|
||||
mask = _mm_slli_epi64(a, other - 64);
|
||||
res.a ^= _mm_slli_si128(mask, 8);
|
||||
return res;
|
||||
}
|
||||
|
||||
inline int128 int128::operator>>(const int& other) const
|
||||
{
|
||||
int128 res(_mm_srli_epi64(a, other));
|
||||
__m128i mask;
|
||||
if (other < 64)
|
||||
mask = _mm_slli_epi64(a, 64 - other);
|
||||
else
|
||||
mask = _mm_srli_epi64(a, other - 64);
|
||||
res.a ^= _mm_srli_si128(mask, 8);
|
||||
return res;
|
||||
}
|
||||
|
||||
inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2)
|
||||
{
|
||||
__m128i tmp3, tmp4, tmp5, tmp6;
|
||||
|
||||
tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
|
||||
tmp4 = _mm_clmulepi64_si128(a, b, 0x10);
|
||||
tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
|
||||
tmp6 = _mm_clmulepi64_si128(a, b, 0x11);
|
||||
|
||||
tmp4 = _mm_xor_si128(tmp4, tmp5);
|
||||
tmp5 = _mm_slli_si128(tmp4, 8);
|
||||
tmp4 = _mm_srli_si128(tmp4, 8);
|
||||
tmp3 = _mm_xor_si128(tmp3, tmp5);
|
||||
tmp6 = _mm_xor_si128(tmp6, tmp4);
|
||||
// initial mul now in tmp3, tmp6
|
||||
*res1 = tmp3;
|
||||
*res2 = tmp6;
|
||||
}
|
||||
|
||||
#endif /* MATH_GF2NLONG_H_ */
|
||||
125
Math/gfp.cpp
Normal file
125
Math/gfp.cpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Math/gfp.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
Zp_Data gfp::ZpD;
|
||||
|
||||
void gfp::almost_randomize(PRNG& G)
|
||||
{
|
||||
G.get_octets((octet*)a.x,t()*sizeof(mp_limb_t));
|
||||
a.x[t()-1]&=ZpD.mask;
|
||||
}
|
||||
|
||||
void gfp::AND(const gfp& x,const gfp& y)
|
||||
{
|
||||
bigint bi1,bi2;
|
||||
to_bigint(bi1,x);
|
||||
to_bigint(bi2,y);
|
||||
mpz_and(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
|
||||
to_gfp(*this, bi1);
|
||||
}
|
||||
|
||||
void gfp::OR(const gfp& x,const gfp& y)
|
||||
{
|
||||
bigint bi1,bi2;
|
||||
to_bigint(bi1,x);
|
||||
to_bigint(bi2,y);
|
||||
mpz_ior(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
|
||||
to_gfp(*this, bi1);
|
||||
}
|
||||
|
||||
void gfp::XOR(const gfp& x,const gfp& y)
|
||||
{
|
||||
bigint bi1,bi2;
|
||||
to_bigint(bi1,x);
|
||||
to_bigint(bi2,y);
|
||||
mpz_xor(bi1.get_mpz_t(), bi1.get_mpz_t(), bi2.get_mpz_t());
|
||||
to_gfp(*this, bi1);
|
||||
}
|
||||
|
||||
void gfp::AND(const gfp& x,const bigint& y)
|
||||
{
|
||||
bigint bi;
|
||||
to_bigint(bi,x);
|
||||
mpz_and(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
|
||||
to_gfp(*this, bi);
|
||||
}
|
||||
|
||||
void gfp::OR(const gfp& x,const bigint& y)
|
||||
{
|
||||
bigint bi;
|
||||
to_bigint(bi,x);
|
||||
mpz_ior(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
|
||||
to_gfp(*this, bi);
|
||||
}
|
||||
|
||||
void gfp::XOR(const gfp& x,const bigint& y)
|
||||
{
|
||||
bigint bi;
|
||||
to_bigint(bi,x);
|
||||
mpz_xor(bi.get_mpz_t(), bi.get_mpz_t(), y.get_mpz_t());
|
||||
to_gfp(*this, bi);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void gfp::SHL(const gfp& x,int n)
|
||||
{
|
||||
if (!x.is_zero())
|
||||
{
|
||||
bigint bi;
|
||||
to_bigint(bi,x,false);
|
||||
mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
|
||||
to_gfp(*this, bi);
|
||||
}
|
||||
else
|
||||
{
|
||||
assign_zero();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gfp::SHR(const gfp& x,int n)
|
||||
{
|
||||
if (!x.is_zero())
|
||||
{
|
||||
bigint bi;
|
||||
to_bigint(bi,x);
|
||||
mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
|
||||
to_gfp(*this, bi);
|
||||
}
|
||||
else
|
||||
{
|
||||
assign_zero();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gfp::SHL(const gfp& x,const bigint& n)
|
||||
{
|
||||
SHL(x,mpz_get_si(n.get_mpz_t()));
|
||||
}
|
||||
|
||||
|
||||
void gfp::SHR(const gfp& x,const bigint& n)
|
||||
{
|
||||
SHR(x,mpz_get_si(n.get_mpz_t()));
|
||||
}
|
||||
|
||||
|
||||
gfp gfp::sqrRoot()
|
||||
{
|
||||
// Temp move to bigint so as to call sqrRootMod
|
||||
bigint ti;
|
||||
to_bigint(ti, *this);
|
||||
ti = sqrRootMod(ti, ZpD.pr);
|
||||
if (!isOdd(ti))
|
||||
ti = ZpD.pr - ti;
|
||||
gfp temp;
|
||||
to_gfp(temp, ti);
|
||||
return temp;
|
||||
}
|
||||
205
Math/gfp.h
Normal file
205
Math/gfp.h
Normal file
@@ -0,0 +1,205 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _gfp
|
||||
#define _gfp
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/modp.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Math/field_types.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
/* This is a wrapper class for the modp data type
|
||||
* It is used to be interface compatible with the gfp
|
||||
* type, which then allows us to template the Share
|
||||
* data type.
|
||||
*
|
||||
* So gfp is used ONLY for the stuff in the finite fields
|
||||
* we are going to be doing MPC over, not the modp stuff
|
||||
* for the FHE scheme
|
||||
*/
|
||||
|
||||
|
||||
class gfp
|
||||
{
|
||||
modp a;
|
||||
static Zp_Data ZpD;
|
||||
|
||||
public:
|
||||
|
||||
typedef gfp value_type;
|
||||
|
||||
static void init_field(const bigint& p,bool mont=true)
|
||||
{ ZpD.init(p,mont); }
|
||||
static bigint pr()
|
||||
{ return ZpD.pr; }
|
||||
static int t()
|
||||
{ return ZpD.get_t(); }
|
||||
static Zp_Data& get_ZpD()
|
||||
{ return ZpD; }
|
||||
|
||||
static DataFieldType field_type() { return DATA_MODP; }
|
||||
static char type_char() { return 'p'; }
|
||||
static string type_string() { return "gfp"; }
|
||||
|
||||
static int size() { return t() * sizeof(mp_limb_t); }
|
||||
|
||||
void assign(const gfp& g) { a=g.a; }
|
||||
void assign_zero() { assignZero(a,ZpD); }
|
||||
void assign_one() { assignOne(a,ZpD); }
|
||||
void assign(word aa) { bigint b=aa; to_gfp(*this,b); }
|
||||
void assign(long aa) { bigint b=aa; to_gfp(*this,b); }
|
||||
void assign(int aa) { bigint b=aa; to_gfp(*this,b); }
|
||||
void assign(const char* buffer) { a.assign(buffer, ZpD.get_t()); }
|
||||
|
||||
modp get() const { return a; }
|
||||
|
||||
// Assumes prD behind x is equal to ZpD
|
||||
void assign(modp& x) { a=x; }
|
||||
|
||||
gfp() { assignZero(a,ZpD); }
|
||||
gfp(const gfp& g) { a=g.a; }
|
||||
gfp(const modp& g) { a=g; }
|
||||
gfp(const __m128i& x) { *this=x; }
|
||||
gfp(const int128& x) { *this=x.a; }
|
||||
gfp(const bigint& x) { to_modp(a, x, ZpD); }
|
||||
gfp(int x) { assign(x); }
|
||||
~gfp() { ; }
|
||||
|
||||
gfp& operator=(const gfp& g)
|
||||
{ if (&g!=this) { a=g.a; }
|
||||
return *this;
|
||||
}
|
||||
|
||||
gfp& operator=(const __m128i other)
|
||||
{
|
||||
memcpy(a.x, &other, sizeof(other));
|
||||
a.x[2] = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void to_m128i(__m128i& ans)
|
||||
{
|
||||
memcpy(&ans, a.x, sizeof(ans));
|
||||
}
|
||||
|
||||
__m128i to_m128i()
|
||||
{
|
||||
return _mm_loadu_si128((__m128i*)a.x);
|
||||
}
|
||||
|
||||
|
||||
bool is_zero() const { return isZero(a,ZpD); }
|
||||
bool is_one() const { return isOne(a,ZpD); }
|
||||
bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); }
|
||||
bool operator==(const gfp& y) const { return equal(y); }
|
||||
bool operator!=(const gfp& y) const { return !equal(y); }
|
||||
|
||||
// x+y
|
||||
template <int T>
|
||||
void add(const gfp& x,const gfp& y)
|
||||
{ Add<T>(a,x.a,y.a,ZpD); }
|
||||
template <int T>
|
||||
void add(const gfp& x)
|
||||
{ Add<T>(a,a,x.a,ZpD); }
|
||||
template <int T>
|
||||
void add(void* x)
|
||||
{ ZpD.Add<T>(a.x,a.x,(mp_limb_t*)x); }
|
||||
void add(const gfp& x,const gfp& y)
|
||||
{ Add(a,x.a,y.a,ZpD); }
|
||||
void add(const gfp& x)
|
||||
{ Add(a,a,x.a,ZpD); }
|
||||
void add(void* x)
|
||||
{ ZpD.Add(a.x,a.x,(mp_limb_t*)x); }
|
||||
void sub(const gfp& x,const gfp& y)
|
||||
{ Sub(a,x.a,y.a,ZpD); }
|
||||
void sub(const gfp& x)
|
||||
{ Sub(a,a,x.a,ZpD); }
|
||||
// = x * y
|
||||
void mul(const gfp& x,const gfp& y)
|
||||
{ Mul(a,x.a,y.a,ZpD); }
|
||||
void mul(const gfp& x)
|
||||
{ Mul(a,a,x.a,ZpD); }
|
||||
|
||||
gfp operator+(const gfp& x) { gfp res; res.add(*this, x); return res; }
|
||||
gfp operator-(const gfp& x) { gfp res; res.sub(*this, x); return res; }
|
||||
gfp operator*(const gfp& x) { gfp res; res.mul(*this, x); return res; }
|
||||
gfp& operator+=(const gfp& x) { add(x); return *this; }
|
||||
gfp& operator-=(const gfp& x) { sub(x); return *this; }
|
||||
gfp& operator*=(const gfp& x) { mul(x); return *this; }
|
||||
|
||||
void square(const gfp& aa)
|
||||
{ Sqr(a,aa.a,ZpD); }
|
||||
void square()
|
||||
{ Sqr(a,a,ZpD); }
|
||||
void invert()
|
||||
{ Inv(a,a,ZpD); }
|
||||
void invert(const gfp& aa)
|
||||
{ Inv(a,aa.a,ZpD); }
|
||||
void negate()
|
||||
{ Negate(a,a,ZpD); }
|
||||
void power(long i)
|
||||
{ Power(a,a,i,ZpD); }
|
||||
|
||||
// deterministic square root
|
||||
gfp sqrRoot();
|
||||
|
||||
void randomize(PRNG& G)
|
||||
{ a.randomize(G,ZpD); }
|
||||
// faster randomization, see implementation for explanation
|
||||
void almost_randomize(PRNG& G);
|
||||
|
||||
void output(ostream& s,bool human) const
|
||||
{ a.output(s,ZpD,human); }
|
||||
void input(istream& s,bool human)
|
||||
{ a.input(s,ZpD,human); }
|
||||
|
||||
friend ostream& operator<<(ostream& s,const gfp& x)
|
||||
{ x.output(s,true);
|
||||
return s;
|
||||
}
|
||||
friend istream& operator>>(istream& s,gfp& x)
|
||||
{ x.input(s,true);
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Bitwise Ops
|
||||
* - Converts gfp args to bigints and then converts answer back to gfp
|
||||
*/
|
||||
void AND(const gfp& x,const gfp& y);
|
||||
void XOR(const gfp& x,const gfp& y);
|
||||
void OR(const gfp& x,const gfp& y);
|
||||
void AND(const gfp& x,const bigint& y);
|
||||
void XOR(const gfp& x,const bigint& y);
|
||||
void OR(const gfp& x,const bigint& y);
|
||||
void SHL(const gfp& x,int n);
|
||||
void SHR(const gfp& x,int n);
|
||||
void SHL(const gfp& x,const bigint& n);
|
||||
void SHR(const gfp& x,const bigint& n);
|
||||
|
||||
gfp operator&(const gfp& x) { gfp res; res.AND(*this, x); return res; }
|
||||
gfp operator^(const gfp& x) { gfp res; res.XOR(*this, x); return res; }
|
||||
gfp operator|(const gfp& x) { gfp res; res.OR(*this, x); return res; }
|
||||
gfp operator<<(int i) { gfp res; res.SHL(*this, i); return res; }
|
||||
gfp operator>>(int i) { gfp res; res.SHR(*this, i); return res; }
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const
|
||||
{ a.pack(o,ZpD); }
|
||||
void unpack(octetStream& o)
|
||||
{ a.unpack(o,ZpD); }
|
||||
|
||||
|
||||
// Convert representation to and from a bigint number
|
||||
friend void to_bigint(bigint& ans,const gfp& x,bool reduce=true)
|
||||
{ to_bigint(ans,x.a,x.ZpD,reduce); }
|
||||
friend void to_gfp(gfp& ans,const bigint& x)
|
||||
{ to_modp(ans.a,x,ans.ZpD); }
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
263
Math/modp.cpp
Normal file
263
Math/modp.cpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Zp_Data.h"
|
||||
#include "modp.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
bool modp::rewind = false;
|
||||
|
||||
/***********************************************************************
|
||||
* The following functions remain the same in Real and Montgomery rep *
|
||||
***********************************************************************/
|
||||
|
||||
void modp::randomize(PRNG& G, const Zp_Data& ZpD)
|
||||
{
|
||||
bigint x=G.randomBnd(ZpD.pr);
|
||||
to_modp(*this,x,ZpD);
|
||||
}
|
||||
|
||||
void modp::pack(octetStream& o,const Zp_Data& ZpD) const
|
||||
{
|
||||
o.append((octet*) x,ZpD.t*sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
|
||||
void modp::unpack(octetStream& o,const Zp_Data& ZpD)
|
||||
{
|
||||
o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
|
||||
void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
|
||||
{
|
||||
ZpD.Sub(ans.x, x.x, y.x);
|
||||
}
|
||||
|
||||
|
||||
void Negate(modp& ans,const modp& x,const Zp_Data& ZpD)
|
||||
{
|
||||
if (isZero(x,ZpD)) { ans=x; return; }
|
||||
mpn_sub_n(ans.x,ZpD.prA,x.x,ZpD.t);
|
||||
}
|
||||
|
||||
|
||||
bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD)
|
||||
{ if (mpn_cmp(x.x,y.x,ZpD.t)!=0)
|
||||
{ return false; }
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isZero(const modp& ans,const Zp_Data& ZpD)
|
||||
{
|
||||
for (int i=0; i<ZpD.t; i++)
|
||||
{ if (ans.x[i]!=0) { return false; } }
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/***********************************************************************
|
||||
* All the remaining functions have Montgomery variants which we need *
|
||||
* to deal with *
|
||||
***********************************************************************/
|
||||
|
||||
|
||||
void assignOne(modp& x,const Zp_Data& ZpD)
|
||||
{ if (ZpD.montgomery)
|
||||
{ mpn_copyi(x.x,ZpD.R,ZpD.t+1); }
|
||||
else
|
||||
{ assignZero(x,ZpD);
|
||||
x.x[0]=1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool isOne(const modp& x,const Zp_Data& ZpD)
|
||||
{ if (ZpD.montgomery)
|
||||
{ if (mpn_cmp(x.x,ZpD.R,ZpD.t)!=0)
|
||||
{ return false; }
|
||||
}
|
||||
else
|
||||
{ if (x.x[0]!=1) { return false; }
|
||||
for (int i=1; i<ZpD.t; i++)
|
||||
{ if (x.x[i]!=0) { return false; } }
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce)
|
||||
{
|
||||
mpz_t a;
|
||||
mpz_init2(a,MAX_MOD_SZ*sizeof(mp_limb_t)*8);
|
||||
if (ZpD.montgomery)
|
||||
{ mp_limb_t one[MAX_MOD_SZ];
|
||||
mpn_zero(one,ZpD.t+1);
|
||||
one[0]=1;
|
||||
ZpD.Mont_Mult(a->_mp_d,x.x,one);
|
||||
}
|
||||
else
|
||||
{ mpn_copyi(a->_mp_d,x.x,ZpD.t+1); }
|
||||
a->_mp_size=ZpD.t;
|
||||
if (reduce)
|
||||
while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0)
|
||||
{ a->_mp_size--; }
|
||||
ans=bigint(a);
|
||||
|
||||
mpz_clear(a);
|
||||
}
|
||||
|
||||
|
||||
void to_modp(modp& ans,int x,const Zp_Data& ZpD)
|
||||
{
|
||||
mpn_zero(ans.x,ZpD.t+1);
|
||||
if (x>=0)
|
||||
{ ans.x[0]=x;
|
||||
if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; }
|
||||
}
|
||||
else
|
||||
{ if (ZpD.t==1)
|
||||
{ ans.x[0]=(ZpD.prA[0]+x)%ZpD.prA[0]; }
|
||||
else
|
||||
{ bigint xx=ZpD.pr+x;
|
||||
to_modp(ans,xx,ZpD);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (ZpD.montgomery)
|
||||
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); }
|
||||
}
|
||||
|
||||
|
||||
void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD)
|
||||
{
|
||||
bigint xx=x%ZpD.pr;
|
||||
if (xx<0) { xx+=ZpD.pr; }
|
||||
//mpz_mod(xx.get_mpz_t(),x.get_mpz_t(),ZpD.pr.get_mpz_t());
|
||||
mpn_zero(ans.x,ZpD.t+1);
|
||||
mpn_copyi(ans.x,xx.get_mpz_t()->_mp_d,xx.get_mpz_t()->_mp_size);
|
||||
if (ZpD.montgomery)
|
||||
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R2); }
|
||||
}
|
||||
|
||||
|
||||
|
||||
void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
|
||||
{
|
||||
if (ZpD.montgomery)
|
||||
{ ZpD.Mont_Mult(ans.x,x.x,y.x); }
|
||||
else
|
||||
{ //ans.x=(x.x*y.x)%ZpD.pr;
|
||||
mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ];
|
||||
mpn_mul_n(aa,x.x,y.x,ZpD.t);
|
||||
mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD)
|
||||
{
|
||||
if (ZpD.montgomery)
|
||||
{ ZpD.Mont_Mult(ans.x,x.x,x.x); }
|
||||
else
|
||||
{ //ans.x=(x.x*x.x)%ZpD.pr;
|
||||
mp_limb_t aa[2*MAX_MOD_SZ],q[2*MAX_MOD_SZ];
|
||||
mpn_sqr(aa,x.x,ZpD.t);
|
||||
mpn_tdiv_qr(q,ans.x,0,aa,2*ZpD.t,ZpD.prA,ZpD.t);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Inv(modp& ans,const modp& x,const Zp_Data& ZpD)
|
||||
{
|
||||
mp_limb_t g[MAX_MOD_SZ],xx[MAX_MOD_SZ+1],yy[MAX_MOD_SZ+1];
|
||||
mp_size_t sz;
|
||||
mpn_copyi(xx,x.x,ZpD.t);
|
||||
mpn_copyi(yy,ZpD.prA,ZpD.t);
|
||||
mpn_gcdext(g,ans.x,&sz,xx,ZpD.t,yy,ZpD.t);
|
||||
if (sz<0)
|
||||
{ mpn_sub(ans.x,ZpD.prA,ZpD.t,ans.x,-sz);
|
||||
sz=-sz;
|
||||
}
|
||||
else
|
||||
{ for (int i=sz; i<ZpD.t; i++) { ans.x[i]=0; } }
|
||||
if (ZpD.montgomery)
|
||||
{ ZpD.Mont_Mult(ans.x,ans.x,ZpD.R3); }
|
||||
}
|
||||
|
||||
|
||||
// XXXX This is a crap version. Hopefully this is not time critical
|
||||
void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD)
|
||||
{
|
||||
if (exp==1) { ans=x; return; }
|
||||
if (exp==0) { assignOne(ans,ZpD); return; }
|
||||
if (exp<0) { throw not_implemented(); }
|
||||
modp t=x;
|
||||
assignOne(ans,ZpD);
|
||||
while (exp!=0)
|
||||
{ if ((exp&1)==1) { Mul(ans,ans,t,ZpD); }
|
||||
exp>>=1;
|
||||
Sqr(t,t,ZpD);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// XXXX This is a crap version. Hopefully this is not time critical
|
||||
void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD)
|
||||
{
|
||||
if (exp==1) { ans=x; return; }
|
||||
if (exp==0) { assignOne(ans,ZpD); return; }
|
||||
if (exp<0) { throw not_implemented(); }
|
||||
modp t=x;
|
||||
assignOne(ans,ZpD);
|
||||
bigint e=exp;
|
||||
while (e!=0)
|
||||
{ if ((e&1)==1) { Mul(ans,ans,t,ZpD); }
|
||||
e>>=1;
|
||||
Sqr(t,t,ZpD);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void modp::output(ostream& s,const Zp_Data& ZpD,bool human) const
|
||||
{
|
||||
if (human)
|
||||
{ bigint te;
|
||||
to_bigint(te,*this,ZpD);
|
||||
if (te < ZpD.pr / 2)
|
||||
s << te;
|
||||
else
|
||||
s << (te - ZpD.pr);
|
||||
}
|
||||
else
|
||||
{ s.write((char*) x,ZpD.t*sizeof(mp_limb_t)); }
|
||||
}
|
||||
|
||||
void modp::input(istream& s,const Zp_Data& ZpD,bool human)
|
||||
{
|
||||
if (s.peek() == EOF)
|
||||
{ if (s.tellg() == 0)
|
||||
{ cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
}
|
||||
//throw end_of_file();
|
||||
s.clear(); // unset EOF flag
|
||||
s.seekg(0);
|
||||
if (!rewind)
|
||||
cout << "REWINDING - ONLY FOR BENCHMARKING" << endl;
|
||||
rewind = true;
|
||||
}
|
||||
|
||||
if (human)
|
||||
{ bigint te;
|
||||
s >> te;
|
||||
to_modp(*this,te,ZpD);
|
||||
}
|
||||
else
|
||||
{ s.read((char*) x,ZpD.t*sizeof(mp_limb_t)); }
|
||||
}
|
||||
|
||||
|
||||
116
Math/modp.h
Normal file
116
Math/modp.h
Normal file
@@ -0,0 +1,116 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Modp
|
||||
#define _Modp
|
||||
|
||||
/*
|
||||
* Currently we only support an MPIR based implementation.
|
||||
*
|
||||
* What ever is type-def'd to bigint is assumed to have
|
||||
* operator overloading for all standard operators, has
|
||||
* comparison operations and istream/ostream operators >>/<<.
|
||||
*
|
||||
* All "integer" operations will be done using operator notation
|
||||
* all "modp" operations should be done using the function calls
|
||||
* below (interchange with Montgomery arithmetic).
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
#include "Math/bigint.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
|
||||
class modp
|
||||
{
|
||||
static bool rewind;
|
||||
|
||||
mp_limb_t x[MAX_MOD_SZ];
|
||||
|
||||
public:
|
||||
|
||||
// NEXT FUNCTION IS FOR DEBUG PURPOSES ONLY
|
||||
mp_limb_t get_limb(int i) { return x[i]; }
|
||||
|
||||
// use mem* functions instead of mpn_*, so the compiler can optimize
|
||||
modp()
|
||||
{ memset(x, 0, sizeof(x)); }
|
||||
modp(const modp& y)
|
||||
{ memcpy(x, y.x, sizeof(x)); }
|
||||
modp& operator=(const modp& y)
|
||||
{ if (this!=&y) { memcpy(x, y.x, sizeof(x)); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
void assign(const char* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); }
|
||||
|
||||
void randomize(PRNG& G, const Zp_Data& ZpD);
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
// i.e. When we do montgomery we dont care about decoding
|
||||
void pack(octetStream& o,const Zp_Data& ZpD) const;
|
||||
void unpack(octetStream& o,const Zp_Data& ZpD);
|
||||
|
||||
|
||||
/**********************************
|
||||
* Modp Operations *
|
||||
**********************************/
|
||||
|
||||
// Convert representation to and from a modp number
|
||||
friend void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce=true);
|
||||
|
||||
friend void to_modp(modp& ans,int x,const Zp_Data& ZpD);
|
||||
friend void to_modp(modp& ans,const bigint& x,const Zp_Data& ZpD);
|
||||
|
||||
template <int T>
|
||||
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
|
||||
{ ZpD.Add(ans.x, x.x, y.x); }
|
||||
friend void Sub(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Mul(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
friend void Sqr(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
friend void Negate(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
friend void Inv(modp& ans,const modp& x,const Zp_Data& ZpD);
|
||||
|
||||
friend void Power(modp& ans,const modp& x,int exp,const Zp_Data& ZpD);
|
||||
friend void Power(modp& ans,const modp& x,const bigint& exp,const Zp_Data& ZpD);
|
||||
|
||||
friend void assignOne(modp& x,const Zp_Data& ZpD);
|
||||
friend void assignZero(modp& x,const Zp_Data& ZpD);
|
||||
friend bool isZero(const modp& x,const Zp_Data& ZpD);
|
||||
friend bool isOne(const modp& x,const Zp_Data& ZpD);
|
||||
friend bool areEqual(const modp& x,const modp& y,const Zp_Data& ZpD);
|
||||
|
||||
// Input and output from a stream
|
||||
// - Can do in human or machine only format (later should be faster)
|
||||
// - If human output appends a space to help with reading
|
||||
// and also convert back/forth from Montgomery if needed
|
||||
void output(ostream& s,const Zp_Data& ZpD,bool human) const;
|
||||
void input(istream& s,const Zp_Data& ZpD,bool human);
|
||||
|
||||
friend class gfp;
|
||||
|
||||
};
|
||||
|
||||
|
||||
inline void assignZero(modp& x,const Zp_Data& ZpD)
|
||||
{
|
||||
if (sizeof(x.x) <= 3 * 16)
|
||||
// use memset to allow the compiler to optimize
|
||||
// if x.x is at most 3*128 bits
|
||||
memset(x.x, 0, sizeof(x.x));
|
||||
else
|
||||
mpn_zero(x.x, ZpD.t + 1);
|
||||
}
|
||||
|
||||
template <int T>
|
||||
inline void Add(modp& ans,const modp& x,const modp& y,const Zp_Data& ZpD)
|
||||
{
|
||||
ZpD.Add<T>(ans.x, x.x, y.x);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
37
Math/operators.h
Normal file
37
Math/operators.h
Normal file
@@ -0,0 +1,37 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* operations.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MATH_OPERATORS_H_
|
||||
#define MATH_OPERATORS_H_
|
||||
|
||||
template <class T>
|
||||
T operator*(const bool& x, const T& y) { return x ? y : T(); }
|
||||
template <class T>
|
||||
T operator*(const T& y, const bool& x) { return x ? y : T(); }
|
||||
template <class T>
|
||||
T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; }
|
||||
|
||||
template <class T, class U>
|
||||
T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; }
|
||||
template <class T, class U>
|
||||
T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; }
|
||||
template <class T, class U>
|
||||
T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; }
|
||||
|
||||
template <class T, class U>
|
||||
T& operator+=(T& x, const U& y) { x.add(y); return x; }
|
||||
template <class T, class U>
|
||||
T& operator*=(T& x, const U& y) { x.mul(y); return x; }
|
||||
template <class T, class U>
|
||||
T& operator-=(T& x, const U& y) { x.sub(y); return x; }
|
||||
|
||||
template <class T, class U>
|
||||
T operator/(const T& x, const U& y) { U inv = y; inv.invert(); return x * inv; }
|
||||
template <class T, class U>
|
||||
T& operator/=(const T& x, const U& y) { U inv = y; inv.invert(); return x *= inv; }
|
||||
|
||||
#endif /* MATH_OPERATORS_H_ */
|
||||
411
Networking/Player.cpp
Normal file
411
Networking/Player.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Player.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <sys/select.h>
|
||||
|
||||
// Use printf rather than cout so valgrind can detect thread issues
|
||||
|
||||
void Names::init(int player,int pnb,const char* servername)
|
||||
{
|
||||
player_no=player;
|
||||
portnum_base=pnb;
|
||||
setup_names(servername);
|
||||
setup_server();
|
||||
}
|
||||
|
||||
|
||||
void Names::init(int player,int pnb,vector<octet*> Nms)
|
||||
{
|
||||
player_no=player;
|
||||
portnum_base=pnb;
|
||||
nplayers=Nms.size();
|
||||
names.resize(nplayers);
|
||||
for (int i=0; i<nplayers; i++)
|
||||
{ names[i]=(char*)Nms[i]; }
|
||||
setup_server();
|
||||
}
|
||||
|
||||
|
||||
void Names::init(int player,int pnb,vector<string> Nms)
|
||||
{
|
||||
player_no=player;
|
||||
portnum_base=pnb;
|
||||
nplayers=Nms.size();
|
||||
names=Nms;
|
||||
setup_server();
|
||||
}
|
||||
|
||||
// initialize hostnames from file
|
||||
void Names::init(int player, int _nplayers, int pnb, const string& filename)
|
||||
{
|
||||
ifstream hostsfile(filename.c_str());
|
||||
if (hostsfile.fail())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Error opening " << filename << ". See HOSTS.example for an example.";
|
||||
throw file_error(ss.str().c_str());
|
||||
}
|
||||
player_no = player;
|
||||
nplayers = _nplayers;
|
||||
portnum_base = pnb;
|
||||
string line;
|
||||
while (getline(hostsfile, line))
|
||||
{
|
||||
if (line.length() > 0 && line.at(0) != '#')
|
||||
names.push_back(line);
|
||||
}
|
||||
if ((int)names.size() < nplayers)
|
||||
throw invalid_params();
|
||||
names.resize(nplayers);
|
||||
for (unsigned int i = 0; i < names.size(); i++)
|
||||
cerr << "name: " << names[i] << endl;
|
||||
setup_server();
|
||||
}
|
||||
|
||||
void Names::setup_names(const char *servername)
|
||||
{
|
||||
int socket_num;
|
||||
int pn = portnum_base - 1;
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
send(socket_num, (octet*)&player_no, sizeof(player_no));
|
||||
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
|
||||
|
||||
int inst=-1; // wait until instruction to start.
|
||||
while (inst != GO) { receive(socket_num, inst); }
|
||||
|
||||
// Send my name
|
||||
octet my_name[512];
|
||||
memset(my_name,0,512*sizeof(octet));
|
||||
gethostname((char*)my_name,512);
|
||||
fprintf(stderr, "My Name = %s\n",my_name);
|
||||
send(socket_num,my_name,512);
|
||||
cerr << "My number = " << player_no << endl;
|
||||
|
||||
// Now get the set of names
|
||||
int i;
|
||||
receive(socket_num,nplayers);
|
||||
cerr << nplayers << " players\n";
|
||||
names.resize(nplayers);
|
||||
for (i=0; i<nplayers; i++)
|
||||
{ octet tmp[512];
|
||||
receive(socket_num,tmp,512);
|
||||
names[i]=(char*)tmp;
|
||||
cerr << "Player " << i << " is running on machine " << names[i] << endl;
|
||||
}
|
||||
close_client_socket(socket_num);
|
||||
}
|
||||
|
||||
|
||||
void Names::setup_server()
|
||||
{
|
||||
server = new ServerSocket(portnum_base + player_no);
|
||||
}
|
||||
|
||||
|
||||
Names::Names(const Names& other)
|
||||
{
|
||||
if (other.server != 0)
|
||||
throw runtime_error("Can copy Names only when uninitialized");
|
||||
player_no = other.player_no;
|
||||
nplayers = other.nplayers;
|
||||
portnum_base = other.portnum_base;
|
||||
names = other.names;
|
||||
server = 0;
|
||||
}
|
||||
|
||||
|
||||
Names::~Names()
|
||||
{
|
||||
if (server != 0)
|
||||
delete server;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Player::Player(const Names& Nms, int id) : PlayerBase(Nms), send_to_self_socket(-1)
|
||||
{
|
||||
nplayers=Nms.nplayers;
|
||||
player_no=Nms.player_no;
|
||||
setup_sockets(Nms.names, Nms.portnum_base, id, *Nms.server);
|
||||
blk_SHA1_Init(&ctx);
|
||||
}
|
||||
|
||||
|
||||
Player::~Player()
|
||||
{
|
||||
/* Close down the sockets */
|
||||
for (int i=0; i<nplayers; i++)
|
||||
close_client_socket(sockets[i]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Set up nmachines client and server sockets to send data back and fro
|
||||
// A machine is a server between it and player i if i<=my_number
|
||||
// Can also communicate with myself, but only with send_to and receive_from
|
||||
void Player::setup_sockets(const vector<string>& names,int portnum_base,int id_base,ServerSocket& server)
|
||||
{
|
||||
sockets.resize(nplayers);
|
||||
// Set up the client side
|
||||
for (int i=player_no; i<nplayers; i++)
|
||||
{ int pn=id_base+i*nplayers+player_no;
|
||||
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),portnum_base+i,pn);
|
||||
set_up_client_socket(sockets[i],names[i].c_str(),portnum_base+i);
|
||||
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
|
||||
}
|
||||
send_to_self_socket = sockets[player_no];
|
||||
// Setting up the server side
|
||||
for (int i=0; i<=player_no; i++)
|
||||
{ int id=id_base+player_no*nplayers+i;
|
||||
fprintf(stderr, "Setting up server with id 0x%x\n",id);
|
||||
sockets[i] = server.get_connection_socket(id);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nplayers; i++)
|
||||
{
|
||||
// timeout of 5 minutes
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 300;
|
||||
tv.tv_usec = 0;
|
||||
int fl = setsockopt(sockets[i], SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(struct timeval));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
socket_players[sockets[i]] = i;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Player::send_to(int player,const octetStream& o,bool donthash) const
|
||||
{
|
||||
int socket = socket_to_send(player);
|
||||
o.Send(socket);
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
}
|
||||
|
||||
|
||||
void Player::send_all(const octetStream& o,bool donthash) const
|
||||
{ for (int i=0; i<nplayers; i++)
|
||||
{ if (i!=player_no)
|
||||
{ o.Send(sockets[i]); }
|
||||
}
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
}
|
||||
|
||||
|
||||
void Player::receive_player(int i,octetStream& o,bool donthash) const
|
||||
{
|
||||
o.reset_write_head();
|
||||
o.Receive(sockets[i]);
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
}
|
||||
|
||||
/* This is deliberately weird to avoid problems with OS max buffer
|
||||
* size getting in the way
|
||||
*/
|
||||
void Player::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
|
||||
{ for (int i=0; i<nplayers; i++)
|
||||
{ if (i>player_no)
|
||||
{ o[player_no].Send(sockets[i]); }
|
||||
else if (i<player_no)
|
||||
{ o[i].reset_write_head();
|
||||
o[i].Receive(sockets[i]);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<nplayers; i++)
|
||||
{ if (i<player_no)
|
||||
{ o[player_no].Send(sockets[i]); }
|
||||
else if (i>player_no)
|
||||
{ o[i].reset_write_head();
|
||||
o[i].Receive(sockets[i]);
|
||||
}
|
||||
}
|
||||
if (!donthash)
|
||||
{ for (int i=0; i<nplayers; i++)
|
||||
{ blk_SHA1_Update(&ctx,o[i].get_data(),o[i].get_length()); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Player::Check_Broadcast() const
|
||||
{
|
||||
octet hashVal[HASH_SIZE];
|
||||
vector<octetStream> h(nplayers);
|
||||
blk_SHA1_Final(hashVal,&ctx);
|
||||
h[player_no].append(hashVal,HASH_SIZE);
|
||||
|
||||
Broadcast_Receive(h,true);
|
||||
for (int i=0; i<nplayers; i++)
|
||||
{ if (i!=player_no)
|
||||
{ if (!h[i].equals(h[player_no]))
|
||||
{ throw broadcast_invalid(); }
|
||||
}
|
||||
}
|
||||
blk_SHA1_Init(&ctx);
|
||||
}
|
||||
|
||||
|
||||
void Player::wait_for_available(vector<int>& players, vector<int>& result) const
|
||||
{
|
||||
fd_set rfds;
|
||||
FD_ZERO(&rfds);
|
||||
int highest = 0;
|
||||
vector<int>::iterator it;
|
||||
for (it = players.begin(); it != players.end(); it++)
|
||||
{
|
||||
if (*it >= 0)
|
||||
{
|
||||
FD_SET(sockets[*it], &rfds);
|
||||
highest = max(highest, sockets[*it]);
|
||||
}
|
||||
}
|
||||
|
||||
int res = select(highest + 1, &rfds, 0, 0, 0);
|
||||
|
||||
if (res < 0)
|
||||
error("select()");
|
||||
|
||||
result.clear();
|
||||
result.reserve(res);
|
||||
for (it = players.begin(); it != players.end(); it++)
|
||||
{
|
||||
if (res == 0)
|
||||
break;
|
||||
|
||||
if (*it >= 0 && FD_ISSET(sockets[*it], &rfds))
|
||||
{
|
||||
res--;
|
||||
result.push_back(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ThreadPlayer::ThreadPlayer(const Names& Nms, int id_base) : Player(Nms, id_base)
|
||||
{
|
||||
for (int i = 0; i < Nms.num_players(); i++)
|
||||
{
|
||||
receivers.push_back(new Receiver(sockets[i]));
|
||||
receivers[i]->start();
|
||||
|
||||
senders.push_back(new Sender(socket_to_send(i)));
|
||||
senders[i]->start();
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPlayer::~ThreadPlayer()
|
||||
{
|
||||
for (unsigned int i = 0; i < receivers.size(); i++)
|
||||
{
|
||||
receivers[i]->stop();
|
||||
if (receivers[i]->timer.elapsed() > 0)
|
||||
cerr << "Waiting for receiving from " << i << ": " << receivers[i]->timer.elapsed() << endl;
|
||||
delete receivers[i];
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < senders.size(); i++)
|
||||
{
|
||||
senders[i]->stop();
|
||||
if (senders[i]->timer.elapsed() > 0)
|
||||
cerr << "Waiting for sending to " << i << ": " << senders[i]->timer.elapsed() << endl;
|
||||
delete senders[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPlayer::request_receive(int i, octetStream& o) const
|
||||
{
|
||||
receivers[i]->request(o);
|
||||
}
|
||||
|
||||
void ThreadPlayer::wait_receive(int i, octetStream& o, bool donthash) const
|
||||
{
|
||||
receivers[i]->wait(o);
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
}
|
||||
|
||||
void ThreadPlayer::receive_player(int i, octetStream& o, bool donthash) const
|
||||
{
|
||||
request_receive(i, o);
|
||||
wait_receive(i, o, donthash);
|
||||
}
|
||||
|
||||
void ThreadPlayer::send_all(const octetStream& o,bool donthash) const
|
||||
{
|
||||
for (int i=0; i<nplayers; i++)
|
||||
{ if (i!=player_no)
|
||||
senders[i]->request(o);
|
||||
}
|
||||
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
|
||||
for (int i = 0; i < nplayers; i++)
|
||||
if (i != player_no)
|
||||
senders[i]->wait(o);
|
||||
}
|
||||
|
||||
|
||||
TwoPartyPlayer::TwoPartyPlayer(const Names& Nms, int other_player, int id) : PlayerBase(Nms), other_player(other_player)
|
||||
{
|
||||
is_server = Nms.my_num() > other_player;
|
||||
setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id);
|
||||
}
|
||||
|
||||
TwoPartyPlayer::~TwoPartyPlayer()
|
||||
{
|
||||
close_client_socket(socket);
|
||||
}
|
||||
|
||||
void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id)
|
||||
{
|
||||
if (is_server)
|
||||
{
|
||||
fprintf(stderr, "Setting up server with id %d\n",id);
|
||||
socket = server.get_connection_socket(id);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id);
|
||||
set_up_client_socket(socket, hostname, pn);
|
||||
::send(socket, (unsigned char*)&id, sizeof(id));
|
||||
}
|
||||
}
|
||||
|
||||
int TwoPartyPlayer::other_player_num() const
|
||||
{
|
||||
return other_player;
|
||||
}
|
||||
|
||||
void TwoPartyPlayer::send(octetStream& o) const
|
||||
{
|
||||
o.Send(socket);
|
||||
}
|
||||
|
||||
void TwoPartyPlayer::receive(octetStream& o) const
|
||||
{
|
||||
o.reset_write_head();
|
||||
o.Receive(socket);
|
||||
}
|
||||
|
||||
void TwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
{
|
||||
{
|
||||
if (is_server)
|
||||
{
|
||||
o[0].Send(socket);
|
||||
o[1].reset_write_head();
|
||||
o[1].Receive(socket);
|
||||
}
|
||||
else
|
||||
{
|
||||
o[1].reset_write_head();
|
||||
o[1].Receive(socket);
|
||||
o[0].Send(socket);
|
||||
}
|
||||
}
|
||||
}
|
||||
185
Networking/Player.h
Normal file
185
Networking/Player.h
Normal file
@@ -0,0 +1,185 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Player
|
||||
#define _Player
|
||||
|
||||
/* Class to create a player, for KeyGen, Offline and Online phases.
|
||||
*
|
||||
* Basically handles connection to the server to obtain the names
|
||||
* of the other players. Plus sending and receiving of data
|
||||
*
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Networking/sockets.h"
|
||||
#include "Networking/ServerSocket.h"
|
||||
#include "Tools/sha1.h"
|
||||
#include "Networking/Receiver.h"
|
||||
#include "Networking/Sender.h"
|
||||
|
||||
/* Class to get the names off the server */
|
||||
class Names
|
||||
{
|
||||
vector<string> names;
|
||||
int nplayers;
|
||||
int portnum_base;
|
||||
int player_no;
|
||||
|
||||
void setup_names(const char *servername);
|
||||
|
||||
void setup_server();
|
||||
|
||||
public:
|
||||
|
||||
mutable ServerSocket* server;
|
||||
|
||||
// Usual setup names
|
||||
void init(int player,int pnb,const char* servername);
|
||||
Names(int player,int pnb,const char* servername)
|
||||
{ init(player,pnb,servername); }
|
||||
// Set up names when we KNOW who we are going to be using before hand
|
||||
void init(int player,int pnb,vector<octet*> Nms);
|
||||
Names(int player,int pnb,vector<octet*> Nms)
|
||||
{ init(player,pnb,Nms); }
|
||||
void init(int player,int pnb,vector<string> Nms);
|
||||
Names(int player,int pnb,vector<string> Nms)
|
||||
{ init(player,pnb,Nms); }
|
||||
// Set up names from file -- reads the first nplayers names in the file
|
||||
void init(int player, int nplayers, int pnb, const string& hostsfile);
|
||||
Names(int player, int nplayers, int pnb, const string& hostsfile)
|
||||
{ init(player, nplayers, pnb, hostsfile); }
|
||||
|
||||
|
||||
Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; }
|
||||
Names(const Names& other);
|
||||
~Names();
|
||||
|
||||
int num_players() const { return nplayers; }
|
||||
int my_num() const { return player_no; }
|
||||
const string get_name(int i) const { return names[i]; }
|
||||
int get_portnum_base() const { return portnum_base; }
|
||||
|
||||
friend class PlayerBase;
|
||||
friend class Player;
|
||||
friend class TwoPartyPlayer;
|
||||
};
|
||||
|
||||
|
||||
class PlayerBase
|
||||
{
|
||||
protected:
|
||||
int player_no;
|
||||
|
||||
public:
|
||||
PlayerBase(const Names& Nms) : player_no(Nms.my_num()) {}
|
||||
int my_num() const { return player_no; }
|
||||
};
|
||||
|
||||
|
||||
class Player : public PlayerBase
|
||||
{
|
||||
protected:
|
||||
vector<int> sockets;
|
||||
int send_to_self_socket;
|
||||
|
||||
void setup_sockets(const vector<string>& names,int portnum_base,int id_base,ServerSocket& server);
|
||||
|
||||
int nplayers;
|
||||
|
||||
mutable blk_SHA_CTX ctx;
|
||||
|
||||
map<int,int> socket_players;
|
||||
|
||||
int socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
|
||||
|
||||
public:
|
||||
// The offset is used for the multi-threaded call, to ensure different
|
||||
// portnum bases in each thread
|
||||
Player(const Names& Nms,int id_base=0);
|
||||
|
||||
virtual ~Player();
|
||||
|
||||
int num_players() const { return nplayers; }
|
||||
int socket(int i) const { return sockets[i]; }
|
||||
|
||||
// Send/Receive data to/from player i
|
||||
// 8-bit ints only (mainly for testing)
|
||||
void send_int(int i,int a) const { send(sockets[i],a); }
|
||||
void receive_int(int i,int& a) const { receive(sockets[i],a); }
|
||||
|
||||
// Send an octetStream to all other players
|
||||
// -- And corresponding receive
|
||||
virtual void send_all(const octetStream& o,bool donthash=false) const;
|
||||
void send_to(int player,const octetStream& o,bool donthash=false) const;
|
||||
virtual void receive_player(int i,octetStream& o,bool donthash=false) const;
|
||||
|
||||
// Receive one from player i
|
||||
|
||||
/* Broadcast and Receive data to/from all players
|
||||
* - Assumes o[player_no] contains the thing broadcast by me
|
||||
*/
|
||||
void Broadcast_Receive(vector<octetStream>& o,bool donthash=false) const;
|
||||
|
||||
/* Run Protocol To Verify Broadcast Is Correct
|
||||
* - Resets the blk_SHA_CTX at the same time
|
||||
*/
|
||||
void Check_Broadcast() const;
|
||||
|
||||
// wait for available inputs
|
||||
void wait_for_available(vector<int>& players, vector<int>& result) const;
|
||||
|
||||
// dummy functions for compatibility
|
||||
virtual void request_receive(int i, octetStream& o) const { sockets[i]; o.get_length(); }
|
||||
virtual void wait_receive(int i, octetStream& o, bool donthash=false) const { receive_player(i, o, donthash); }
|
||||
};
|
||||
|
||||
|
||||
class ThreadPlayer : public Player
|
||||
{
|
||||
public:
|
||||
mutable vector<Receiver*> receivers;
|
||||
mutable vector<Sender*> senders;
|
||||
|
||||
ThreadPlayer(const Names& Nms,int id_base=0);
|
||||
virtual ~ThreadPlayer();
|
||||
|
||||
void request_receive(int i, octetStream& o) const;
|
||||
void wait_receive(int i, octetStream& o, bool donthash=false) const;
|
||||
void receive_player(int i,octetStream& o,bool donthash=false) const;
|
||||
|
||||
void send_all(const octetStream& o,bool donthash=false) const;
|
||||
};
|
||||
|
||||
|
||||
class TwoPartyPlayer : public PlayerBase
|
||||
{
|
||||
private:
|
||||
// setup sockets for comm. with only one other player
|
||||
void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id);
|
||||
|
||||
int socket;
|
||||
bool is_server;
|
||||
int other_player;
|
||||
|
||||
public:
|
||||
TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
|
||||
~TwoPartyPlayer();
|
||||
|
||||
void send(octetStream& o) const;
|
||||
void receive(octetStream& o) const;
|
||||
|
||||
int other_player_num() const;
|
||||
|
||||
/* Send and receive to/from the other player
|
||||
* - o[0] contains my data, received data put in o[1]
|
||||
*/
|
||||
void send_receive_player(vector<octetStream>& o) const;
|
||||
};
|
||||
|
||||
#endif
|
||||
58
Networking/Receiver.cpp
Normal file
58
Networking/Receiver.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Receiver.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Receiver.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
void* run_receiver_thread(void* receiver)
|
||||
{
|
||||
((Receiver*)receiver)->run();
|
||||
return 0;
|
||||
}
|
||||
|
||||
Receiver::Receiver(int socket) : socket(socket), thread(0)
|
||||
{
|
||||
}
|
||||
|
||||
void Receiver::start()
|
||||
{
|
||||
pthread_create(&thread, 0, run_receiver_thread, this);
|
||||
}
|
||||
|
||||
void Receiver::stop()
|
||||
{
|
||||
in.stop();
|
||||
pthread_join(thread, 0);
|
||||
}
|
||||
|
||||
void Receiver::run()
|
||||
{
|
||||
octetStream* os = 0;
|
||||
while (in.pop(os))
|
||||
{
|
||||
os->reset_write_head();
|
||||
timer.start();
|
||||
os->Receive(socket);
|
||||
timer.stop();
|
||||
out.push(os);
|
||||
}
|
||||
}
|
||||
|
||||
void Receiver::request(octetStream& os)
|
||||
{
|
||||
in.push(&os);
|
||||
}
|
||||
|
||||
void Receiver::wait(octetStream& os)
|
||||
{
|
||||
octetStream* queued = 0;
|
||||
out.pop(queued);
|
||||
if (queued != &os)
|
||||
throw not_implemented();
|
||||
}
|
||||
40
Networking/Receiver.h
Normal file
40
Networking/Receiver.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Receiver.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef NETWORKING_RECEIVER_H_
|
||||
#define NETWORKING_RECEIVER_H_
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
class Receiver
|
||||
{
|
||||
int socket;
|
||||
WaitQueue<octetStream*> in;
|
||||
WaitQueue<octetStream*> out;
|
||||
pthread_t thread;
|
||||
|
||||
// prevent copying
|
||||
Receiver(const Receiver& other);
|
||||
|
||||
public:
|
||||
Timer timer;
|
||||
|
||||
Receiver(int socket);
|
||||
|
||||
void start();
|
||||
void stop();
|
||||
void run();
|
||||
|
||||
void request(octetStream& os);
|
||||
void wait(octetStream& os);
|
||||
};
|
||||
|
||||
#endif /* NETWORKING_RECEIVER_H_ */
|
||||
54
Networking/Sender.cpp
Normal file
54
Networking/Sender.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Sender.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Sender.h"
|
||||
|
||||
void* run_sender_thread(void* sender)
|
||||
{
|
||||
((Sender*)sender)->run();
|
||||
return 0;
|
||||
}
|
||||
|
||||
Sender::Sender(int socket) : socket(socket), thread(0)
|
||||
{
|
||||
}
|
||||
|
||||
void Sender::start()
|
||||
{
|
||||
pthread_create(&thread, 0, run_sender_thread, this);
|
||||
}
|
||||
|
||||
void Sender::stop()
|
||||
{
|
||||
in.stop();
|
||||
pthread_join(thread, 0);
|
||||
}
|
||||
|
||||
void Sender::run()
|
||||
{
|
||||
const octetStream* os = 0;
|
||||
while (in.pop(os))
|
||||
{
|
||||
// timer.start();
|
||||
os->Send(socket);
|
||||
// timer.stop();
|
||||
out.push(os);
|
||||
}
|
||||
}
|
||||
|
||||
void Sender::request(const octetStream& os)
|
||||
{
|
||||
in.push(&os);
|
||||
}
|
||||
|
||||
void Sender::wait(const octetStream& os)
|
||||
{
|
||||
const octetStream* queued = 0;
|
||||
out.pop(queued);
|
||||
if (queued != &os)
|
||||
throw not_implemented();
|
||||
}
|
||||
40
Networking/Sender.h
Normal file
40
Networking/Sender.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Sender.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef NETWORKING_SENDER_H_
|
||||
#define NETWORKING_SENDER_H_
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
class Sender
|
||||
{
|
||||
int socket;
|
||||
WaitQueue<const octetStream*> in;
|
||||
WaitQueue<const octetStream*> out;
|
||||
pthread_t thread;
|
||||
|
||||
// prevent copying
|
||||
Sender(const Sender& other);
|
||||
|
||||
public:
|
||||
Timer timer;
|
||||
|
||||
Sender(int socket);
|
||||
|
||||
void start();
|
||||
void stop();
|
||||
void run();
|
||||
|
||||
void request(const octetStream& os);
|
||||
void wait(const octetStream& os);
|
||||
};
|
||||
|
||||
#endif /* NETWORKING_SENDER_H_ */
|
||||
115
Networking/ServerSocket.cpp
Normal file
115
Networking/ServerSocket.cpp
Normal file
@@ -0,0 +1,115 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* ServerSocket.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <Networking/ServerSocket.h>
|
||||
#include <Networking/sockets.h>
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
using namespace std;
|
||||
|
||||
void* accept_thread(void* server_socket)
|
||||
{
|
||||
((ServerSocket*)server_socket)->accept_clients();
|
||||
return 0;
|
||||
}
|
||||
|
||||
ServerSocket::ServerSocket(int Portnum) : portnum(Portnum)
|
||||
{
|
||||
struct sockaddr_in serv; /* socket info about our server */
|
||||
|
||||
memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */
|
||||
serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */
|
||||
serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */
|
||||
serv.sin_port = htons(Portnum); /* set the server port number */
|
||||
|
||||
main_socket = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (main_socket<0) { error("set_up_socket:socket"); }
|
||||
|
||||
int one=1;
|
||||
int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
/* disable Nagle's algorithm */
|
||||
fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
octet my_name[512];
|
||||
memset(my_name,0,512*sizeof(octet));
|
||||
gethostname((char*)my_name,512);
|
||||
|
||||
/* bind serv information to mysocket
|
||||
* - Just assume it will eventually wake up
|
||||
*/
|
||||
fl=1;
|
||||
while (fl!=0)
|
||||
{ fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
|
||||
if (fl != 0)
|
||||
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
|
||||
sleep(1);
|
||||
}
|
||||
else
|
||||
{ cerr << "Bound on port " << Portnum << endl; }
|
||||
}
|
||||
if (fl<0) { error("set_up_socket:bind"); }
|
||||
|
||||
/* start listening, allowing a queue of up to 1000 pending connection */
|
||||
fl=listen(main_socket, 1000);
|
||||
if (fl<0) { error("set_up_socket:listen"); }
|
||||
|
||||
pthread_create(&thread, 0, accept_thread, this);
|
||||
}
|
||||
|
||||
ServerSocket::~ServerSocket()
|
||||
{
|
||||
pthread_cancel(thread);
|
||||
pthread_join(thread, 0);
|
||||
if (close(main_socket)) { error("close(main_socket"); };
|
||||
}
|
||||
|
||||
void ServerSocket::accept_clients()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
struct sockaddr dest;
|
||||
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
|
||||
int socksize = sizeof(dest);
|
||||
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
|
||||
if (consocket<0) { error("set_up_socket:accept"); }
|
||||
|
||||
int client_id;
|
||||
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
|
||||
|
||||
data_signal.lock();
|
||||
clients[client_id] = consocket;
|
||||
data_signal.broadcast();
|
||||
data_signal.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
int ServerSocket::get_connection_socket(int id)
|
||||
{
|
||||
data_signal.lock();
|
||||
if (used.find(id) != used.end())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Connection id " << hex << id << " already used";
|
||||
throw IO_Error(ss.str());
|
||||
}
|
||||
|
||||
while (clients.find(id) == clients.end())
|
||||
data_signal.wait();
|
||||
|
||||
int client = clients[id];
|
||||
used.insert(id);
|
||||
data_signal.unlock();
|
||||
return client;
|
||||
}
|
||||
44
Networking/ServerSocket.h
Normal file
44
Networking/ServerSocket.h
Normal file
@@ -0,0 +1,44 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* ServerSocket.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef NETWORKING_SERVERSOCKET_H_
|
||||
#define NETWORKING_SERVERSOCKET_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
using namespace std;
|
||||
|
||||
#include <pthread.h>
|
||||
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/Signal.h"
|
||||
|
||||
class ServerSocket
|
||||
{
|
||||
int main_socket, portnum;
|
||||
map<int,int> clients;
|
||||
set<int> used;
|
||||
Signal data_signal;
|
||||
pthread_t thread;
|
||||
|
||||
// disable copying
|
||||
ServerSocket(const ServerSocket& other);
|
||||
|
||||
public:
|
||||
ServerSocket(int Portnum);
|
||||
~ServerSocket();
|
||||
|
||||
void accept_clients();
|
||||
|
||||
// This depends on clients sending their id as int.
|
||||
// Has to be thread-safe.
|
||||
int get_connection_socket(int number);
|
||||
|
||||
void close_socket();
|
||||
};
|
||||
|
||||
#endif /* NETWORKING_SERVERSOCKET_H_ */
|
||||
45
Networking/data.h
Normal file
45
Networking/data.h
Normal file
@@ -0,0 +1,45 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Data
|
||||
#define _Data
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
|
||||
typedef unsigned char octet;
|
||||
|
||||
// Assumes word is a 64 bit value
|
||||
#ifdef WIN32
|
||||
typedef unsigned __int64 word;
|
||||
#else
|
||||
typedef unsigned long word;
|
||||
#endif
|
||||
|
||||
#define BROADCAST 0
|
||||
#define ROUTE 1
|
||||
#define TERMINATE 2
|
||||
#define GO 3
|
||||
|
||||
void encode_length(octet *buff,int len);
|
||||
int decode_length(octet *buff);
|
||||
|
||||
|
||||
inline void encode_length(octet *buff,int len)
|
||||
{
|
||||
if (len<0) { throw invalid_length(); }
|
||||
buff[0]=len&255;
|
||||
buff[1]=(len>>8)&255;
|
||||
buff[2]=(len>>16)&255;
|
||||
buff[3]=(len>>24)&255;
|
||||
}
|
||||
|
||||
inline int decode_length(octet *buff)
|
||||
{
|
||||
int len=buff[0]+256*buff[1]+65536*buff[2]+16777216*buff[3];
|
||||
if (len<0) { throw invalid_length(); }
|
||||
return len;
|
||||
}
|
||||
|
||||
#endif
|
||||
225
Networking/sockets.cpp
Normal file
225
Networking/sockets.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "sockets.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
void error(const char *str)
|
||||
{
|
||||
char err[1000];
|
||||
gethostname(err,1000);
|
||||
strcat(err," : ");
|
||||
strcat(err,str);
|
||||
perror(err);
|
||||
throw bad_value();
|
||||
}
|
||||
|
||||
void error(const char *str1,const char *str2)
|
||||
{
|
||||
char err[1000];
|
||||
gethostname(err,1000);
|
||||
strcat(err," : ");
|
||||
strcat(err,str1);
|
||||
strcat(err,str2);
|
||||
perror(err);
|
||||
throw bad_value();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum)
|
||||
{
|
||||
|
||||
struct sockaddr_in serv; /* socket info about our server */
|
||||
int socksize = sizeof(struct sockaddr_in);
|
||||
|
||||
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
|
||||
memset(&serv, 0, sizeof(serv)); /* zero the struct before filling the fields */
|
||||
serv.sin_family = AF_INET; /* set the type of connection to TCP/IP */
|
||||
serv.sin_addr.s_addr = INADDR_ANY; /* set our address to any interface */
|
||||
serv.sin_port = htons(Portnum); /* set the server port number */
|
||||
|
||||
main_socket = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (main_socket<0) { error("set_up_socket:socket"); }
|
||||
|
||||
int one=1;
|
||||
int fl=setsockopt(main_socket,SOL_SOCKET,SO_REUSEADDR,(char*)&one,sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
/* disable Nagle's algorithm */
|
||||
fl= setsockopt(main_socket, IPPROTO_TCP, TCP_NODELAY, (char*)&one,sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
octet my_name[512];
|
||||
memset(my_name,0,512*sizeof(octet));
|
||||
gethostname((char*)my_name,512);
|
||||
|
||||
/* bind serv information to mysocket
|
||||
* - Just assume it will eventually wake up
|
||||
*/
|
||||
fl=1;
|
||||
while (fl!=0)
|
||||
{ fl=bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
|
||||
if (fl != 0)
|
||||
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
|
||||
sleep(1);
|
||||
}
|
||||
else
|
||||
{ cerr << "Bound on port " << Portnum << endl; }
|
||||
}
|
||||
if (fl<0) { error("set_up_socket:bind"); }
|
||||
|
||||
/* start listening, allowing a queue of up to 1 pending connection */
|
||||
fl=listen(main_socket, 1);
|
||||
if (fl<0) { error("set_up_socket:listen"); }
|
||||
|
||||
consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
|
||||
|
||||
if (consocket<0) { error("set_up_socket:accept"); }
|
||||
|
||||
}
|
||||
|
||||
|
||||
void close_server_socket(int consocket,int main_socket)
|
||||
{
|
||||
if (close(consocket)) { error("close(socket)"); }
|
||||
if (close(main_socket)) { error("close(main_socket"); };
|
||||
}
|
||||
|
||||
|
||||
|
||||
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
{
|
||||
mysocket = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (mysocket<0) { error("set_up_socket:socket"); }
|
||||
|
||||
/* disable Nagle's algorithm */
|
||||
int one=1;
|
||||
int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
fl=setsockopt(mysocket, SOL_SOCKET, SO_REUSEADDR, (char*)&one, sizeof(int));
|
||||
if (fl<0) { error("set_up_socket:setsockopt"); }
|
||||
|
||||
struct sockaddr_in dest;
|
||||
dest.sin_family = AF_INET;
|
||||
dest.sin_port = htons(Portnum); // set destination port number
|
||||
|
||||
/*
|
||||
struct hostent *server;
|
||||
server=gethostbyname(hostname);
|
||||
if (server== NULL)
|
||||
{ error("set_up_socket:gethostbyname"); }
|
||||
bcopy((char *)server->h_addr,
|
||||
(char *)&dest.sin_addr.s_addr,
|
||||
server->h_length); // set destination IP number
|
||||
*/
|
||||
struct addrinfo hints, *ai=NULL,*rp;
|
||||
memset (&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_INET;
|
||||
hints.ai_flags = AI_CANONNAME;
|
||||
|
||||
octet my_name[512];
|
||||
memset(my_name,0,512*sizeof(octet));
|
||||
gethostname((char*)my_name,512);
|
||||
|
||||
int erp;
|
||||
for (int i = 0; i < 60; i++)
|
||||
{ erp=getaddrinfo (hostname, NULL, &hints, &ai);
|
||||
if (erp == 0)
|
||||
{ break; }
|
||||
else
|
||||
{ cerr << "getaddrinfo on " << my_name << " has returned '" << gai_strerror(erp) <<
|
||||
"' for " << hostname << ", trying again in a second ..." << endl;
|
||||
if (ai)
|
||||
freeaddrinfo(ai);
|
||||
sleep(1);
|
||||
}
|
||||
}
|
||||
if (erp!=0)
|
||||
{ error("set_up_socket:getaddrinfo"); }
|
||||
|
||||
for (rp=ai; rp!=NULL; rp=rp->ai_next)
|
||||
{ const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr;
|
||||
|
||||
if (ai->ai_family == AF_INET)
|
||||
{ memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
freeaddrinfo(ai);
|
||||
|
||||
|
||||
do
|
||||
{ fl=1;
|
||||
while (fl==1 || errno==EINPROGRESS)
|
||||
{ fl=connect(mysocket, (struct sockaddr *)&dest, sizeof(struct sockaddr)); }
|
||||
}
|
||||
while (fl==-1 && errno==ECONNREFUSED);
|
||||
if (fl<0) { error("set_up_socket:connect:",hostname); }
|
||||
}
|
||||
|
||||
|
||||
|
||||
void close_client_socket(int socket)
|
||||
{
|
||||
if (close(socket))
|
||||
{
|
||||
char tmp[1000];
|
||||
sprintf(tmp, "close(%d)", socket);
|
||||
error(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned long long sent_amount = 0, sent_counter = 0;
|
||||
|
||||
|
||||
void send(int socket,int a)
|
||||
{
|
||||
unsigned char msg[1];
|
||||
msg[0]=a&255;
|
||||
if (send(socket,msg,1,0)!=1)
|
||||
{ error("Send error - 2 "); }
|
||||
}
|
||||
|
||||
|
||||
void receive(int socket,int& a)
|
||||
{
|
||||
unsigned char msg[1];
|
||||
int i=0;
|
||||
while (i==0)
|
||||
{ i=recv(socket,msg,1,0);
|
||||
if (i<0) { error("Receiving error - 2"); }
|
||||
}
|
||||
a=msg[0];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void send_ack(int socket)
|
||||
{
|
||||
char msg[]="OK";
|
||||
if (send(socket,msg,2,0)!=2)
|
||||
{ error("Send Ack"); }
|
||||
}
|
||||
|
||||
|
||||
int get_ack(int socket)
|
||||
{
|
||||
char msg[]="OK";
|
||||
char msg_r[2];
|
||||
int i=0,j;
|
||||
while (2-i>0)
|
||||
{ j=recv(socket,msg_r+i,2-i,0);
|
||||
i=i+j;
|
||||
}
|
||||
|
||||
if (msg_r[0]!=msg[0] || msg_r[1]!=msg[1]) { return 1; }
|
||||
return 0;
|
||||
}
|
||||
|
||||
67
Networking/sockets.h
Normal file
67
Networking/sockets.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _sockets
|
||||
#define _sockets
|
||||
|
||||
#include "Networking/data.h"
|
||||
|
||||
#include <errno.h> /* Errors */
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
|
||||
#include <arpa/inet.h>
|
||||
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/wait.h> /* Wait for Process Termination */
|
||||
|
||||
|
||||
void error(const char *str1,const char *str2);
|
||||
void error(const char *str);
|
||||
|
||||
void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum);
|
||||
void close_server_socket(int consocket,int main_socket);
|
||||
|
||||
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum);
|
||||
void close_client_socket(int socket);
|
||||
|
||||
void send(int socket,octet *msg,int len);
|
||||
void receive(int socket,octet *msg,int len);
|
||||
|
||||
/* Send and receive 8 bit integers */
|
||||
void send(int socket,int a);
|
||||
void receive(int socket,int& a);
|
||||
|
||||
void send_ack(int socket);
|
||||
int get_ack(int socket);
|
||||
|
||||
|
||||
extern unsigned long long sent_amount, sent_counter;
|
||||
|
||||
inline void send(int socket,octet *msg,int len)
|
||||
{
|
||||
if (send(socket,msg,len,0)!=len)
|
||||
{ error("Send error - 1 "); }
|
||||
|
||||
sent_amount += len;
|
||||
sent_counter++;
|
||||
}
|
||||
|
||||
inline void receive(int socket,octet *msg,int len)
|
||||
{
|
||||
int i=0,j;
|
||||
while (len-i>0)
|
||||
{ j=recv(socket,msg+i,len-i,0);
|
||||
if (j<0) { error("Receiving error - 1"); }
|
||||
i=i+j;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
294
OT/BaseOT.cpp
Normal file
294
OT/BaseOT.cpp
Normal file
@@ -0,0 +1,294 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "OT/BaseOT.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <pthread.h>
|
||||
|
||||
extern "C" {
|
||||
#include "SimpleOT/ot_sender.h"
|
||||
#include "SimpleOT/ot_receiver.h"
|
||||
}
|
||||
|
||||
using namespace std;
|
||||
|
||||
const char* role_to_str(OT_ROLE role)
|
||||
{
|
||||
if (role == RECEIVER)
|
||||
return "RECEIVER";
|
||||
if (role == SENDER)
|
||||
return "SENDER";
|
||||
return "BOTH";
|
||||
}
|
||||
|
||||
OT_ROLE INV_ROLE(OT_ROLE role)
|
||||
{
|
||||
if (role == RECEIVER)
|
||||
return SENDER;
|
||||
if (role == SENDER)
|
||||
return RECEIVER;
|
||||
else
|
||||
return BOTH;
|
||||
}
|
||||
|
||||
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
|
||||
{
|
||||
if (role == SENDER)
|
||||
{
|
||||
P->send(os[0]);
|
||||
}
|
||||
else if (role == RECEIVER)
|
||||
{
|
||||
P->receive(os[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// both sender + receiver
|
||||
P->send_receive_player(os);
|
||||
}
|
||||
}
|
||||
|
||||
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
|
||||
{
|
||||
if (role == RECEIVER)
|
||||
{
|
||||
P->send(os[0]);
|
||||
}
|
||||
else if (role == SENDER)
|
||||
{
|
||||
P->receive(os[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// both
|
||||
P->send_receive_player(os);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
{
|
||||
int i, j, k, len;
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
vector<octetStream> os(2);
|
||||
SIMPLEOT_SENDER sender;
|
||||
SIMPLEOT_RECEIVER receiver;
|
||||
|
||||
unsigned char S_pack[ PACKBYTES ];
|
||||
unsigned char Rs_pack[ 2 ][ 4 * PACKBYTES ];
|
||||
unsigned char sender_keys[ 2 ][ 4 ][ HASHBYTES ];
|
||||
unsigned char receiver_keys[ 4 ][ HASHBYTES ];
|
||||
unsigned char cs[ 4 ];
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
sender_genS(&sender, S_pack);
|
||||
os[0].store_bytes(S_pack, sizeof(S_pack));
|
||||
}
|
||||
send_if_ot_sender(P, os, ot_role);
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
os[1].get_bytes((octet*) receiver.S_pack, len);
|
||||
if (len != HASHBYTES)
|
||||
{
|
||||
cerr << "Received invalid length in base OT\n";
|
||||
exit(1);
|
||||
}
|
||||
receiver_procS(&receiver);
|
||||
receiver_maketable(&receiver);
|
||||
}
|
||||
|
||||
for (i = 0; i < nOT; i += 4)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
if (new_receiver_inputs)
|
||||
receiver_inputs[i + j] = G.get_uchar()&1;
|
||||
cs[j] = receiver_inputs[i + j];
|
||||
}
|
||||
receiver_rsgen(&receiver, Rs_pack[0], cs);
|
||||
os[0].reset_write_head();
|
||||
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
|
||||
receiver_keygen(&receiver, receiver_keys);
|
||||
}
|
||||
send_if_ot_receiver(P, os, ot_role);
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
os[1].get_bytes((octet*) Rs_pack[1], len);
|
||||
if (len != sizeof(Rs_pack[1]))
|
||||
{
|
||||
cerr << "Received invalid length in base OT\n";
|
||||
exit(1);
|
||||
}
|
||||
sender_keygen(&sender, Rs_pack[1], sender_keys);
|
||||
|
||||
// Copy 128 bits of keys to sender_inputs
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
for (k = 0; k < AES_BLK_SIZE; k++)
|
||||
{
|
||||
sender_inputs[i + j][0].set_byte(k, sender_keys[0][j][k]);
|
||||
sender_inputs[i + j][1].set_byte(k, sender_keys[1][j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
// Copy keys to receiver_outputs
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
for (k = 0; k < AES_BLK_SIZE; k++)
|
||||
{
|
||||
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef BASE_OT_DEBUG
|
||||
for (j = 0; j < 4; j++)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
printf("%4d-th sender keys:", i+j);
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[0][j][k]);
|
||||
printf(" ");
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]);
|
||||
printf("\n");
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
printf("%4d-th receiver key:", i+j);
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
set_seeds();
|
||||
}
|
||||
|
||||
void BaseOT::set_seeds()
|
||||
{
|
||||
for (int i = 0; i < nOT; i++)
|
||||
{
|
||||
// Set PRG seeds
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
G_sender[i][0].SetSeed(sender_inputs[i][0].get_ptr());
|
||||
G_sender[i][1].SetSeed(sender_inputs[i][1].get_ptr());
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
G_receiver[i].SetSeed(receiver_outputs[i].get_ptr());
|
||||
}
|
||||
}
|
||||
extend_length();
|
||||
}
|
||||
|
||||
void BaseOT::extend_length()
|
||||
{
|
||||
for (int i = 0; i < nOT; i++)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
sender_inputs[i][0].randomize(G_sender[i][0]);
|
||||
sender_inputs[i][1].randomize(G_sender[i][1]);
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
receiver_outputs[i].randomize(G_receiver[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void BaseOT::check()
|
||||
{
|
||||
vector<octetStream> os(2);
|
||||
BitVector tmp_vector(8 * AES_BLK_SIZE);
|
||||
|
||||
|
||||
for (int i = 0; i < nOT; i++)
|
||||
{
|
||||
if (ot_role == SENDER)
|
||||
{
|
||||
// send both inputs over
|
||||
sender_inputs[i][0].pack(os[0]);
|
||||
sender_inputs[i][1].pack(os[0]);
|
||||
P->send(os[0]);
|
||||
}
|
||||
else if (ot_role == RECEIVER)
|
||||
{
|
||||
P->receive(os[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// both sender + receiver
|
||||
sender_inputs[i][0].pack(os[0]);
|
||||
sender_inputs[i][1].pack(os[0]);
|
||||
P->send_receive_player(os);
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
tmp_vector.unpack(os[1]);
|
||||
|
||||
if (receiver_inputs[i] == 1)
|
||||
{
|
||||
tmp_vector.unpack(os[1]);
|
||||
}
|
||||
if (!tmp_vector.equals(receiver_outputs[i]))
|
||||
{
|
||||
cerr << "Incorrect OT\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
os[0].reset_write_head();
|
||||
os[1].reset_write_head();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void FakeOT::exec_base(bool new_receiver_inputs)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
vector<octetStream> os(2);
|
||||
vector<BitVector> bv(2);
|
||||
|
||||
if ((ot_role & RECEIVER) && new_receiver_inputs)
|
||||
{
|
||||
for (int i = 0; i < nOT; i++)
|
||||
// Generate my receiver inputs
|
||||
receiver_inputs[i] = G.get_uchar()&1;
|
||||
}
|
||||
|
||||
if (ot_role & SENDER)
|
||||
for (int i = 0; i < nOT; i++)
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
sender_inputs[i][j].randomize(G);
|
||||
sender_inputs[i][j].pack(os[0]);
|
||||
}
|
||||
|
||||
send_if_ot_sender(P, os, ot_role);
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
for (int i = 0; i < nOT; i++)
|
||||
{
|
||||
for (int j = 0; j < 2; j++)
|
||||
bv[j].unpack(os[1]);
|
||||
receiver_outputs[i] = bv[receiver_inputs[i]];
|
||||
}
|
||||
|
||||
set_seeds();
|
||||
}
|
||||
93
OT/BaseOT.h
Normal file
93
OT/BaseOT.h
Normal file
@@ -0,0 +1,93 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _BASE_OT
|
||||
#define _BASE_OT
|
||||
|
||||
/* The OT thread uses the Miracl library, which is not thread safe.
|
||||
* Thus all Miracl based code is contained in this one thread so as
|
||||
* to avoid locking issues etc.
|
||||
*
|
||||
* Thus this thread serves all base OTs to all other threads
|
||||
*/
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "Tools/random.h"
|
||||
#include "OT/BitVector.h"
|
||||
|
||||
// currently always assumes BOTH, i.e. do 2 sets of OT symmetrically,
|
||||
// use bitwise & to check for role
|
||||
enum OT_ROLE
|
||||
{
|
||||
RECEIVER = 0x01,
|
||||
SENDER = 0x10,
|
||||
BOTH = 0x11
|
||||
};
|
||||
|
||||
OT_ROLE INV_ROLE(OT_ROLE role);
|
||||
|
||||
const char* role_to_str(OT_ROLE role);
|
||||
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
|
||||
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
|
||||
|
||||
class BaseOT
|
||||
{
|
||||
public:
|
||||
vector<int> receiver_inputs;
|
||||
vector< vector<BitVector> > sender_inputs;
|
||||
vector<BitVector> receiver_outputs;
|
||||
TwoPartyPlayer* P;
|
||||
|
||||
BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH)
|
||||
: P(player), nOT(nOT), ot_length(ot_length), ot_role(role)
|
||||
{
|
||||
receiver_inputs.resize(nOT);
|
||||
sender_inputs.resize(nOT, vector<BitVector>(2));
|
||||
receiver_outputs.resize(nOT);
|
||||
G_sender.resize(nOT, vector<PRNG>(2));
|
||||
G_receiver.resize(nOT);
|
||||
|
||||
for (int i = 0; i < nOT; i++)
|
||||
{
|
||||
sender_inputs[i][0] = BitVector(8 * AES_BLK_SIZE);
|
||||
sender_inputs[i][1] = BitVector(8 * AES_BLK_SIZE);
|
||||
receiver_outputs[i] = BitVector(8 * AES_BLK_SIZE);
|
||||
}
|
||||
}
|
||||
virtual ~BaseOT() {}
|
||||
|
||||
int length() { return ot_length; }
|
||||
|
||||
void set_receiver_inputs(const vector<int>& new_inputs)
|
||||
{
|
||||
if ((int)new_inputs.size() != nOT)
|
||||
throw invalid_length();
|
||||
receiver_inputs = new_inputs;
|
||||
}
|
||||
|
||||
// do the OTs -- generate fresh random choice bits by default
|
||||
virtual void exec_base(bool new_receiver_inputs=true);
|
||||
// use PRG to get the next ot_length bits
|
||||
void extend_length();
|
||||
void check();
|
||||
protected:
|
||||
int nOT, ot_length;
|
||||
OT_ROLE ot_role;
|
||||
|
||||
vector< vector<PRNG> > G_sender;
|
||||
vector<PRNG> G_receiver;
|
||||
|
||||
bool is_sender() { return (bool) (ot_role & SENDER); }
|
||||
bool is_receiver() { return (bool) (ot_role & RECEIVER); }
|
||||
|
||||
void set_seeds();
|
||||
};
|
||||
|
||||
class FakeOT : public BaseOT
|
||||
{
|
||||
public:
|
||||
FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
|
||||
BaseOT(nOT, ot_length, player, role) {}
|
||||
void exec_base(bool new_receiver_inputs=true);
|
||||
};
|
||||
|
||||
#endif
|
||||
646
OT/BitMatrix.cpp
Normal file
646
OT/BitMatrix.cpp
Normal file
@@ -0,0 +1,646 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* BitMatrix.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <smmintrin.h>
|
||||
#include <immintrin.h>
|
||||
#include <mpirxx.h>
|
||||
|
||||
#include "BitMatrix.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
union matrix16x8
|
||||
{
|
||||
__m128i whole;
|
||||
octet rows[16];
|
||||
|
||||
bool get_bit(int x, int y)
|
||||
{ return (rows[x] >> y) & 1; }
|
||||
|
||||
void input(square128& input, int x, int y);
|
||||
void transpose(square128& output, int x, int y);
|
||||
};
|
||||
|
||||
class square16
|
||||
{
|
||||
public:
|
||||
// 16x16 in two halves, 128 bits each
|
||||
matrix16x8 halves[2];
|
||||
|
||||
bool get_bit(int x, int y)
|
||||
{ return halves[y/8].get_bit(x, y % 8); }
|
||||
|
||||
void input(square128& output, int x, int y);
|
||||
void transpose(square128& output, int x, int y);
|
||||
|
||||
void check_transpose(square16& dual);
|
||||
void print();
|
||||
};
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void matrix16x8::input(square128& input, int x, int y)
|
||||
{
|
||||
for (int l = 0; l < 16; l++)
|
||||
rows[l] = input.bytes[16*x+l][y];
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void square16::input(square128& input, int x, int y)
|
||||
{
|
||||
for (int i = 0; i < 2; i++)
|
||||
halves[i].input(input, x, 2 * y + i);
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void matrix16x8::transpose(square128& output, int x, int y)
|
||||
{
|
||||
for (int j = 0; j < 8; j++)
|
||||
{
|
||||
int row = _mm_movemask_epi8(whole);
|
||||
whole = _mm_slli_epi64(whole, 1);
|
||||
|
||||
// _mm_movemask_epi8 uses most significant bit, hence +7-j
|
||||
output.doublebytes[8*x+7-j][y] = row;
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void square16::transpose(square128& output, int x, int y)
|
||||
{
|
||||
for (int i = 0; i < 2; i++)
|
||||
halves[i].transpose(output, 2 * x + i, y);
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
union matrix32x8
|
||||
{
|
||||
__m256i whole;
|
||||
octet rows[32];
|
||||
|
||||
void input(square128& input, int x, int y);
|
||||
void transpose(square128& output, int x, int y);
|
||||
};
|
||||
|
||||
class square32
|
||||
{
|
||||
public:
|
||||
matrix32x8 quarters[4];
|
||||
|
||||
void input(square128& input, int x, int y);
|
||||
void transpose(square128& output, int x, int y);
|
||||
};
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void matrix32x8::input(square128& input, int x, int y)
|
||||
{
|
||||
for (int l = 0; l < 32; l++)
|
||||
rows[l] = input.bytes[32*x+l][y];
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void square32::input(square128& input, int x, int y)
|
||||
{
|
||||
for (int i = 0; i < 4; i++)
|
||||
quarters[i].input(input, x, 4 * y + i);
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void matrix32x8::transpose(square128& output, int x, int y)
|
||||
{
|
||||
for (int j = 0; j < 8; j++)
|
||||
{
|
||||
int row = _mm256_movemask_epi8(whole);
|
||||
whole = _mm256_slli_epi64(whole, 1);
|
||||
|
||||
// _mm_movemask_epi8 uses most significant bit, hence +7-j
|
||||
output.words[8*x+7-j][y] = row;
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
inline void square32::transpose(square128& output, int x, int y)
|
||||
{
|
||||
for (int i = 0; i < 4; i++)
|
||||
quarters[i].transpose(output, 4 * x + i, y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX2__
|
||||
#warning Using AVX2 for transpose
|
||||
typedef square32 subsquare;
|
||||
#define N_SUBSQUARES 4
|
||||
#else
|
||||
typedef square16 subsquare;
|
||||
#define N_SUBSQUARES 8
|
||||
#endif
|
||||
|
||||
__attribute__((optimize("unroll-loops")))
|
||||
void square128::transpose()
|
||||
{
|
||||
for (int j = 0; j < N_SUBSQUARES; j++)
|
||||
for (int k = 0; k < j; k++)
|
||||
{
|
||||
subsquare a, b;
|
||||
a.input(*this, k, j);
|
||||
b.input(*this, j, k);
|
||||
a.transpose(*this, j, k);
|
||||
b.transpose(*this, k, j);
|
||||
}
|
||||
|
||||
for (int j = 0; j < N_SUBSQUARES; j++)
|
||||
{
|
||||
subsquare a;
|
||||
a.input(*this, j, j);
|
||||
a.transpose(*this, j, j);
|
||||
}
|
||||
}
|
||||
|
||||
void square128::randomize(PRNG& G)
|
||||
{
|
||||
G.get_octets((octet*)&rows, sizeof(rows));
|
||||
}
|
||||
|
||||
template <>
|
||||
void square128::randomize<gf2n>(int row, PRNG& G)
|
||||
{
|
||||
rows[row] = G.get_doubleword();
|
||||
}
|
||||
|
||||
template <>
|
||||
void square128::randomize<gfp>(int row, PRNG& G)
|
||||
{
|
||||
rows[row] = gfp::get_ZpD().get_random128(G);
|
||||
}
|
||||
|
||||
|
||||
void gfp_iadd(__m128i& a, __m128i& b)
|
||||
{
|
||||
gfp::get_ZpD().Add((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b);
|
||||
}
|
||||
|
||||
void gfp_isub(__m128i& a, __m128i& b)
|
||||
{
|
||||
gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&a, (mp_limb_t*)&b);
|
||||
}
|
||||
|
||||
void gfp_irsub(__m128i& a, __m128i& b)
|
||||
{
|
||||
gfp::get_ZpD().Sub((mp_limb_t*)&a, (mp_limb_t*)&b, (mp_limb_t*)&a);
|
||||
}
|
||||
|
||||
template<>
|
||||
void square128::conditional_add<gf2n>(BitVector& conditions, square128& other, int offset)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
if (conditions.get_bit(128 * offset + i))
|
||||
rows[i] ^= other.rows[i];
|
||||
}
|
||||
|
||||
template<>
|
||||
void square128::conditional_add<gfp>(BitVector& conditions, square128& other, int offset)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
if (conditions.get_bit(128 * offset + i))
|
||||
gfp_iadd(rows[i], other.rows[i]);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void square128::hash_row_wise(MMO& mmo, square128& input)
|
||||
{
|
||||
mmo.hashBlockWise<T,128>((octet*)rows, (octet*)input.rows);
|
||||
}
|
||||
|
||||
template <>
|
||||
void square128::to(gf2n_long& result)
|
||||
{
|
||||
int128 high, low;
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
low ^= int128(rows[i]) << i;
|
||||
high ^= int128(rows[i]) >> (128 - i);
|
||||
}
|
||||
result.reduce(high, low);
|
||||
}
|
||||
|
||||
template <>
|
||||
void square128::to(gfp& result)
|
||||
{
|
||||
mp_limb_t product[4], sum[4], tmp[2][4];
|
||||
memset(tmp, 0, sizeof(tmp));
|
||||
memset(sum, 0, sizeof(sum));
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i]));
|
||||
mpn_lshift(product, tmp[i/64], 4, i % 64);
|
||||
mpn_add_n(sum, product, sum, 4);
|
||||
}
|
||||
mp_limb_t q[4], ans[4];
|
||||
mpn_tdiv_qr(q, ans, 0, sum, 4, gfp::get_ZpD().get_prA(), 2);
|
||||
result = *(__m128i*)ans;
|
||||
}
|
||||
|
||||
void square128::check_transpose(square128& dual, int i, int k)
|
||||
{
|
||||
for (int j = 0; j < 16; j++)
|
||||
for (int l = 0; l < 16; l++)
|
||||
if (get_bit(16 * i + j, 16 * k + l) != dual.get_bit(16 * k + l, 16 * i + j))
|
||||
{
|
||||
cout << "Error in 16x16 square (" << i << "," << k << ")" << endl;
|
||||
print(i, k);
|
||||
dual.print(i, k);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void square16::print()
|
||||
{
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
for (int j = 0; j < 8; j++)
|
||||
{
|
||||
for (int k = 0; k < 2; k++)
|
||||
{
|
||||
for (int l = 0; l < 8; l++)
|
||||
cout << halves[k].get_bit(8 * i + j, l);
|
||||
cout << " ";
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void square128::print(int i, int k)
|
||||
{
|
||||
square16 a;
|
||||
a.input(*this, i, k);
|
||||
a.print();
|
||||
}
|
||||
|
||||
void square128::print()
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
for (int j = 0; j < 128; j++)
|
||||
cout << get_bit(i, j);
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void square128::set_zero()
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
rows[i] = _mm_setzero_si128();
|
||||
}
|
||||
|
||||
square128& square128::operator^=(square128& other)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
rows[i] ^= other.rows[i];
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::add<gf2n>(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::add<gfp>(square128& other)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
gfp_iadd(rows[i], other.rows[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::sub<gf2n>(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::sub<gfp>(square128& other)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
gfp_isub(rows[i], other.rows[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::rsub<gf2n>(square128& other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template<>
|
||||
square128& square128::rsub<gfp>(square128& other)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
gfp_irsub(rows[i], other.rows[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
square128& square128::operator^=(__m128i* other)
|
||||
{
|
||||
__m128i value = _mm_loadu_si128(other);
|
||||
for (int i = 0; i < 128; i++)
|
||||
rows[i] ^= value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <>
|
||||
square128& square128::sub<gf2n>(__m128i* other)
|
||||
{
|
||||
return *this ^= other;
|
||||
}
|
||||
|
||||
template <>
|
||||
square128& square128::sub<gfp>(__m128i* other)
|
||||
{
|
||||
__m128i value = _mm_loadu_si128(other);
|
||||
for (int i = 0; i < 128; i++)
|
||||
gfp_isub(rows[i], value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
square128& square128::operator^=(BitVector& other)
|
||||
{
|
||||
return *this ^= (__m128i*)other.get_ptr();
|
||||
}
|
||||
|
||||
bool square128::operator==(square128& other)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
__m128i tmp = rows[i] ^ other.rows[i];
|
||||
if (not _mm_test_all_zeros(tmp, tmp))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void square128::pack(octetStream& o) const
|
||||
{
|
||||
o.append((octet*)this->bytes, sizeof(bytes));
|
||||
}
|
||||
|
||||
void square128::unpack(octetStream &o)
|
||||
{
|
||||
o.consume((octet*)this->bytes, sizeof(bytes));
|
||||
}
|
||||
|
||||
|
||||
BitMatrix::BitMatrix(int length)
|
||||
{
|
||||
resize(length);
|
||||
}
|
||||
|
||||
void BitMatrix::resize(int length)
|
||||
{
|
||||
if (length % 128 != 0)
|
||||
throw invalid_length();
|
||||
squares.resize(length / 128);
|
||||
}
|
||||
|
||||
int BitMatrix::size()
|
||||
{
|
||||
return squares.size() * 128;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrix& BitMatrix::add(BitMatrix& other)
|
||||
{
|
||||
if (squares.size() != other.squares.size())
|
||||
throw invalid_length();
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].add<T>(other.squares[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrix& BitMatrix::sub(BitMatrix& other)
|
||||
{
|
||||
if (squares.size() != other.squares.size())
|
||||
throw invalid_length();
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].sub<T>(other.squares[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrix& BitMatrix::rsub(BitMatrixSlice& other)
|
||||
{
|
||||
if (squares.size() < other.end)
|
||||
throw invalid_length();
|
||||
for (size_t i = other.start; i < other.end; i++)
|
||||
squares[i].rsub<T>(other.bm.squares[i]);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrix& BitMatrix::sub(BitVector& other)
|
||||
{
|
||||
if (squares.size() * 128 != other.size())
|
||||
throw invalid_length();
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].sub<T>((__m128i*)other.get_ptr() + i);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool BitMatrix::operator==(BitMatrix& other)
|
||||
{
|
||||
if (squares.size() != other.squares.size())
|
||||
throw invalid_length();
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
if (not(squares[i] == other.squares[i]))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BitMatrix::operator!=(BitMatrix& other)
|
||||
{
|
||||
return not (*this == other);
|
||||
}
|
||||
|
||||
void BitMatrix::randomize(PRNG& G)
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].randomize(G);
|
||||
}
|
||||
|
||||
void BitMatrix::randomize(int row, PRNG& G)
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].randomize<gf2n>(row, G);
|
||||
}
|
||||
|
||||
void BitMatrix::transpose()
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].transpose();
|
||||
}
|
||||
|
||||
void BitMatrix::check_transpose(BitMatrix& dual)
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
{
|
||||
for (int j = 0; j < 128; j++)
|
||||
for (int k = 0; k < 128; k++)
|
||||
if (squares[i].get_bit(j, k) != dual.squares[i].get_bit(k, j))
|
||||
{
|
||||
cout << "First error in square " << i << " row " << j
|
||||
<< " column " << k << endl;
|
||||
squares[i].print(i / 8, j / 8);
|
||||
dual.squares[i].print(i / 8, j / 8);
|
||||
return;
|
||||
}
|
||||
}
|
||||
cout << "No errors in transpose" << endl;
|
||||
}
|
||||
|
||||
void BitMatrix::print_side_by_side(BitMatrix& other)
|
||||
{
|
||||
for (int i = 0; i < 32; i++)
|
||||
{
|
||||
for (int j = 0; j < 64; j++)
|
||||
cout << squares[0].get_bit(i,j);
|
||||
cout << " ";
|
||||
for (int j = 0; j < 64; j++)
|
||||
cout << other.squares[0].get_bit(i,j);
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void BitMatrix::print_conditional(BitVector& conditions)
|
||||
{
|
||||
for (int i = 0; i < 32; i++)
|
||||
{
|
||||
if (conditions.get_bit(i))
|
||||
for (int j = 0; j < 65; j++)
|
||||
cout << " ";
|
||||
for (int j = 0; j < 64; j++)
|
||||
cout << squares[0].get_bit(i,j);
|
||||
if (!conditions.get_bit(i))
|
||||
for (int j = 0; j < 65; j++)
|
||||
cout << " ";
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void BitMatrix::pack(octetStream& os) const
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].pack(os);
|
||||
}
|
||||
|
||||
void BitMatrix::unpack(octetStream& os)
|
||||
{
|
||||
for (size_t i = 0; i < squares.size(); i++)
|
||||
squares[i].unpack(os);
|
||||
}
|
||||
|
||||
void BitMatrix::to(vector<BitVector>& output)
|
||||
{
|
||||
output.resize(128);
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
output[i].resize(128 * squares.size());
|
||||
for (size_t j = 0; j < squares.size(); j++)
|
||||
output[i].set_int128(j, squares[j].rows[i]);
|
||||
}
|
||||
}
|
||||
|
||||
BitMatrixSlice::BitMatrixSlice(BitMatrix& bm, size_t start, size_t size) :
|
||||
bm(bm), start(start), size(size)
|
||||
{
|
||||
end = start + size;
|
||||
if (end > bm.squares.size())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Matrix slice (" << start << "," << end << ") larger than matrix (" << bm.squares.size() << ")";
|
||||
throw invalid_argument(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrixSlice& BitMatrixSlice::rsub(BitMatrixSlice& other)
|
||||
{
|
||||
bm.rsub<T>(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
BitMatrixSlice& BitMatrixSlice::add(BitVector& other, int repeat)
|
||||
{
|
||||
if (end * 128 > other.size() * repeat)
|
||||
throw invalid_length();
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].sub<T>((__m128i*)other.get_ptr() + i / repeat);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void BitMatrixSlice::randomize(int row, PRNG& G)
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].randomize<T>(row, G);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void BitMatrixSlice::conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset)
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].conditional_add<T>(conditions, other.squares[i], useOffset * i);
|
||||
}
|
||||
|
||||
void BitMatrixSlice::transpose()
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].transpose();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void BitMatrixSlice::print()
|
||||
{
|
||||
cout << "hex / value" << endl;
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
void BitMatrixSlice::pack(octetStream& os) const
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].pack(os);
|
||||
}
|
||||
|
||||
void BitMatrixSlice::unpack(octetStream& os)
|
||||
{
|
||||
for (size_t i = start; i < end; i++)
|
||||
bm.squares[i].unpack(os);
|
||||
}
|
||||
|
||||
template void BitMatrixSlice::conditional_add<gf2n>(BitVector& conditions, BitMatrix& other, bool useOffset);
|
||||
template void BitMatrixSlice::conditional_add<gfp>(BitVector& conditions, BitMatrix& other, bool useOffset);
|
||||
template BitMatrixSlice& BitMatrixSlice::rsub<gf2n>(BitMatrixSlice& other);
|
||||
template BitMatrixSlice& BitMatrixSlice::rsub<gfp>(BitMatrixSlice& other);
|
||||
template BitMatrixSlice& BitMatrixSlice::add<gf2n>(BitVector& other, int repeat);
|
||||
template BitMatrixSlice& BitMatrixSlice::add<gfp>(BitVector& other, int repeat);
|
||||
template BitMatrix& BitMatrix::add<gf2n>(BitMatrix& other);
|
||||
template BitMatrix& BitMatrix::add<gfp>(BitMatrix& other);
|
||||
template BitMatrix& BitMatrix::sub<gf2n>(BitMatrix& other);
|
||||
template BitMatrix& BitMatrix::sub<gfp>(BitMatrix& other);
|
||||
template void BitMatrixSlice::print<gf2n>();
|
||||
template void BitMatrixSlice::print<gfp>();
|
||||
template void BitMatrixSlice::randomize<gf2n>(int row, PRNG& G);
|
||||
template void BitMatrixSlice::randomize<gfp>(int row, PRNG& G);
|
||||
template void square128::hash_row_wise<gf2n>(MMO& mmo, square128& input);
|
||||
template void square128::hash_row_wise<gfp>(MMO& mmo, square128& input);
|
||||
136
OT/BitMatrix.h
Normal file
136
OT/BitMatrix.h
Normal file
@@ -0,0 +1,136 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* BitMatrix.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_BITMATRIX_H_
|
||||
#define OT_BITMATRIX_H_
|
||||
|
||||
#include <vector>
|
||||
#include <emmintrin.h>
|
||||
|
||||
#include "BitVector.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/MMO.h"
|
||||
#include "Math/gf2nlong.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
union square128 {
|
||||
__m128i rows[128];
|
||||
octet bytes[128][16];
|
||||
int16_t doublebytes[128][8];
|
||||
int32_t words[128][4];
|
||||
|
||||
bool get_bit(int x, int y)
|
||||
{ return (bytes[x][y/8] >> (y % 8)) & 1; }
|
||||
|
||||
void set_zero();
|
||||
|
||||
square128& operator^=(square128& other);
|
||||
square128& operator^=(__m128i* other);
|
||||
square128& operator^=(BitVector& other);
|
||||
bool operator==(square128& other);
|
||||
|
||||
template <class T>
|
||||
square128& add(square128& other);
|
||||
template <class T>
|
||||
square128& sub(square128& other);
|
||||
template <class T>
|
||||
square128& rsub(square128& other);
|
||||
template <class T>
|
||||
square128& sub(__m128i* other);
|
||||
|
||||
void randomize(PRNG& G);
|
||||
template <class T>
|
||||
void randomize(int row, PRNG& G);
|
||||
template <class T>
|
||||
void conditional_add(BitVector& conditions, square128& other, int offset);
|
||||
void transpose();
|
||||
template <class T>
|
||||
void hash_row_wise(MMO& mmo, square128& input);
|
||||
template <class T>
|
||||
void to(T& result);
|
||||
|
||||
void check_transpose(square128& dual, int i, int k);
|
||||
void print(int i, int k);
|
||||
void print();
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
};
|
||||
|
||||
class BitMatrixSlice;
|
||||
|
||||
class BitMatrix
|
||||
{
|
||||
public:
|
||||
vector<square128> squares;
|
||||
|
||||
BitMatrix() {}
|
||||
BitMatrix(int length);
|
||||
void resize(int length);
|
||||
int size();
|
||||
|
||||
template <class T>
|
||||
BitMatrix& add(BitMatrix& other);
|
||||
template <class T>
|
||||
BitMatrix& sub(BitMatrix& other);
|
||||
template <class T>
|
||||
BitMatrix& rsub(BitMatrixSlice& other);
|
||||
template <class T>
|
||||
BitMatrix& sub(BitVector& other);
|
||||
bool operator==(BitMatrix& other);
|
||||
bool operator!=(BitMatrix& other);
|
||||
|
||||
void randomize(PRNG& G);
|
||||
void randomize(int row, PRNG& G);
|
||||
void transpose();
|
||||
|
||||
void check_transpose(BitMatrix& dual);
|
||||
void print_side_by_side(BitMatrix& other);
|
||||
void print_conditional(BitVector& conditions);
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
void to(vector<BitVector>& output);
|
||||
};
|
||||
|
||||
class BitMatrixSlice
|
||||
{
|
||||
friend class BitMatrix;
|
||||
|
||||
BitMatrix& bm;
|
||||
size_t start, size, end;
|
||||
|
||||
public:
|
||||
BitMatrixSlice(BitMatrix& bm, size_t start, size_t size);
|
||||
|
||||
template <class T>
|
||||
BitMatrixSlice& rsub(BitMatrixSlice& other);
|
||||
template <class T>
|
||||
BitMatrixSlice& add(BitVector& other, int repeat = 1);
|
||||
|
||||
template <class T>
|
||||
void randomize(int row, PRNG& G);
|
||||
template <class T>
|
||||
void conditional_add(BitVector& conditions, BitMatrix& other, bool useOffset = false);
|
||||
void transpose();
|
||||
|
||||
template <class T>
|
||||
void print();
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
};
|
||||
|
||||
#endif /* OT_BITMATRIX_H_ */
|
||||
107
OT/BitVector.cpp
Normal file
107
OT/BitVector.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "OT/BitVector.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
void BitVector::randomize(PRNG& G)
|
||||
{
|
||||
G.get_octets(bytes, nbytes);
|
||||
}
|
||||
|
||||
template<>
|
||||
void BitVector::randomize_blocks<gf2n>(PRNG& G)
|
||||
{
|
||||
randomize(G);
|
||||
}
|
||||
|
||||
template<>
|
||||
void BitVector::randomize_blocks<gfp>(PRNG& G)
|
||||
{
|
||||
gfp tmp;
|
||||
for (size_t i = 0; i < (nbits / 128); i++)
|
||||
{
|
||||
tmp.randomize(G);
|
||||
for (int j = 0; j < 2; j++)
|
||||
((mp_limb_t*)bytes)[2*i+j] = tmp.get().get_limb(j);
|
||||
}
|
||||
}
|
||||
|
||||
void BitVector::randomize_at(int a, int nb, PRNG& G)
|
||||
{
|
||||
if (nb < 1)
|
||||
throw invalid_length();
|
||||
G.get_octets(bytes + a, nb);
|
||||
}
|
||||
|
||||
/*
|
||||
*/
|
||||
|
||||
void BitVector::output(ostream& s,bool human) const
|
||||
{
|
||||
if (human)
|
||||
{
|
||||
s << nbits << " " << hex;
|
||||
for (unsigned int i = 0; i < nbytes; i++)
|
||||
{
|
||||
s << int(bytes[i]) << " ";
|
||||
}
|
||||
s << dec << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
int len = nbits;
|
||||
s.write((char*) &len, sizeof(int));
|
||||
s.write((char*) bytes, nbytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void BitVector::input(istream& s,bool human)
|
||||
{
|
||||
if (s.peek() == EOF)
|
||||
{
|
||||
if (s.tellg() == 0)
|
||||
{
|
||||
cout << "IO problem. Empty file?" << endl;
|
||||
throw file_error();
|
||||
}
|
||||
throw end_of_file();
|
||||
}
|
||||
int len;
|
||||
if (human)
|
||||
{
|
||||
s >> len >> hex;
|
||||
resize(len);
|
||||
for (size_t i = 0; i < nbytes; i++)
|
||||
{
|
||||
s >> bytes[i];
|
||||
}
|
||||
s >> dec;
|
||||
}
|
||||
else
|
||||
{
|
||||
s.read((char*) &len, sizeof(int));
|
||||
resize(len);
|
||||
s.read((char*) bytes, nbytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void BitVector::pack(octetStream& o) const
|
||||
{
|
||||
o.append((octet*)bytes, nbytes);
|
||||
}
|
||||
|
||||
|
||||
void BitVector::unpack(octetStream& o)
|
||||
{
|
||||
o.consume((octet*)bytes, nbytes);
|
||||
}
|
||||
|
||||
|
||||
212
OT/BitVector.h
Normal file
212
OT/BitVector.h
Normal file
@@ -0,0 +1,212 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _BITVECTOR
|
||||
#define _BITVECTOR
|
||||
|
||||
/* Vector of bits */
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
#include <stdlib.h>
|
||||
#include <pmmintrin.h>
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Networking/data.h"
|
||||
// just for util functions
|
||||
#include "Math/bigint.h"
|
||||
#include "Math/gf2nlong.h"
|
||||
|
||||
class PRNG;
|
||||
class octetStream;
|
||||
|
||||
|
||||
class BitVector
|
||||
{
|
||||
octet* bytes;
|
||||
|
||||
size_t nbytes;
|
||||
size_t nbits;
|
||||
size_t length;
|
||||
|
||||
public:
|
||||
|
||||
void assign(const BitVector& K)
|
||||
{
|
||||
if (nbits != K.nbits)
|
||||
{
|
||||
resize(K.nbits);
|
||||
}
|
||||
memcpy(bytes, K.bytes, nbytes);
|
||||
}
|
||||
void assign_bytes(char* new_bytes, int len)
|
||||
{
|
||||
resize(len*8);
|
||||
memcpy(bytes, new_bytes, len);
|
||||
}
|
||||
void assign_zero()
|
||||
{
|
||||
memset(bytes, 0, nbytes);
|
||||
}
|
||||
// only grows, never destroys
|
||||
void resize(size_t new_nbits)
|
||||
{
|
||||
if (nbits != new_nbits)
|
||||
{
|
||||
int new_nbytes = DIV_CEIL(new_nbits,8);
|
||||
|
||||
if (nbits < new_nbits)
|
||||
{
|
||||
octet* tmp = new octet[new_nbytes];
|
||||
memcpy(tmp, bytes, nbytes);
|
||||
delete[] bytes;
|
||||
bytes = tmp;
|
||||
}
|
||||
|
||||
nbits = new_nbits;
|
||||
nbytes = new_nbytes;
|
||||
/*
|
||||
// use realloc to preserve original contents
|
||||
if (new_nbits < nbits)
|
||||
{
|
||||
memcpy(tmp, bytes, new_nbytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
memset(tmp, 0, new_nbytes);
|
||||
memcpy(tmp, bytes, nbytes);
|
||||
}*/
|
||||
|
||||
// realloc may fail on size 0
|
||||
/*if (new_nbits == 0)
|
||||
{
|
||||
free(bytes);
|
||||
bytes = (octet*) malloc(0);//new octet[0];
|
||||
//free(bytes);
|
||||
return;
|
||||
}
|
||||
bytes = (octet*)realloc(bytes, nbytes);
|
||||
if (bytes == NULL)
|
||||
{
|
||||
cerr << "realloc failed\n";
|
||||
exit(1);
|
||||
}*/
|
||||
/*delete[] bytes;
|
||||
nbits = new_nbits;
|
||||
nbytes = DIV_CEIL(nbits, 8);
|
||||
bytes = new octet[nbytes];*/
|
||||
}
|
||||
}
|
||||
unsigned int size() const { return nbits; }
|
||||
unsigned int size_bytes() const { return nbytes; }
|
||||
octet* get_ptr() { return bytes; }
|
||||
|
||||
BitVector(size_t n=128)
|
||||
{
|
||||
nbits = n;
|
||||
nbytes = DIV_CEIL(nbits, 8);
|
||||
bytes = new octet[nbytes];
|
||||
length = n;
|
||||
assign_zero();
|
||||
}
|
||||
BitVector(const BitVector& K)
|
||||
{
|
||||
bytes = new octet[K.nbytes];
|
||||
nbytes = K.nbytes;
|
||||
nbits = K.nbits;
|
||||
assign(K);
|
||||
}
|
||||
~BitVector() {
|
||||
//cout << "Destroy, size = " << nbytes << endl;
|
||||
delete[] bytes;
|
||||
}
|
||||
BitVector& operator=(const BitVector& K)
|
||||
{
|
||||
if (this!=&K) { assign(K); }
|
||||
return *this;
|
||||
}
|
||||
|
||||
octet get_byte(int i) const { return bytes[i]; }
|
||||
|
||||
void set_byte(int i, octet b) { bytes[i] = b; }
|
||||
|
||||
// get the i-th 64-bit word
|
||||
word get_word(int i) const { return *(word*)(bytes + i*8); }
|
||||
|
||||
void set_word(int i, word w)
|
||||
{
|
||||
int offset = i * sizeof(word);
|
||||
memcpy(bytes + offset, (octet*)&w, sizeof(word));
|
||||
}
|
||||
|
||||
int128 get_int128(int i) const { return _mm_lddqu_si128((__m128i*)bytes + i); }
|
||||
void set_int128(int i, int128 a) { *((__m128i*)bytes + i) = a.a; }
|
||||
|
||||
int get_bit(int i) const
|
||||
{
|
||||
return (bytes[i/8] >> (i % 8)) & 1;
|
||||
}
|
||||
void set_bit(int i,unsigned int a)
|
||||
{
|
||||
int j = i/8, k = i&7;
|
||||
if (a==1)
|
||||
{ bytes[j] |= (octet)(1UL<<k); }
|
||||
else
|
||||
{ bytes[j] &= (octet)~(1UL<<k); }
|
||||
}
|
||||
|
||||
void add(const BitVector& A, const BitVector& B)
|
||||
{
|
||||
if (A.nbits != B.nbits)
|
||||
{ throw invalid_length(); }
|
||||
resize(A.nbits);
|
||||
for (unsigned int i=0; i < nbytes; i++)
|
||||
{
|
||||
bytes[i] = A.bytes[i] ^ B.bytes[i];
|
||||
}
|
||||
}
|
||||
|
||||
void add(const BitVector& A)
|
||||
{
|
||||
if (nbits != A.nbits)
|
||||
{ throw invalid_length(); }
|
||||
for (unsigned int i = 0; i < nbytes; i++)
|
||||
{
|
||||
bytes[i] ^= A.bytes[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool equals(const BitVector& K) const
|
||||
{
|
||||
if (nbits != K.nbits)
|
||||
{ throw invalid_length(); }
|
||||
for (unsigned int i = 0; i < nbytes; i++)
|
||||
{ if (bytes[i] != K.bytes[i]) { return false; } }
|
||||
return true;
|
||||
}
|
||||
|
||||
void randomize(PRNG& G);
|
||||
template <class T>
|
||||
void randomize_blocks(PRNG& G);
|
||||
// randomize bytes a, ..., a+nb-1
|
||||
void randomize_at(int a, int nb, PRNG& G);
|
||||
|
||||
void output(ostream& s,bool human) const;
|
||||
void input(istream& s,bool human);
|
||||
|
||||
// Pack and unpack in native format
|
||||
// i.e. Dont care about conversion to human readable form
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
string str()
|
||||
{
|
||||
stringstream ss;
|
||||
ss << hex;
|
||||
for(size_t i(0);i < nbytes;++i)
|
||||
ss << (int)bytes[i] << " ";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
555
OT/NPartyTripleGenerator.cpp
Normal file
555
OT/NPartyTripleGenerator.cpp
Normal file
@@ -0,0 +1,555 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "NPartyTripleGenerator.h"
|
||||
|
||||
#include "OT/OTExtensionWithMatrix.h"
|
||||
#include "OT/OTMultiplier.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Share.h"
|
||||
#include "Math/operators.h"
|
||||
#include "Auth/Subroutines.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <math.h>
|
||||
|
||||
template <class T, int N>
|
||||
class Triple
|
||||
{
|
||||
public:
|
||||
T a[N];
|
||||
T b;
|
||||
T c[N];
|
||||
|
||||
int repeat(int l)
|
||||
{
|
||||
switch (l)
|
||||
{
|
||||
case 0:
|
||||
case 2:
|
||||
return N;
|
||||
case 1:
|
||||
return 1;
|
||||
default:
|
||||
throw bad_value();
|
||||
}
|
||||
}
|
||||
|
||||
T& byIndex(int l, int j)
|
||||
{
|
||||
switch (l)
|
||||
{
|
||||
case 0:
|
||||
return a[j];
|
||||
case 1:
|
||||
return b;
|
||||
case 2:
|
||||
return c[j];
|
||||
default:
|
||||
throw bad_value();
|
||||
}
|
||||
}
|
||||
|
||||
template <int M>
|
||||
void amplify(const Triple<T,M>& uncheckedTriple, PRNG& G)
|
||||
{
|
||||
b = uncheckedTriple.b;
|
||||
for (int i = 0; i < N; i++)
|
||||
for (int j = 0; j < M; j++)
|
||||
{
|
||||
typename T::value_type r;
|
||||
r.randomize(G);
|
||||
a[i] += r * uncheckedTriple.a[j];
|
||||
c[i] += r * uncheckedTriple.c[j];
|
||||
}
|
||||
}
|
||||
|
||||
void output(ostream& outputStream, int n = N, bool human = false)
|
||||
{
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
a[i].output(outputStream, human);
|
||||
b.output(outputStream, human);
|
||||
c[i].output(outputStream, human);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, int N>
|
||||
class PlainTriple : public Triple<T,N>
|
||||
{
|
||||
public:
|
||||
// this assumes that valueBits[1] is still set to the bits of b
|
||||
void to(vector<BitVector>& valueBits, int i)
|
||||
{
|
||||
for (int j = 0; j < N; j++)
|
||||
{
|
||||
valueBits[0].set_int128(i * N + j, this->a[j].to_m128i());
|
||||
valueBits[2].set_int128(i * N + j, this->c[j].to_m128i());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, int N>
|
||||
class ShareTriple : public Triple<Share<T>, N>
|
||||
{
|
||||
public:
|
||||
void from(PlainTriple<T,N>& triple, vector<OTMultiplier<T>*>& ot_multipliers,
|
||||
int iTriple, const NPartyTripleGenerator& generator)
|
||||
{
|
||||
for (int l = 0; l < 3; l++)
|
||||
{
|
||||
int repeat = this->repeat(l);
|
||||
for (int j = 0; j < repeat; j++)
|
||||
{
|
||||
T value = triple.byIndex(l,j);
|
||||
T mac = value * generator.machine.get_mac_key<T>();
|
||||
for (int i = 0; i < generator.nparties-1; i++)
|
||||
mac += ot_multipliers[i]->macs[l][iTriple * repeat + j];
|
||||
Share<T>& share = this->byIndex(l,j);
|
||||
share.set_share(value);
|
||||
share.set_mac(mac);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
T computeCheckMAC(const T& maskedA)
|
||||
{
|
||||
return this->c[0].get_mac() - maskedA * this->b.get_mac();
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* Copies the relevant base OTs from setup
|
||||
* N.B. setup must not be stored as it will be used by other threads
|
||||
*/
|
||||
NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
TripleMachine& machine) :
|
||||
globalPlayer(names, - thread_num * machine.nplayers * machine.nplayers),
|
||||
thread_num(thread_num),
|
||||
my_num(setup.get_my_num()),
|
||||
nloops(nloops),
|
||||
nparties(setup.get_nparties()),
|
||||
machine(machine)
|
||||
{
|
||||
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
|
||||
nTriples = nTriplesPerLoop * nloops;
|
||||
field_size = 128;
|
||||
nAmplify = machine.amplify ? N_AMPLIFY : 1;
|
||||
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
|
||||
|
||||
int n = nparties;
|
||||
//baseReceiverInput = machines[0]->baseReceiverInput;
|
||||
//baseSenderInputs.resize(n-1);
|
||||
//baseReceiverOutputs.resize(n-1);
|
||||
nbase = setup.get_nbase();
|
||||
baseReceiverInput.resize(nbase);
|
||||
baseReceiverOutputs.resize(n - 1);
|
||||
baseSenderInputs.resize(n - 1);
|
||||
players.resize(n-1);
|
||||
|
||||
gf2n_long::init_field(128);
|
||||
|
||||
for (int i = 0; i < n-1; i++)
|
||||
{
|
||||
// i for indexing, other_player is actual number
|
||||
int other_player, id;
|
||||
if (i >= my_num)
|
||||
other_player = i + 1;
|
||||
else
|
||||
other_player = i;
|
||||
|
||||
// copy base OT inputs + outputs
|
||||
for (int j = 0; j < 128; j++)
|
||||
{
|
||||
baseReceiverInput.set_bit(j, (unsigned int)setup.get_base_receiver_input(j));
|
||||
}
|
||||
baseReceiverOutputs[i] = setup.baseOTs[i]->receiver_outputs;
|
||||
baseSenderInputs[i] = setup.baseOTs[i]->sender_inputs;
|
||||
|
||||
// new TwoPartyPlayer with unique id for each thread + pair of players
|
||||
if (my_num < other_player)
|
||||
id = (thread_num+1)*n*n + my_num*n + other_player;
|
||||
else
|
||||
id = (thread_num+1)*n*n + other_player*n + my_num;
|
||||
players[i] = new TwoPartyPlayer(names, other_player, id);
|
||||
cout << "Set up with player " << other_player << " in thread " << thread_num << " with id " << id << endl;
|
||||
}
|
||||
|
||||
pthread_mutex_init(&mutex, 0);
|
||||
pthread_cond_init(&ready, 0);
|
||||
}
|
||||
|
||||
NPartyTripleGenerator::~NPartyTripleGenerator()
|
||||
{
|
||||
for (size_t i = 0; i < players.size(); i++)
|
||||
delete players[i];
|
||||
//delete nplayer;
|
||||
pthread_mutex_destroy(&mutex);
|
||||
pthread_cond_destroy(&ready);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void* run_ot_thread(void* ptr)
|
||||
{
|
||||
((OTMultiplier<T>*)ptr)->multiply();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void NPartyTripleGenerator::generate()
|
||||
{
|
||||
vector< OTMultiplier<T>* > ot_multipliers(nparties-1);
|
||||
|
||||
timers["Generator thread"].start();
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
{
|
||||
ot_multipliers[i] = new OTMultiplier<T>(*this, i);
|
||||
pthread_mutex_lock(&ot_multipliers[i]->mutex);
|
||||
pthread_create(&(ot_multipliers[i]->thread), 0, run_ot_thread<T>, ot_multipliers[i]);
|
||||
}
|
||||
|
||||
// add up the shares from each thread and write to file
|
||||
stringstream ss;
|
||||
ss << machine.prep_data_dir;
|
||||
if (machine.generateBits)
|
||||
ss << "Bits-";
|
||||
else
|
||||
ss << "Triples-";
|
||||
ss << T::type_char() << "-P" << my_num;
|
||||
if (thread_num != 0)
|
||||
ss << "-" << thread_num;
|
||||
ofstream outputFile(ss.str().c_str());
|
||||
|
||||
if (machine.generateBits)
|
||||
generateBits(ot_multipliers, outputFile);
|
||||
else
|
||||
generateTriples(ot_multipliers, outputFile);
|
||||
|
||||
timers["Generator thread"].stop();
|
||||
if (machine.output)
|
||||
cout << "Written " << nTriples << " outputs to " << ss.str() << endl;
|
||||
else
|
||||
cout << "Generated " << nTriples << " outputs" << endl;
|
||||
|
||||
// wait for threads to finish
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
{
|
||||
pthread_mutex_unlock(&ot_multipliers[i]->mutex);
|
||||
pthread_join(ot_multipliers[i]->thread, NULL);
|
||||
cout << "OT thread " << i << " finished\n" << flush;
|
||||
}
|
||||
cout << "OT threads finished\n";
|
||||
|
||||
for (size_t i = 0; i < ot_multipliers.size(); i++)
|
||||
delete ot_multipliers[i];
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator::generateBits(vector< OTMultiplier<gf2n>* >& ot_multipliers,
|
||||
ofstream& outputFile)
|
||||
{
|
||||
PRNG share_prg;
|
||||
share_prg.ReSeed();
|
||||
|
||||
int nBitsToCheck = nTriplesPerLoop + field_size;
|
||||
valueBits.resize(1);
|
||||
valueBits[0].resize(ceil(1.0 * nBitsToCheck / field_size) * field_size);
|
||||
MAC_Check<gf2n> MC(machine.get_mac_key<gf2n>());
|
||||
vector< Share<gf2n> > bits(nBitsToCheck);
|
||||
vector< Share<gf2n> > to_open(1);
|
||||
vector<gf2n> opened(1);
|
||||
|
||||
start_progress(ot_multipliers);
|
||||
|
||||
for (int k = 0; k < nloops; k++)
|
||||
{
|
||||
print_progress(k);
|
||||
|
||||
valueBits[0].randomize_blocks<gf2n>(share_prg);
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_signal(&ot_multipliers[i]->ready);
|
||||
timers["Authentication OTs"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
|
||||
timers["Authentication OTs"].stop();
|
||||
|
||||
octet seed[SEED_SIZE];
|
||||
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
|
||||
PRNG G;
|
||||
G.SetSeed(seed);
|
||||
|
||||
Share<gf2n> check_sum;
|
||||
gf2n r;
|
||||
for (int j = 0; j < nBitsToCheck; j++)
|
||||
{
|
||||
gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key<gf2n>();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
mac_sum += ot_multipliers[i]->macs[0][j];
|
||||
bits[j].set_share(valueBits[0].get_bit(j));
|
||||
bits[j].set_mac(mac_sum);
|
||||
r.randomize(G);
|
||||
check_sum += r * bits[j];
|
||||
}
|
||||
|
||||
to_open[0] = check_sum;
|
||||
MC.POpen_Begin(opened, to_open, globalPlayer);
|
||||
MC.POpen_End(opened, to_open, globalPlayer);
|
||||
MC.Check(globalPlayer);
|
||||
|
||||
if (machine.output)
|
||||
for (int j = 0; j < nTriplesPerLoop; j++)
|
||||
bits[j].output(outputFile, false);
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_signal(&ot_multipliers[i]->ready);
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator::generateBits(vector< OTMultiplier<gfp>* >& ot_multipliers,
|
||||
ofstream& outputFile)
|
||||
{
|
||||
generateTriples(ot_multipliers, outputFile);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void NPartyTripleGenerator::generateTriples(vector< OTMultiplier<T>* >& ot_multipliers,
|
||||
ofstream& outputFile)
|
||||
{
|
||||
PRNG share_prg;
|
||||
share_prg.ReSeed();
|
||||
|
||||
valueBits.resize(3);
|
||||
for (int i = 0; i < 2; i++)
|
||||
valueBits[2*i].resize(field_size * nPreampTriplesPerLoop);
|
||||
valueBits[1].resize(field_size * nTriplesPerLoop);
|
||||
vector< PlainTriple<T,N_AMPLIFY> > preampTriples;
|
||||
vector< PlainTriple<T,2> > amplifiedTriples;
|
||||
vector< ShareTriple<T,2> > uncheckedTriples;
|
||||
MAC_Check<T> MC(machine.get_mac_key<T>());
|
||||
|
||||
if (machine.amplify)
|
||||
preampTriples.resize(nTriplesPerLoop);
|
||||
if (machine.generateMACs)
|
||||
{
|
||||
amplifiedTriples.resize(nTriplesPerLoop);
|
||||
uncheckedTriples.resize(nTriplesPerLoop);
|
||||
}
|
||||
|
||||
start_progress(ot_multipliers);
|
||||
|
||||
for (int k = 0; k < nloops; k++)
|
||||
{
|
||||
print_progress(k);
|
||||
|
||||
for (int j = 0; j < 2; j++)
|
||||
valueBits[j].randomize_blocks<T>(share_prg);
|
||||
|
||||
timers["OTs"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
|
||||
timers["OTs"].stop();
|
||||
|
||||
for (int j = 0; j < nPreampTriplesPerLoop; j++)
|
||||
{
|
||||
T a(valueBits[0].get_int128(j));
|
||||
T b(valueBits[1].get_int128(j / nAmplify));
|
||||
T c = a * b;
|
||||
timers["Triple computation"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
{
|
||||
c += ot_multipliers[i]->c_output[j];
|
||||
}
|
||||
timers["Triple computation"].stop();
|
||||
if (machine.amplify)
|
||||
{
|
||||
preampTriples[j/nAmplify].a[j%nAmplify] = a;
|
||||
preampTriples[j/nAmplify].b = b;
|
||||
preampTriples[j/nAmplify].c[j%nAmplify] = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
timers["Writing"].start();
|
||||
a.output(outputFile, false);
|
||||
b.output(outputFile, false);
|
||||
c.output(outputFile, false);
|
||||
timers["Writing"].stop();
|
||||
}
|
||||
}
|
||||
|
||||
if (machine.amplify)
|
||||
{
|
||||
octet seed[SEED_SIZE];
|
||||
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
|
||||
PRNG G;
|
||||
G.SetSeed(seed);
|
||||
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
|
||||
{
|
||||
PlainTriple<T,2> triple;
|
||||
triple.amplify(preampTriples[iTriple], G);
|
||||
|
||||
if (machine.generateMACs)
|
||||
amplifiedTriples[iTriple] = triple;
|
||||
else
|
||||
{
|
||||
timers["Writing"].start();
|
||||
triple.output(outputFile);
|
||||
timers["Writing"].stop();
|
||||
}
|
||||
}
|
||||
|
||||
if (machine.generateMACs)
|
||||
{
|
||||
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
|
||||
amplifiedTriples[iTriple].to(valueBits, iTriple);
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_signal(&ot_multipliers[i]->ready);
|
||||
timers["Authentication OTs"].start();
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
|
||||
timers["Authentication OTs"].stop();
|
||||
|
||||
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
|
||||
{
|
||||
uncheckedTriples[iTriple].from(amplifiedTriples[iTriple], ot_multipliers, iTriple, *this);
|
||||
|
||||
if (!machine.check)
|
||||
{
|
||||
timers["Writing"].start();
|
||||
amplifiedTriples[iTriple].output(outputFile);
|
||||
timers["Writing"].stop();
|
||||
}
|
||||
}
|
||||
|
||||
if (machine.check)
|
||||
{
|
||||
vector< Share<T> > maskedAs(nTriplesPerLoop);
|
||||
vector< ShareTriple<T,1> > maskedTriples(nTriplesPerLoop);
|
||||
for (int j = 0; j < nTriplesPerLoop; j++)
|
||||
{
|
||||
maskedTriples[j].amplify(uncheckedTriples[j], G);
|
||||
maskedAs[j] = maskedTriples[j].a[0];
|
||||
}
|
||||
|
||||
vector<T> openedAs(nTriplesPerLoop);
|
||||
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
|
||||
MC.POpen_End(openedAs, maskedAs, globalPlayer);
|
||||
|
||||
for (int j = 0; j < nTriplesPerLoop; j++)
|
||||
MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), int128(0), globalPlayer);
|
||||
|
||||
MC.Check(globalPlayer);
|
||||
|
||||
if (machine.generateBits)
|
||||
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
|
||||
else
|
||||
if (machine.output)
|
||||
for (int j = 0; j < nTriplesPerLoop; j++)
|
||||
uncheckedTriples[j].output(outputFile, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_signal(&ot_multipliers[i]->ready);
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator::generateBitsFromTriples(
|
||||
vector< ShareTriple<gfp,2> >& triples, MAC_Check<gfp>& MC, ofstream& outputFile)
|
||||
{
|
||||
vector< Share<gfp> > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop);
|
||||
for (int i = 0; i < nTriplesPerLoop; i++)
|
||||
a_plus_b[i] = triples[i].a[0] + triples[i].b;
|
||||
vector<gfp> opened(nTriplesPerLoop);
|
||||
MC.POpen_Begin(opened, a_plus_b, globalPlayer);
|
||||
MC.POpen_End(opened, a_plus_b, globalPlayer);
|
||||
for (int i = 0; i < nTriplesPerLoop; i++)
|
||||
a_squared[i] = triples[i].a[0] * opened[i] - triples[i].c[0];
|
||||
MC.POpen_Begin(opened, a_squared, globalPlayer);
|
||||
MC.POpen_End(opened, a_squared, globalPlayer);
|
||||
Share<gfp> one(gfp(1), globalPlayer.my_num(), MC.get_alphai());
|
||||
for (int i = 0; i < nTriplesPerLoop; i++)
|
||||
{
|
||||
gfp root = opened[i].sqrRoot();
|
||||
if (root.is_zero())
|
||||
continue;
|
||||
Share<gfp> bit = (triples[i].a[0] / root + one) / gfp(2);
|
||||
if (machine.output)
|
||||
bit.output(outputFile, false);
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void NPartyTripleGenerator::generateBitsFromTriples(
|
||||
vector< ShareTriple<gf2n,2> >& triples, MAC_Check<gf2n>& MC, ofstream& outputFile)
|
||||
{
|
||||
throw how_would_that_work();
|
||||
// warning gymnastics
|
||||
triples[0];
|
||||
MC.number();
|
||||
outputFile << "";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void NPartyTripleGenerator::start_progress(vector< OTMultiplier<T>* >& ot_multipliers)
|
||||
{
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
|
||||
lock();
|
||||
signal();
|
||||
wait();
|
||||
gettimeofday(&last_lap, 0);
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
pthread_cond_signal(&ot_multipliers[i]->ready);
|
||||
}
|
||||
|
||||
void NPartyTripleGenerator::print_progress(int k)
|
||||
{
|
||||
if (thread_num == 0 && my_num == 0)
|
||||
{
|
||||
struct timeval stop;
|
||||
gettimeofday(&stop, 0);
|
||||
if (timeval_diff_in_seconds(&last_lap, &stop) > 1)
|
||||
{
|
||||
double diff = timeval_diff_in_seconds(&machine.start, &stop);
|
||||
double throughput = k * nTriplesPerLoop * machine.nthreads / diff;
|
||||
double remaining = diff * (nloops - k) / k;
|
||||
cout << k << '/' << nloops << ", throughput: " << throughput
|
||||
<< ", time left: " << remaining << ", elapsed: " << diff
|
||||
<< ", estimated total: " << (diff + remaining) << endl;
|
||||
last_lap = stop;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NPartyTripleGenerator::lock()
|
||||
{
|
||||
pthread_mutex_lock(&mutex);
|
||||
}
|
||||
|
||||
void NPartyTripleGenerator::unlock()
|
||||
{
|
||||
pthread_mutex_unlock(&mutex);
|
||||
}
|
||||
|
||||
void NPartyTripleGenerator::signal()
|
||||
{
|
||||
pthread_cond_signal(&ready);
|
||||
}
|
||||
|
||||
void NPartyTripleGenerator::wait()
|
||||
{
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
}
|
||||
|
||||
|
||||
template void NPartyTripleGenerator::generate<gf2n>();
|
||||
template void NPartyTripleGenerator::generate<gfp>();
|
||||
84
OT/NPartyTripleGenerator.h
Normal file
84
OT/NPartyTripleGenerator.h
Normal file
@@ -0,0 +1,84 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef OT_NPARTYTRIPLEGENERATOR_H_
|
||||
#define OT_NPARTYTRIPLEGENERATOR_H_
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "OT/BaseOT.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "OT/TripleMachine.h"
|
||||
#include "OT/OTMultiplier.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#define N_AMPLIFY 3
|
||||
|
||||
template <class T, int N>
|
||||
class ShareTriple;
|
||||
|
||||
class NPartyTripleGenerator
|
||||
{
|
||||
//OTTripleSetup* setup;
|
||||
Player globalPlayer;
|
||||
|
||||
int thread_num;
|
||||
int my_num;
|
||||
int nbase;
|
||||
|
||||
struct timeval last_lap;
|
||||
|
||||
pthread_mutex_t mutex;
|
||||
pthread_cond_t ready;
|
||||
|
||||
template <class T>
|
||||
void generateTriples(vector< OTMultiplier<T>* >& ot_multipliers, ofstream& outputFile);
|
||||
template <class T>
|
||||
void generateBits(vector< OTMultiplier<T>* >& ot_multipliers, ofstream& outputFile);
|
||||
template <class T, int N>
|
||||
void generateBitsFromTriples(vector<ShareTriple<T, N> >& triples,
|
||||
MAC_Check<T>& MC, ofstream& outputFile);
|
||||
|
||||
template <class T>
|
||||
void start_progress(vector< OTMultiplier<T>* >& ot_multipliers);
|
||||
void print_progress(int k);
|
||||
|
||||
public:
|
||||
// TwoPartyPlayer's for OTs, n-party Player for sacrificing
|
||||
vector<TwoPartyPlayer*> players;
|
||||
//vector<OTMachine*> machines;
|
||||
BitVector baseReceiverInput; // same for every set of OTs
|
||||
vector< vector< vector<BitVector> > > baseSenderInputs;
|
||||
vector< vector<BitVector> > baseReceiverOutputs;
|
||||
vector<BitVector> valueBits;
|
||||
|
||||
int nTriples;
|
||||
int nTriplesPerLoop;
|
||||
int nloops;
|
||||
int field_size;
|
||||
int nAmplify;
|
||||
int nPreampTriplesPerLoop;
|
||||
int repeat[3];
|
||||
int nparties;
|
||||
|
||||
TripleMachine& machine;
|
||||
|
||||
map<string,Timer> timers;
|
||||
|
||||
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, TripleMachine& machine);
|
||||
~NPartyTripleGenerator();
|
||||
template <class T>
|
||||
void generate();
|
||||
|
||||
void lock();
|
||||
void unlock();
|
||||
void signal();
|
||||
void wait();
|
||||
};
|
||||
|
||||
#endif
|
||||
791
OT/OTExtension.cpp
Normal file
791
OT/OTExtension.cpp
Normal file
@@ -0,0 +1,791 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "OTExtension.h"
|
||||
|
||||
#include "OT/Tools.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Tools/aes.h"
|
||||
#include "Tools/MMO.h"
|
||||
#include <wmmintrin.h>
|
||||
#include <emmintrin.h>
|
||||
|
||||
|
||||
word TRANSPOSE_MASKS128[7][2] = {
|
||||
{ 0x0000000000000000, 0xFFFFFFFFFFFFFFFF },
|
||||
{ 0x00000000FFFFFFFF, 0x00000000FFFFFFFF },
|
||||
{ 0x0000FFFF0000FFFF, 0x0000FFFF0000FFFF },
|
||||
{ 0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF },
|
||||
{ 0x0F0F0F0F0F0F0F0F, 0x0F0F0F0F0F0F0F0F },
|
||||
{ 0x3333333333333333, 0x3333333333333333 },
|
||||
{ 0x5555555555555555, 0x5555555555555555 }
|
||||
};
|
||||
|
||||
string word_to_str(word a)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << hex;
|
||||
for(int i = 0;i < 8; i++)
|
||||
ss << ((a >> (i*8)) & 255) << " ";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// Transpose 16x16 matrix starting at bv[x][y] in-place using SSE2
|
||||
void sse_transpose16(vector<BitVector>& bv, int x, int y)
|
||||
{
|
||||
__m128i input[2];
|
||||
// 16x16 in two halves, 128 bits each
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int j = 0; j < 16; j++)
|
||||
((octet*)&input[i])[j] = bv[x+j].get_byte(y / 8 + i);
|
||||
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int j = 0; j < 8; j++)
|
||||
{
|
||||
int output = _mm_movemask_epi8(input[i]);
|
||||
input[i] = _mm_slli_epi64(input[i], 1);
|
||||
|
||||
for (int k = 0; k < 2; k++)
|
||||
// _mm_movemask_epi8 uses most significant bit, hence +7-j
|
||||
bv[x+8*i+7-j].set_byte(y / 8 + k, ((octet*)&output)[k]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Transpose 128x128 bit-matrix using Eklundh's algorithm
|
||||
*
|
||||
* Input is in input[i] [ bits <offset> to <offset+127> ], i = 0, ..., 127
|
||||
* Output is in output[i + offset] (entire 128-bit vector), i = 0, ..., 127.
|
||||
*
|
||||
* Transposes 128-bit vectors in little-endian format.
|
||||
*/
|
||||
//void eklundh_transpose64(vector<word>& output, const vector<Big_Keys>& input, int offset)
|
||||
void eklundh_transpose128(vector<BitVector>& output, const vector<BitVector>& input,
|
||||
int offset)
|
||||
{
|
||||
int width = 64;
|
||||
int logn = 7, nswaps = 1;
|
||||
|
||||
#ifdef TRANSPOSE_DEBUG
|
||||
stringstream input_ss[128];
|
||||
stringstream output_ss[128];
|
||||
#endif
|
||||
// first copy input to output
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
//output[i + offset*64] = input[i].get(offset);
|
||||
output[i + offset].set_word(0, input[i].get_word(offset/64));
|
||||
output[i + offset].set_word(1, input[i].get_word(offset/64 + 1));
|
||||
|
||||
#ifdef TRANSPOSE_DEBUG
|
||||
for (int j = 0; j < 128; j++)
|
||||
{
|
||||
input_ss[j] << input[i].get_bit(offset + j);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// now transpose output in-place
|
||||
for (int i = 0; i < logn; i++)
|
||||
{
|
||||
word mask1 = TRANSPOSE_MASKS128[i][1], mask2 = TRANSPOSE_MASKS128[i][0];
|
||||
word inv_mask1 = ~mask1, inv_mask2 = ~mask2;
|
||||
|
||||
if (width == 8)
|
||||
{
|
||||
for (int j = 0; j < 8; j++)
|
||||
for (int k = 0; k < 8; k++)
|
||||
sse_transpose16(output, offset + 16 * j, 16 * k);
|
||||
break;
|
||||
}
|
||||
else
|
||||
// for width >= 64, shift is undefined so treat as a special case
|
||||
// (and avoid branching in inner loop)
|
||||
if (width < 64)
|
||||
{
|
||||
for (int j = 0; j < nswaps; j++)
|
||||
{
|
||||
for (int k = 0; k < width; k++)
|
||||
{
|
||||
int i1 = k + 2*width*j;
|
||||
int i2 = k + width + 2*width*j;
|
||||
|
||||
// t1 is lower 64 bits, t2 is upper 64 bits
|
||||
// (remember we're transposing in little-endian format)
|
||||
word t1 = output[i1 + offset].get_word(0);
|
||||
word t2 = output[i1 + offset].get_word(1);
|
||||
|
||||
word tt1 = output[i2 + offset].get_word(0);
|
||||
word tt2 = output[i2 + offset].get_word(1);
|
||||
|
||||
// swap operations due to little endian-ness
|
||||
output[i1 + offset].set_word(0, (t1 & mask1) ^
|
||||
((tt1 & mask1) << width));
|
||||
output[i1 + offset].set_word(1, (t2 & mask2) ^
|
||||
((tt2 & mask2) << width) ^
|
||||
((tt1 & mask1) >> (64 - width)));
|
||||
|
||||
output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^
|
||||
((t1 & inv_mask1) >> width) ^
|
||||
((t2 & inv_mask2)) << (64 - width));
|
||||
output[i2 + offset].set_word(1, (tt2 & inv_mask2) ^
|
||||
((t2 & inv_mask2) >> width));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < nswaps; j++)
|
||||
{
|
||||
for (int k = 0; k < width; k++)
|
||||
{
|
||||
int i1 = k + 2*width*j;
|
||||
int i2 = k + width + 2*width*j;
|
||||
|
||||
// t1 is lower 64 bits, t2 is upper 64 bits
|
||||
// (remember we're transposing in little-endian format)
|
||||
word t1 = output[i1 + offset].get_word(0);
|
||||
word t2 = output[i1 + offset].get_word(1);
|
||||
|
||||
word tt1 = output[i2 + offset].get_word(0);
|
||||
word tt2 = output[i2 + offset].get_word(1);
|
||||
|
||||
output[i1 + offset].set_word(0, (t1 & mask1));
|
||||
output[i1 + offset].set_word(1, (t2 & mask2) ^
|
||||
((tt1 & mask1) >> (64 - width)));
|
||||
|
||||
output[i2 + offset].set_word(0, (tt1 & inv_mask1) ^
|
||||
((t2 & inv_mask2)) << (64 - width));
|
||||
output[i2 + offset].set_word(1, (tt2 & inv_mask2));
|
||||
}
|
||||
}
|
||||
}
|
||||
nswaps *= 2;
|
||||
width /= 2;
|
||||
}
|
||||
#ifdef TRANSPOSE_DEBUG
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
for (int j = 0; j < 128; j++)
|
||||
{
|
||||
output_ss[j] << output[offset + j].get_bit(i);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
if (output_ss[i].str().compare(input_ss[i].str()) != 0)
|
||||
{
|
||||
cerr << "String " << i << " failed. offset = " << offset << endl;
|
||||
cerr << input_ss[i].str() << endl;
|
||||
cerr << output_ss[i].str() << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
cout << "\ttranspose with offset " << offset << " ok\n";
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// get bit, starting from MSB as bit 0
|
||||
int get_bit(word x, int b)
|
||||
{
|
||||
return (x >> (63 - b)) & 1;
|
||||
}
|
||||
int get_bit128(word x1, word x2, int b)
|
||||
{
|
||||
if (b < 64)
|
||||
{
|
||||
return (x1 >> (b - 64)) & 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (x2 >> b) & 1;
|
||||
}
|
||||
}
|
||||
|
||||
void naive_transpose128(vector<BitVector>& output, const vector<BitVector>& input,
|
||||
int offset)
|
||||
{
|
||||
for (int i = 0; i < 128; i++)
|
||||
{
|
||||
// NB: words are read from input in big-endian format
|
||||
word w1 = input[i].get_word(offset/64);
|
||||
word w2 = input[i].get_word(offset/64 + 1);
|
||||
|
||||
for (int j = 0; j < 128; j++)
|
||||
{
|
||||
//output[j + offset].set_bit(i, input[i].get_bit(j + offset));
|
||||
|
||||
if (j < 64)
|
||||
output[j + offset].set_bit(i, (w1 >> j) & 1);
|
||||
else
|
||||
output[j + offset].set_bit(i, w2 >> (j-64) & 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transpose64(
|
||||
vector<BitVector>::iterator& output_it,
|
||||
vector<BitVector>::iterator& input_it)
|
||||
{
|
||||
for (int i = 0; i < 64; i++)
|
||||
{
|
||||
for (int j = 0; j < 64; j++)
|
||||
{
|
||||
(output_it + j)->set_bit(i, (input_it + i)->get_bit(j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Naive 64x64 bit matrix transpose
|
||||
void naive_transpose64(vector<BitVector>& output, const vector<BitVector>& input,
|
||||
int xoffset, int yoffset)
|
||||
{
|
||||
int word_size = 64;
|
||||
|
||||
for (int i = 0; i < word_size; i++)
|
||||
{
|
||||
word w = input[i + yoffset].get_word(xoffset);
|
||||
for (int j = 0; j < word_size; j++)
|
||||
{
|
||||
//cout << j + xoffset*word_size << ", " << yoffset/word_size << endl;
|
||||
//int wbit = (((w >> j) & 1) << i); cout << "wbit " << wbit << endl;
|
||||
|
||||
// set i-th bit of output to j-th bit of w
|
||||
// scale yoffset by 64 since we're selecting words from the BitVector
|
||||
word tmp = output[j + xoffset*word_size].get_word(yoffset/word_size);
|
||||
output[j + xoffset*word_size].set_word(yoffset/word_size, tmp ^ ((w >> j) & 1) << i);
|
||||
// set i-th bit of output to j-th bit of w
|
||||
//output[j + offset*word_size] ^= ((w >> j) & 1) << i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtension::transfer(int nOTs,
|
||||
const BitVector& receiverInput)
|
||||
{
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval totalstartv, totalendv;
|
||||
gettimeofday(&totalstartv, NULL);
|
||||
#endif
|
||||
cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl;
|
||||
// add k + s to account for discarding k OTs
|
||||
nOTs += 2 * 128;
|
||||
if (nOTs % nbaseOTs != 0)
|
||||
throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
|
||||
if (nOTs == 0)
|
||||
return;
|
||||
|
||||
vector<BitVector> t0(nbaseOTs, BitVector(nOTs)), tmp(nbaseOTs, BitVector(nOTs)), t1(nbaseOTs, BitVector(nOTs));
|
||||
BitVector u(nOTs);
|
||||
senderOutput.resize(2, vector<BitVector>(nOTs, BitVector(nbaseOTs)));
|
||||
// resize to account for extra k OTs that are discarded
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
BitVector newReceiverInput(nOTs);
|
||||
for (unsigned int i = 0; i < receiverInput.size_bytes(); i++)
|
||||
{
|
||||
newReceiverInput.set_byte(i, receiverInput.get_byte(i));
|
||||
}
|
||||
|
||||
//BitVector newReceiverInput(receiverInput);
|
||||
newReceiverInput.resize(nOTs);
|
||||
|
||||
receiverOutput.resize(nOTs, BitVector(nbaseOTs));
|
||||
|
||||
for (int loop = 0; loop < nloops; loop++)
|
||||
{
|
||||
vector<octetStream> os(2), tmp_os(2);
|
||||
|
||||
// randomize last 128 + 128 bits that will be discarded
|
||||
for (int i = 0; i < 4; i++)
|
||||
newReceiverInput.set_word(nOTs/64 - i, G.get_word());
|
||||
|
||||
// expand with PRG and create correlation
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
t0[i].randomize(G_sender[i][0]);
|
||||
t1[i].randomize(G_sender[i][1]);
|
||||
tmp[i].assign(t1[i]);
|
||||
|
||||
tmp[i].add(t0[i]);
|
||||
tmp[i].add(newReceiverInput);
|
||||
tmp[i].pack(os[0]);
|
||||
/*cout << "t0: " << t0[i].str() << endl;
|
||||
cout << "t1: " << t1[i].str() << endl;
|
||||
cout << "Sending tmp: " << tmp[i].str() << endl;*/
|
||||
}
|
||||
}
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval commst1, commst2;
|
||||
gettimeofday(&commst1, NULL);
|
||||
#endif
|
||||
// send t0 + t1 + x
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
|
||||
// sender adjusts using base receiver bits
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
// randomize base receiver output
|
||||
tmp[i].randomize(G_receiver[i]);
|
||||
|
||||
// u = t0 + t1 + x
|
||||
u.unpack(os[1]);
|
||||
|
||||
if (baseReceiverInput.get_bit(i) == 1)
|
||||
{
|
||||
// now tmp is q[i] = t0[i] + Delta[i] * x
|
||||
tmp[i].add(u);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&commst2, NULL);
|
||||
double commstime = timeval_diff(&commst1, &commst2);
|
||||
cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush;
|
||||
times["Communication"] += timeval_diff(&commst1, &commst2);
|
||||
#endif
|
||||
|
||||
// transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0]
|
||||
|
||||
// stupid transpose
|
||||
/*for (int j = 0; j < nOTs; j++)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
senderOutput[0][j].set_bit(i, t0[i].get_bit(j));
|
||||
receiverOutput[j].set_bit(i, tmp[i].get_bit(j));
|
||||
}
|
||||
}*/
|
||||
cout << "Starting matrix transpose\n" << flush << endl;
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval transt1, transt2;
|
||||
gettimeofday(&transt1, NULL);
|
||||
#endif
|
||||
// transpose in 128-bit chunks with Eklundh's algorithm
|
||||
for (int i = 0; i < nOTs / 128; i++)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
eklundh_transpose128(receiverOutput, t0, i*128);
|
||||
//naive_transpose128(receiverOutput, t0, i*128);
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
eklundh_transpose128(senderOutput[0], tmp, i*128);
|
||||
//naive_transpose128(senderOutput[0], tmp, i*128);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&transt2, NULL);
|
||||
double transtime = timeval_diff(&transt1, &transt2);
|
||||
cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush;
|
||||
times["Matrix transpose"] += timeval_diff(&transt1, &transt2);
|
||||
#endif
|
||||
|
||||
#ifdef OTEXT_DEBUG
|
||||
// verify correctness of the OTs
|
||||
// i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i]
|
||||
// (where Delta = baseReceiverOutput)
|
||||
BitVector tmp_vector1(nbaseOTs), tmp_vector2(nOTs);//nbaseOTs);
|
||||
cout << "\tVerifying OT extensions (debugging)\n";
|
||||
for (int i = 0; i < nOTs; i++)
|
||||
{
|
||||
os[0].reset_write_head();
|
||||
os[1].reset_write_head();
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
// send t0 and x over
|
||||
receiverOutput[i].pack(os[0]);
|
||||
//t0[i].pack(os[0]);
|
||||
newReceiverInput.pack(os[0]);
|
||||
}
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
tmp_vector1.unpack(os[1]);
|
||||
tmp_vector2.unpack(os[1]);
|
||||
|
||||
// if x_i = 1, add Delta
|
||||
if (tmp_vector2.get_bit(i) == 1)
|
||||
{
|
||||
tmp_vector1.add(baseReceiverInput);
|
||||
}
|
||||
if (!tmp_vector1.equals(senderOutput[0][i]))
|
||||
{
|
||||
cerr << "Incorrect OT at " << i << "\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
cout << "Correlated OTs all OK\n";
|
||||
#endif
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
double elapsed;
|
||||
#endif
|
||||
// correlation check
|
||||
if (!passive_only)
|
||||
{
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval startv, endv;
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
check_correlation(nOTs, newReceiverInput);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush;
|
||||
times["Total correlation check"] += timeval_diff(&startv, &endv);
|
||||
#endif
|
||||
}
|
||||
|
||||
hash_outputs(nOTs, receiverOutput);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&totalendv, NULL);
|
||||
elapsed = timeval_diff(&totalstartv, &totalendv);
|
||||
cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush;
|
||||
#endif
|
||||
|
||||
#ifdef OTEXT_DEBUG
|
||||
// verify correctness of the random OTs
|
||||
// i.e. senderOutput[0][i] + x_i * Delta = receiverOutput[i]
|
||||
// (where Delta = baseReceiverOutput)
|
||||
cout << "Verifying random OTs (debugging)\n";
|
||||
for (int i = 0; i < nOTs; i++)
|
||||
{
|
||||
os[0].reset_write_head();
|
||||
os[1].reset_write_head();
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
// send receiver's input/output over
|
||||
receiverOutput[i].pack(os[0]);
|
||||
newReceiverInput.pack(os[0]);
|
||||
}
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
//player->send_receive_player(os);
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
tmp_vector1.unpack(os[1]);
|
||||
tmp_vector2.unpack(os[1]);
|
||||
|
||||
// if x_i = 1, comp with sender output[1]
|
||||
if ((tmp_vector2.get_bit(i) == 1))
|
||||
{
|
||||
if (!tmp_vector1.equals(senderOutput[1][i]))
|
||||
{
|
||||
cerr << "Incorrect OT\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
// else should be sender output[0]
|
||||
else if (!tmp_vector1.equals(senderOutput[0][i]))
|
||||
{
|
||||
cerr << "Incorrect OT\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
cout << "Random OTs all OK\n";
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&totalendv, NULL);
|
||||
times["Total thread"] += timeval_diff(&totalstartv, &totalendv);
|
||||
#endif
|
||||
|
||||
receiverOutput.resize(nOTs - 2 * 128);
|
||||
senderOutput[0].resize(nOTs - 2 * 128);
|
||||
senderOutput[1].resize(nOTs - 2 * 128);
|
||||
}
|
||||
|
||||
/*
|
||||
* Hash outputs to make into random OT
|
||||
*/
|
||||
void OTExtension::hash_outputs(int nOTs, vector<BitVector>& receiverOutput)
|
||||
{
|
||||
cout << "Hashing... " << flush;
|
||||
octetStream os, h_os(HASH_SIZE);
|
||||
BitVector tmp(nbaseOTs);
|
||||
MMO mmo;
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval startv, endv;
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < nOTs; i++)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
tmp.add(senderOutput[0][i], baseReceiverInput);
|
||||
|
||||
if (senderOutput[0][i].size() == 128)
|
||||
{
|
||||
mmo.hashOneBlock<gf2n>(senderOutput[0][i].get_ptr(), senderOutput[0][i].get_ptr());
|
||||
mmo.hashOneBlock<gf2n>(senderOutput[1][i].get_ptr(), tmp.get_ptr());
|
||||
}
|
||||
else
|
||||
{
|
||||
os.reset_write_head();
|
||||
h_os.reset_write_head();
|
||||
|
||||
senderOutput[0][i].pack(os);
|
||||
os.hash(h_os);
|
||||
senderOutput[0][i].unpack(h_os);
|
||||
|
||||
os.reset_write_head();
|
||||
h_os.reset_write_head();
|
||||
|
||||
tmp.pack(os);
|
||||
os.hash(h_os);
|
||||
senderOutput[1][i].unpack(h_os);
|
||||
}
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
if (receiverOutput[i].size() == 128)
|
||||
mmo.hashOneBlock<gf2n>(receiverOutput[i].get_ptr(), receiverOutput[i].get_ptr());
|
||||
else
|
||||
{
|
||||
os.reset_write_head();
|
||||
h_os.reset_write_head();
|
||||
|
||||
receiverOutput[i].pack(os);
|
||||
os.hash(h_os);
|
||||
receiverOutput[i].unpack(h_os);
|
||||
}
|
||||
}
|
||||
}
|
||||
cout << "done.\n";
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
double elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush;
|
||||
times["Hashing"] += timeval_diff(&startv, &endv);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// test if a == b
|
||||
int eq_m128i(__m128i a, __m128i b)
|
||||
{
|
||||
__m128i vcmp = _mm_cmpeq_epi8(a, b);
|
||||
uint16_t vmask = _mm_movemask_epi8(vcmp);
|
||||
return (vmask == 0xffff);
|
||||
}
|
||||
|
||||
void random_m128i(PRNG& G, __m128i *r)
|
||||
{
|
||||
BitVector rv(128);
|
||||
rv.randomize(G);
|
||||
*r = _mm_load_si128((__m128i*)&(rv.get_ptr()[0]));
|
||||
}
|
||||
|
||||
void test_mul()
|
||||
{
|
||||
cout << "Testing GF(2^128) multiplication\n";
|
||||
__m128i t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
BitVector r(128);
|
||||
for (int i = 0; i < 1000; i++)
|
||||
{
|
||||
random_m128i(G, &t1);
|
||||
random_m128i(G, &t2);
|
||||
// test commutativity
|
||||
gfmul128(t1, t2, &t3);
|
||||
gfmul128(t2, t1, &t4);
|
||||
if (!eq_m128i(t3, t4))
|
||||
{
|
||||
cerr << "Incorrect multiplication:\n";
|
||||
cerr << "t1 * t2 = " << __m128i_toString<octet>(t3) << endl;
|
||||
cerr << "t2 * t1 = " << __m128i_toString<octet>(t4) << endl;
|
||||
}
|
||||
// test distributivity: t1*t3 + t2*t3 = (t1 + t2) * t3
|
||||
random_m128i(G, &t1);
|
||||
random_m128i(G, &t2);
|
||||
random_m128i(G, &t3);
|
||||
gfmul128(t1, t3, &t4);
|
||||
gfmul128(t2, t3, &t5);
|
||||
t6 = _mm_xor_si128(t4, t5);
|
||||
|
||||
t7 = _mm_xor_si128(t1, t2);
|
||||
gfmul128(t7, t3, &t8);
|
||||
if (!eq_m128i(t6, t8))
|
||||
{
|
||||
cerr << "Incorrect multiplication:\n";
|
||||
cerr << "t1 * t3 + t2 * t3 = " << __m128i_toString<octet>(t6) << endl;
|
||||
cerr << "(t1 + t2) * t3 = " << __m128i_toString<octet>(t8) << endl;
|
||||
}
|
||||
}
|
||||
t1 = _mm_set_epi32(0, 0, 0, 03);
|
||||
t2 = _mm_set_epi32(0, 0, 0, 11);
|
||||
//gfmul128(t1, t2, &t3);
|
||||
mul128(t1, t2, &t3, &t4);
|
||||
cout << "t1 = " << __m128i_toString<octet>(t1) << endl;
|
||||
cout << "t2 = " << __m128i_toString<octet>(t2) << endl;
|
||||
cout << "t3 = " << __m128i_toString<octet>(t3) << endl;
|
||||
cout << "t4 = " << __m128i_toString<octet>(t4) << endl;
|
||||
|
||||
uint64_t cc[] __attribute__((aligned (16))) = { 0,0 };
|
||||
_mm_store_si128((__m128i*)cc, t1);
|
||||
word t1w = cc[0];
|
||||
_mm_store_si128((__m128i*)cc, t2);
|
||||
word t2w = cc[0];
|
||||
|
||||
cout << "t1w = " << t1w << endl;
|
||||
cout << "t1 = " << word_to_bytes(t1w) << endl;
|
||||
cout << "t2 = " << word_to_bytes(t2w) << endl;
|
||||
cout << "t1 * t2 = " << word_to_bytes(t1w*t2w) << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void OTExtension::check_correlation(int nOTs,
|
||||
const BitVector& receiverInput)
|
||||
{
|
||||
//cout << "\tStarting correlation check\n" << flush;
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval startv, endv;
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
if (nbaseOTs != 128)
|
||||
{
|
||||
cerr << "Correlation check not implemented for length != 128\n";
|
||||
throw not_implemented();
|
||||
}
|
||||
PRNG G;
|
||||
octet* seed = new octet[SEED_SIZE];
|
||||
random_seed_commit(seed, *player, SEED_SIZE);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
double elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tCommitment for seed took time " << elapsed/1000000 << endl << flush;
|
||||
times["Commitment for seed"] += timeval_diff(&startv, &endv);
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
G.SetSeed(seed);
|
||||
delete[] seed;
|
||||
vector<octetStream> os(2);
|
||||
|
||||
if (!Check_CPU_support_AES())
|
||||
{
|
||||
cerr << "Not implemented GF(2^128) multiplication in C\n";
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
__m128i Delta, x128i;
|
||||
Delta = _mm_load_si128((__m128i*)&(baseReceiverInput.get_ptr()[0]));
|
||||
|
||||
BitVector chi(nbaseOTs);
|
||||
BitVector x(nbaseOTs);
|
||||
__m128i t = _mm_setzero_si128();
|
||||
__m128i q = _mm_setzero_si128();
|
||||
__m128i t2 = _mm_setzero_si128();
|
||||
__m128i q2 = _mm_setzero_si128();
|
||||
__m128i chii, ti, qi, ti2, qi2;
|
||||
x128i = _mm_setzero_si128();
|
||||
|
||||
for (int i = 0; i < nOTs; i++)
|
||||
{
|
||||
// chi.randomize(G);
|
||||
// chii = _mm_load_si128((__m128i*)&(chi.get_ptr()[0]));
|
||||
chii = G.get_doubleword();
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
if (receiverInput.get_bit(i) == 1)
|
||||
{
|
||||
x128i = _mm_xor_si128(x128i, chii);
|
||||
}
|
||||
ti = _mm_loadu_si128((__m128i*)get_receiver_output(i));
|
||||
// multiply over polynomial ring to avoid reduction
|
||||
mul128(ti, chii, &ti, &ti2);
|
||||
t = _mm_xor_si128(t, ti);
|
||||
t2 = _mm_xor_si128(t2, ti2);
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
qi = _mm_loadu_si128((__m128i*)(get_sender_output(0, i)));
|
||||
mul128(qi, chii, &qi, &qi2);
|
||||
q = _mm_xor_si128(q, qi);
|
||||
q2 = _mm_xor_si128(q2, qi2);
|
||||
}
|
||||
}
|
||||
#ifdef OTEXT_DEBUG
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
cout << "\tSending x,t\n";
|
||||
cout << "\tsend x = " << __m128i_toString<octet>(x128i) << endl;
|
||||
cout << "\tsend t = " << __m128i_toString<octet>(t) << endl;
|
||||
cout << "\tsend t2 = " << __m128i_toString<octet>(t2) << endl;
|
||||
}
|
||||
#endif
|
||||
check_iteration(Delta, q, q2, t, t2, x128i);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tChecking correlation took time " << elapsed/1000000 << endl << flush;
|
||||
times["Checking correlation"] += timeval_diff(&startv, &endv);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OTExtension::check_iteration(__m128i delta, __m128i q, __m128i q2,
|
||||
__m128i t, __m128i t2, __m128i x)
|
||||
{
|
||||
vector<octetStream> os(2);
|
||||
// send x, t;
|
||||
__m128i received_t, received_t2, received_x, tmp1, tmp2;
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
os[0].append((octet*)&x, sizeof(x));
|
||||
os[0].append((octet*)&t, sizeof(t));
|
||||
os[0].append((octet*)&t2, sizeof(t2));
|
||||
}
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
os[1].consume((octet*)&received_x, sizeof(received_x));
|
||||
os[1].consume((octet*)&received_t, sizeof(received_t));
|
||||
os[1].consume((octet*)&received_t2, sizeof(received_t2));
|
||||
|
||||
// check t = x * Delta + q
|
||||
//gfmul128(received_x, delta, &tmp1);
|
||||
mul128(received_x, delta, &tmp1, &tmp2);
|
||||
tmp1 = _mm_xor_si128(tmp1, q);
|
||||
tmp2 = _mm_xor_si128(tmp2, q2);
|
||||
|
||||
if (eq_m128i(tmp1, received_t) && eq_m128i(tmp2, received_t2))
|
||||
{
|
||||
//cout << "\tCheck passed\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "Correlation check failed\n";
|
||||
cout << "rec t = " << __m128i_toString<octet>(received_t) << endl;
|
||||
cout << "tmp1 = " << __m128i_toString<octet>(tmp1) << endl;
|
||||
cout << "q = " << __m128i_toString<octet>(q) << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
octet* OTExtension::get_receiver_output(int i)
|
||||
{
|
||||
return receiverOutput[i].get_ptr();
|
||||
}
|
||||
|
||||
octet* OTExtension::get_sender_output(int choice, int i)
|
||||
{
|
||||
return senderOutput[choice][i].get_ptr();
|
||||
}
|
||||
122
OT/OTExtension.h
Normal file
122
OT/OTExtension.h
Normal file
@@ -0,0 +1,122 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _OTEXTENSION
|
||||
#define _OTEXTENSION
|
||||
|
||||
#include "OT/BaseOT.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include "Networking/Player.h"
|
||||
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <assert.h>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
using namespace std;
|
||||
|
||||
//#define OTEXT_TIMER
|
||||
//#define OTEXT_DEBUG
|
||||
|
||||
|
||||
class OTExtension
|
||||
{
|
||||
public:
|
||||
BitVector baseReceiverInput;
|
||||
vector< vector<BitVector> > senderOutput;
|
||||
vector<BitVector> receiverOutput;
|
||||
map<string,long long> times;
|
||||
|
||||
OTExtension(int nbaseOTs, int baseLength,
|
||||
int nloops, int nsubloops,
|
||||
TwoPartyPlayer* player,
|
||||
BitVector& baseReceiverInput,
|
||||
vector< vector<BitVector> >& baseSenderInput,
|
||||
vector<BitVector>& baseReceiverOutput,
|
||||
OT_ROLE role=BOTH,
|
||||
bool passive=false)
|
||||
: baseReceiverInput(baseReceiverInput), passive_only(passive), nbaseOTs(nbaseOTs),
|
||||
baseLength(baseLength), nloops(nloops), nsubloops(nsubloops), ot_role(role), player(player)
|
||||
{
|
||||
G_sender.resize(nbaseOTs, vector<PRNG>(2));
|
||||
G_receiver.resize(nbaseOTs);
|
||||
|
||||
// set up PRGs for expanding the seed OTs
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
assert(baseSenderInput[i][0].size_bytes() >= AES_BLK_SIZE);
|
||||
assert(baseSenderInput[i][1].size_bytes() >= AES_BLK_SIZE);
|
||||
assert(baseReceiverOutput[i].size_bytes() >= AES_BLK_SIZE);
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
G_sender[i][0].SetSeed(baseSenderInput[i][0].get_ptr());
|
||||
G_sender[i][1].SetSeed(baseSenderInput[i][1].get_ptr());
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
G_receiver[i].SetSeed(baseReceiverOutput[i].get_ptr());
|
||||
}
|
||||
|
||||
#ifdef OTEXT_DEBUG
|
||||
// sanity check for base OTs
|
||||
vector<octetStream> os(2);
|
||||
BitVector t0(128);
|
||||
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
// send both inputs to test
|
||||
baseSenderInput[i][0].pack(os[0]);
|
||||
baseSenderInput[i][1].pack(os[0]);
|
||||
}
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
// sender checks results
|
||||
t0.unpack(os[1]);
|
||||
if (baseReceiverInput.get_bit(i) == 1)
|
||||
t0.unpack(os[1]);
|
||||
if (!t0.equals(baseReceiverOutput[i]))
|
||||
{
|
||||
cerr << "Incorrect base OT\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
os[0].reset_write_head();
|
||||
os[1].reset_write_head();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~OTExtension() {}
|
||||
|
||||
virtual void transfer(int nOTs, const BitVector& receiverInput);
|
||||
virtual octet* get_receiver_output(int i);
|
||||
virtual octet* get_sender_output(int choice, int i);
|
||||
|
||||
protected:
|
||||
bool passive_only;
|
||||
int nbaseOTs, baseLength, nloops, nsubloops;
|
||||
OT_ROLE ot_role;
|
||||
TwoPartyPlayer* player;
|
||||
vector< vector<PRNG> > G_sender;
|
||||
vector<PRNG> G_receiver;
|
||||
|
||||
void check_correlation(int nOTs,
|
||||
const BitVector& receiverInput);
|
||||
|
||||
void check_iteration(__m128i delta, __m128i q, __m128i q2,
|
||||
__m128i t, __m128i t2, __m128i x);
|
||||
|
||||
void hash_outputs(int nOTs, vector<BitVector>& receiverOutput);
|
||||
};
|
||||
|
||||
#endif
|
||||
466
OT/OTExtensionWithMatrix.cpp
Normal file
466
OT/OTExtensionWithMatrix.cpp
Normal file
@@ -0,0 +1,466 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OTExtensionWithMatrix.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "OTExtensionWithMatrix.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
void OTExtensionWithMatrix::seed(vector<BitMatrix>& baseSenderInput,
|
||||
BitMatrix& baseReceiverOutput)
|
||||
{
|
||||
nbaseOTs = baseReceiverInput.size();
|
||||
//cout << "nbaseOTs " << nbaseOTs << endl;
|
||||
G_sender.resize(nbaseOTs, vector<PRNG>(2));
|
||||
G_receiver.resize(nbaseOTs);
|
||||
|
||||
// set up PRGs for expanding the seed OTs
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
G_sender[i][0].SetSeed((octet*)&baseSenderInput[0].squares[i/128].rows[i%128]);
|
||||
G_sender[i][1].SetSeed((octet*)&baseSenderInput[1].squares[i/128].rows[i%128]);
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
G_receiver[i].SetSeed((octet*)&baseReceiverOutput.squares[i/128].rows[i%128]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::transfer(int nOTs,
|
||||
const BitVector& receiverInput)
|
||||
{
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval totalstartv, totalendv;
|
||||
gettimeofday(&totalstartv, NULL);
|
||||
#endif
|
||||
cout << "\tDoing " << nOTs << " extended OTs as " << role_to_str(ot_role) << endl;
|
||||
|
||||
// resize to account for extra k OTs that are discarded
|
||||
BitVector newReceiverInput(nOTs);
|
||||
for (unsigned int i = 0; i < receiverInput.size_bytes(); i++)
|
||||
{
|
||||
newReceiverInput.set_byte(i, receiverInput.get_byte(i));
|
||||
}
|
||||
|
||||
for (int loop = 0; loop < nloops; loop++)
|
||||
{
|
||||
extend<gf2n>(nOTs, newReceiverInput);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&totalendv, NULL);
|
||||
double elapsed = timeval_diff(&totalstartv, &totalendv);
|
||||
cout << "\t\tTotal thread time: " << elapsed/1000000 << endl << flush;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&totalendv, NULL);
|
||||
times["Total thread"] += timeval_diff(&totalstartv, &totalendv);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::resize(int nOTs)
|
||||
{
|
||||
t1.resize(nOTs);
|
||||
u.resize(nOTs);
|
||||
senderOutputMatrices.resize(2);
|
||||
for (int i = 0; i < 2; i++)
|
||||
senderOutputMatrices[i].resize(nOTs);
|
||||
receiverOutputMatrix.resize(nOTs);
|
||||
}
|
||||
|
||||
// the template is used to denote the field of the hash output
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
|
||||
{
|
||||
// if (nOTs % nbaseOTs != 0)
|
||||
// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
|
||||
if (nOTs_requested == 0)
|
||||
return;
|
||||
// add k + s to account for discarding k OTs
|
||||
int nOTs = nOTs_requested + 2 * 128;
|
||||
|
||||
int slice = nOTs / nsubloops / 128;
|
||||
nOTs = slice * nsubloops * 128;
|
||||
resize(nOTs);
|
||||
newReceiverInput.resize(nOTs);
|
||||
|
||||
// randomize last 128 + 128 bits that will be discarded
|
||||
for (int i = 0; i < 4; i++)
|
||||
newReceiverInput.set_word(nOTs/64 - i - 1, G.get_word());
|
||||
|
||||
// subloop for first part to interleave communication with computation
|
||||
for (int start = 0; start < nOTs / 128; start += slice)
|
||||
{
|
||||
expand<gf2n>(start, slice);
|
||||
correlate<gf2n>(start, slice, newReceiverInput, true);
|
||||
transpose(start, slice);
|
||||
}
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
double elapsed;
|
||||
#endif
|
||||
// correlation check
|
||||
if (!passive_only)
|
||||
{
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval startv, endv;
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
check_correlation(nOTs, newReceiverInput);
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tTotal correlation check time: " << elapsed/1000000 << endl << flush;
|
||||
times["Total correlation check"] += timeval_diff(&startv, &endv);
|
||||
#endif
|
||||
}
|
||||
|
||||
hash_outputs<T>(nOTs);
|
||||
|
||||
receiverOutputMatrix.resize(nOTs_requested);
|
||||
senderOutputMatrices[0].resize(nOTs_requested);
|
||||
senderOutputMatrices[1].resize(nOTs_requested);
|
||||
newReceiverInput.resize(nOTs_requested);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::expand(int start, int slice)
|
||||
{
|
||||
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
|
||||
BitMatrixSlice senderOutputSlices[2] = {
|
||||
BitMatrixSlice(senderOutputMatrices[0], start, slice),
|
||||
BitMatrixSlice(senderOutputMatrices[1], start, slice)
|
||||
};
|
||||
BitMatrixSlice t1Slice(t1, start, slice);
|
||||
|
||||
// expand with PRG
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
receiverOutputSlice.randomize<T>(i, G_sender[i][0]);
|
||||
t1Slice.randomize<T>(i, G_sender[i][1]);
|
||||
}
|
||||
}
|
||||
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
// randomize base receiver output
|
||||
senderOutputSlices[0].randomize<T>(i, G_receiver[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::expand_transposed()
|
||||
{
|
||||
for (int i = 0; i < nbaseOTs; i++)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
receiverOutputMatrix.squares[i/128].randomize<T>(i % 128, G_sender[i][0]);
|
||||
t1.squares[i/128].randomize<T>(i % 128, G_sender[i][1]);
|
||||
}
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
senderOutputMatrices[0].squares[i/128].randomize<T>(i % 128, G_receiver[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::setup_for_correlation(vector<BitMatrix>& baseSenderOutputs, BitMatrix& baseReceiverOutput)
|
||||
{
|
||||
receiverOutputMatrix = baseSenderOutputs[0];
|
||||
t1 = baseSenderOutputs[1];
|
||||
u.resize(t1.size());
|
||||
senderOutputMatrices.resize(2);
|
||||
senderOutputMatrices[0] = baseReceiverOutput;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::correlate(int start, int slice,
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat)
|
||||
{
|
||||
vector<octetStream> os(2);
|
||||
|
||||
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
|
||||
BitMatrixSlice senderOutputSlices[2] = {
|
||||
BitMatrixSlice(senderOutputMatrices[0], start, slice),
|
||||
BitMatrixSlice(senderOutputMatrices[1], start, slice)
|
||||
};
|
||||
BitMatrixSlice t1Slice(t1, start, slice);
|
||||
BitMatrixSlice uSlice(u, start, slice);
|
||||
|
||||
// create correlation
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
t1Slice.rsub<T>(receiverOutputSlice);
|
||||
t1Slice.add<T>(newReceiverInput, repeat);
|
||||
t1Slice.pack(os[0]);
|
||||
|
||||
// t1 = receiverOutputMatrix;
|
||||
// t1 ^= newReceiverInput;
|
||||
// receiverOutputMatrix.print_side_by_side(t1);
|
||||
}
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval commst1, commst2;
|
||||
gettimeofday(&commst1, NULL);
|
||||
#endif
|
||||
// send t0 + t1 + x
|
||||
send_if_ot_receiver(player, os, ot_role);
|
||||
|
||||
// sender adjusts using base receiver bits
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
// u = t0 + t1 + x
|
||||
uSlice.unpack(os[1]);
|
||||
senderOutputSlices[0].conditional_add<T>(baseReceiverInput, u, !useConstantBase);
|
||||
}
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&commst2, NULL);
|
||||
double commstime = timeval_diff(&commst1, &commst2);
|
||||
cout << "\t\tCommunication took time " << commstime/1000000 << endl << flush;
|
||||
times["Communication"] += timeval_diff(&commst1, &commst2);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::transpose(int start, int slice)
|
||||
{
|
||||
BitMatrixSlice receiverOutputSlice(receiverOutputMatrix, start, slice);
|
||||
BitMatrixSlice senderOutputSlices[2] = {
|
||||
BitMatrixSlice(senderOutputMatrices[0], start, slice),
|
||||
BitMatrixSlice(senderOutputMatrices[1], start, slice)
|
||||
};
|
||||
|
||||
// transpose t0[i] onto receiverOutput and tmp (q[i]) onto senderOutput[i][0]
|
||||
|
||||
//cout << "Starting matrix transpose\n" << flush << endl;
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval transt1, transt2;
|
||||
gettimeofday(&transt1, NULL);
|
||||
#endif
|
||||
// transpose in 128-bit chunks
|
||||
if (ot_role & RECEIVER)
|
||||
receiverOutputSlice.transpose();
|
||||
if (ot_role & SENDER)
|
||||
senderOutputSlices[0].transpose();
|
||||
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&transt2, NULL);
|
||||
double transtime = timeval_diff(&transt1, &transt2);
|
||||
cout << "\t\tMatrix transpose took time " << transtime/1000000 << endl << flush;
|
||||
times["Matrix transpose"] += timeval_diff(&transt1, &transt2);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Hash outputs to make into random OT
|
||||
*/
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::hash_outputs(int nOTs)
|
||||
{
|
||||
//cout << "Hashing... " << flush;
|
||||
octetStream os, h_os(HASH_SIZE);
|
||||
square128 tmp;
|
||||
MMO mmo;
|
||||
#ifdef OTEXT_TIMER
|
||||
timeval startv, endv;
|
||||
gettimeofday(&startv, NULL);
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < nOTs / 128; i++)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
tmp = senderOutputMatrices[0].squares[i];
|
||||
tmp ^= baseReceiverInput;
|
||||
senderOutputMatrices[0].squares[i].hash_row_wise<T>(mmo, senderOutputMatrices[0].squares[i]);
|
||||
senderOutputMatrices[1].squares[i].hash_row_wise<T>(mmo, tmp);
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
receiverOutputMatrix.squares[i].hash_row_wise<T>(mmo, receiverOutputMatrix.squares[i]);
|
||||
}
|
||||
}
|
||||
//cout << "done.\n";
|
||||
#ifdef OTEXT_TIMER
|
||||
gettimeofday(&endv, NULL);
|
||||
double elapsed = timeval_diff(&startv, &endv);
|
||||
cout << "\t\tOT ext hashing took time " << elapsed/1000000 << endl << flush;
|
||||
times["Hashing"] += timeval_diff(&startv, &endv);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples, vector<T>& output)
|
||||
{
|
||||
if (receiverOutputMatrix.squares.size() < nTriples)
|
||||
throw invalid_length();
|
||||
output.resize(nTriples);
|
||||
for (unsigned int j = 0; j < nTriples; j++)
|
||||
{
|
||||
T c1, c2;
|
||||
receiverOutputMatrix.squares[j].to(c1);
|
||||
senderOutputMatrices[0].squares[j].to(c2);
|
||||
output[j] = c1 - c2;
|
||||
}
|
||||
}
|
||||
|
||||
octet* OTExtensionWithMatrix::get_receiver_output(int i)
|
||||
{
|
||||
return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128];
|
||||
}
|
||||
|
||||
octet* OTExtensionWithMatrix::get_sender_output(int choice, int i)
|
||||
{
|
||||
return (octet*)&senderOutputMatrices[choice].squares[i/128].rows[i%128];
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::print(BitVector& newReceiverInput, int i)
|
||||
{
|
||||
if (player->my_num() == 0)
|
||||
{
|
||||
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix, i);
|
||||
print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_sender(senderOutputMatrices[0].squares[i], senderOutputMatrices[1].squares[i]);
|
||||
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix, i);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int k, int offset)
|
||||
{
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
if (newReceiverInput.get_bit((offset + k) * 128 + i))
|
||||
{
|
||||
for (int j = 0; j < 33; j++)
|
||||
cout << " ";
|
||||
cout << T(matrix.squares[k].rows[i]);
|
||||
}
|
||||
else
|
||||
cout << int128(matrix.squares[k].rows[i]);
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::print_sender(square128& square0, square128& square1)
|
||||
{
|
||||
if (ot_role & SENDER)
|
||||
{
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
cout << int128(square0.rows[i]) << " ";
|
||||
cout << int128(square1.rows[i]) << " ";
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void OTExtensionWithMatrix::print_post_correlate(BitVector& newReceiverInput, int j, int offset, int sender)
|
||||
{
|
||||
cout << "post correlate, sender" << sender << endl;
|
||||
if (player->my_num() == sender)
|
||||
{
|
||||
T delta = newReceiverInput.get_int128(offset + j);
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
cout << (int128(receiverOutputMatrix.squares[j].rows[i]));
|
||||
cout << " ";
|
||||
cout << (T(receiverOutputMatrix.squares[j].rows[i]) - delta);
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
print_receiver<T>(baseReceiverInput, senderOutputMatrices[0], j);
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::print_pre_correlate(int i)
|
||||
{
|
||||
cout << "pre correlate" << endl;
|
||||
if (player->my_num() == 0)
|
||||
print_sender(receiverOutputMatrix.squares[i], t1.squares[i]);
|
||||
else
|
||||
print_receiver<gf2n>(baseReceiverInput, senderOutputMatrices[0], i);
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::print_post_transpose(BitVector& newReceiverInput, int i, int sender)
|
||||
{
|
||||
cout << "post transpose, sender " << sender << endl;
|
||||
if (player->my_num() == sender)
|
||||
{
|
||||
print_receiver<gf2n>(newReceiverInput, receiverOutputMatrix);
|
||||
}
|
||||
else
|
||||
{
|
||||
square128 tmp = senderOutputMatrices[0].squares[i];
|
||||
tmp ^= baseReceiverInput;
|
||||
print_sender(senderOutputMatrices[0].squares[i], tmp);
|
||||
}
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::print_pre_expand()
|
||||
{
|
||||
cout << "pre expand" << endl;
|
||||
if (player->my_num() == 0)
|
||||
{
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
for (int j = 0; j < 2; j++)
|
||||
cout << int128(_mm_loadu_si128((__m128i*)G_sender[i][j].get_seed())) << " ";
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < 16; i++)
|
||||
{
|
||||
if (baseReceiverInput.get_bit(i))
|
||||
{
|
||||
for (int j = 0; j < 33; j++)
|
||||
cout << " ";
|
||||
}
|
||||
cout << int128(_mm_loadu_si128((__m128i*)G_receiver[i].get_seed())) << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
template void OTExtensionWithMatrix::correlate<gf2n>(int start, int slice,
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat);
|
||||
template void OTExtensionWithMatrix::correlate<gfp>(int start, int slice,
|
||||
BitVector& newReceiverInput, bool useConstantBase, int repeat);
|
||||
template void OTExtensionWithMatrix::print_post_correlate<gf2n>(
|
||||
BitVector& newReceiverInput, int j, int offset, int sender);
|
||||
template void OTExtensionWithMatrix::print_post_correlate<gfp>(
|
||||
BitVector& newReceiverInput, int j, int offset, int sender);
|
||||
template void OTExtensionWithMatrix::extend<gf2n>(int nOTs_requested,
|
||||
BitVector& newReceiverInput);
|
||||
template void OTExtensionWithMatrix::extend<gfp>(int nOTs_requested,
|
||||
BitVector& newReceiverInput);
|
||||
template void OTExtensionWithMatrix::expand<gf2n>(int start, int slice);
|
||||
template void OTExtensionWithMatrix::expand<gfp>(int start, int slice);
|
||||
template void OTExtensionWithMatrix::expand_transposed<gf2n>();
|
||||
template void OTExtensionWithMatrix::expand_transposed<gfp>();
|
||||
template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples,
|
||||
vector<gf2n>& output);
|
||||
template void OTExtensionWithMatrix::reduce_squares(unsigned int nTriples,
|
||||
vector<gfp>& output);
|
||||
71
OT/OTExtensionWithMatrix.h
Normal file
71
OT/OTExtensionWithMatrix.h
Normal file
@@ -0,0 +1,71 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OTExtensionWithMatrix.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_OTEXTENSIONWITHMATRIX_H_
|
||||
#define OT_OTEXTENSIONWITHMATRIX_H_
|
||||
|
||||
#include "OTExtension.h"
|
||||
#include "BitMatrix.h"
|
||||
#include "Math/gf2n.h"
|
||||
|
||||
class OTExtensionWithMatrix : public OTExtension
|
||||
{
|
||||
public:
|
||||
vector<BitMatrix> senderOutputMatrices;
|
||||
BitMatrix receiverOutputMatrix;
|
||||
BitMatrix t1, u;
|
||||
PRNG G;
|
||||
|
||||
OTExtensionWithMatrix(int nbaseOTs, int baseLength,
|
||||
int nloops, int nsubloops,
|
||||
TwoPartyPlayer* player,
|
||||
BitVector& baseReceiverInput,
|
||||
vector< vector<BitVector> >& baseSenderInput,
|
||||
vector<BitVector>& baseReceiverOutput,
|
||||
OT_ROLE role=BOTH,
|
||||
bool passive=false)
|
||||
: OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
|
||||
baseSenderInput, baseReceiverOutput, role, passive) {
|
||||
G.ReSeed();
|
||||
}
|
||||
|
||||
void seed(vector<BitMatrix>& baseSenderInput,
|
||||
BitMatrix& baseReceiverOutput);
|
||||
void transfer(int nOTs, const BitVector& receiverInput);
|
||||
void resize(int nOTs);
|
||||
template <class T>
|
||||
void extend(int nOTs, BitVector& newReceiverInput);
|
||||
template <class T>
|
||||
void expand(int start, int slice);
|
||||
template <class T>
|
||||
void expand_transposed();
|
||||
template <class T>
|
||||
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
|
||||
void transpose(int start, int slice);
|
||||
void setup_for_correlation(vector<BitMatrix>& baseSenderOutputs, BitMatrix& baseReceiverOutput);
|
||||
template <class T>
|
||||
void reduce_squares(unsigned int nTriples, vector<T>& output);
|
||||
|
||||
void print(BitVector& newReceiverInput, int i = 0);
|
||||
template <class T>
|
||||
void print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int i = 0, int offset = 0);
|
||||
void print_sender(square128& square0, square128& square);
|
||||
template <class T>
|
||||
void print_post_correlate(BitVector& newReceiverInput, int i = 0, int offset = 0, int sender = 0);
|
||||
void print_pre_correlate(int i = 0);
|
||||
void print_post_transpose(BitVector& newReceiverInput, int i = 0, int sender = 0);
|
||||
void print_pre_expand();
|
||||
|
||||
octet* get_receiver_output(int i);
|
||||
octet* get_sender_output(int choice, int i);
|
||||
|
||||
protected:
|
||||
template <class T>
|
||||
void hash_outputs(int nOTs);
|
||||
};
|
||||
|
||||
#endif /* OT_OTEXTENSIONWITHMATRIX_H_ */
|
||||
401
OT/OTMachine.cpp
Normal file
401
OT/OTMachine.cpp
Normal file
@@ -0,0 +1,401 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "OT/OTExtension.h"
|
||||
#include "OT/OTExtensionWithMatrix.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include <sys/time.h>
|
||||
|
||||
#include "OutputCheck.h"
|
||||
#include "OTMachine.h"
|
||||
|
||||
//#define BASE_OT_DEBUG
|
||||
|
||||
class OT_thread_info
|
||||
{
|
||||
public:
|
||||
|
||||
int thread_num;
|
||||
bool stop;
|
||||
int other_player_num;
|
||||
OTExtension* ot_ext;
|
||||
int nOTs, nbase;
|
||||
BitVector receiverInput;
|
||||
};
|
||||
|
||||
void* run_otext_thread(void* ptr)
|
||||
{
|
||||
OT_thread_info *tinfo = (OT_thread_info*) ptr;
|
||||
|
||||
//int num = tinfo->thread_num;
|
||||
//int other_player_num = tinfo->other_player_num;
|
||||
printf("\tI am in thread %d\n", tinfo->thread_num);
|
||||
tinfo->ot_ext->transfer(tinfo->nOTs, tinfo->receiverInput);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
OTMachine::OTMachine(int argc, const char** argv)
|
||||
{
|
||||
opt.add(
|
||||
"", // Default.
|
||||
1, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"This player's number, 0/1 (required).", // Help description.
|
||||
"-p", // Flag token.
|
||||
"--player" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"5000", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Base port number (default: 5000).", // Help description.
|
||||
"-pn", // Flag token.
|
||||
"--portnum" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"localhost", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Host name(s) that player 0 is running on (default: localhost). Split with commas.", // Help description.
|
||||
"-h", // Flag token.
|
||||
"--hostname" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"1024",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Number of extended OTs to run (default: 1024).",
|
||||
"-n",
|
||||
"--nOTs"
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"128", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of base OTs to run (default: 128).", // Help description.
|
||||
"-b", // Flag token.
|
||||
"--nbase" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"s",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Mode for OT. a (asymmetric) or s (symmetric, i.e. play both sender/receiver) (default: s).",
|
||||
"-m",
|
||||
"--mode"
|
||||
);
|
||||
opt.add(
|
||||
"1",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Number of threads (default: 1).",
|
||||
"-x",
|
||||
"--nthreads"
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"1",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Number of loops (default: 1).",
|
||||
"-l",
|
||||
"--nloops"
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"1",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Number of subloops (default: 1).",
|
||||
"-s",
|
||||
"--nsubloops"
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Run in passive security mode.", // Help description.
|
||||
"-pas", // Flag token.
|
||||
"--passive" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Write results to files.", // Help description.
|
||||
"-o", // Flag token.
|
||||
"--output" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Real base OT.", // Help description.
|
||||
"-r", // Flag token.
|
||||
"--real" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
|
||||
string hostname, ot_mode, usage;
|
||||
passive = false;
|
||||
opt.get("-p")->getInt(my_num);
|
||||
opt.get("-pn")->getInt(portnum_base);
|
||||
opt.get("-h")->getString(hostname);
|
||||
opt.get("-n")->getLong(nOTs);
|
||||
opt.get("-m")->getString(ot_mode);
|
||||
opt.get("--nthreads")->getInt(nthreads);
|
||||
opt.get("--nloops")->getInt(nloops);
|
||||
opt.get("--nsubloops")->getInt(nsubloops);
|
||||
opt.get("--nbase")->getInt(nbase);
|
||||
if (opt.isSet("-pas"))
|
||||
passive = true;
|
||||
|
||||
if (!opt.isSet("-p"))
|
||||
{
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
cout << "Player 0 host name = " << hostname << endl;
|
||||
cout << "Creating " << nOTs << " extended OTs in " << nthreads << " threads\n";
|
||||
cout << "Running in mode " << ot_mode << endl;
|
||||
|
||||
if (passive)
|
||||
cout << "Running with PASSIVE security only\n";
|
||||
|
||||
if (nbase < 128)
|
||||
cout << "WARNING: only using " << nbase << " seed OTs, using these for OT extensions is insecure.\n";
|
||||
|
||||
if (ot_mode.compare("s") == 0)
|
||||
ot_role = BOTH;
|
||||
else if (ot_mode.compare("a") == 0)
|
||||
{
|
||||
if (my_num == 0)
|
||||
ot_role = SENDER;
|
||||
else
|
||||
ot_role = RECEIVER;
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "Invalid OT mode argument: " << ot_mode << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Several names for multiplexing
|
||||
unsigned int pos = 0;
|
||||
while (pos < hostname.length())
|
||||
{
|
||||
string::size_type new_pos = hostname.find(',', pos);
|
||||
if (new_pos == string::npos)
|
||||
new_pos = hostname.length();
|
||||
int len = new_pos - pos;
|
||||
string name = hostname.substr(pos, len);
|
||||
pos = new_pos + 1;
|
||||
|
||||
vector<string> names(2);
|
||||
names[my_num] = "localhost";
|
||||
names[1-my_num] = name;
|
||||
N.resize(N.size() + 1);
|
||||
N[N.size()-1].init(my_num, portnum_base, names);
|
||||
}
|
||||
|
||||
P = new TwoPartyPlayer(N[0], 1 - my_num, 500);
|
||||
|
||||
timeval baseOTstart, baseOTend;
|
||||
gettimeofday(&baseOTstart, NULL);
|
||||
// swap role for base OTs
|
||||
if (opt.isSet("-r"))
|
||||
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
else
|
||||
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
cout << "real mode " << opt.isSet("-r") << endl;
|
||||
BaseOT& bot = *bot_;
|
||||
bot.exec_base();
|
||||
gettimeofday(&baseOTend, NULL);
|
||||
double basetime = timeval_diff(&baseOTstart, &baseOTend);
|
||||
cout << "\t\tBaseTime (" << role_to_str(ot_role) << "): " << basetime/1000000 << endl << flush;
|
||||
|
||||
// Receiver send something to force synchronization
|
||||
// (since Sender finishes baseOTs before Receiver)
|
||||
int a = 3;
|
||||
vector<octetStream> os(2);
|
||||
os[0].store(a);
|
||||
P->send_receive_player(os);
|
||||
os[1].get(a);
|
||||
cout << a << endl;
|
||||
|
||||
#ifdef BASE_OT_DEBUG
|
||||
// check base OTs
|
||||
bot.check();
|
||||
// check after extending with PRG a few times
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
bot.extend_length();
|
||||
bot.check();
|
||||
}
|
||||
cout << "Verifying base OTs (debugging)\n";
|
||||
#endif
|
||||
|
||||
// convert baseOT selection bits to BitVector
|
||||
// (not already BitVector due to legacy PVW code)
|
||||
baseReceiverInput.resize(nbase);
|
||||
for (int i = 0; i < nbase; i++)
|
||||
{
|
||||
baseReceiverInput.set_bit(i, bot.receiver_inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
OTMachine::~OTMachine()
|
||||
{
|
||||
delete bot_;
|
||||
delete P;
|
||||
}
|
||||
|
||||
|
||||
void OTMachine::run()
|
||||
{
|
||||
// divide nOTs between threads and loops
|
||||
nOTs = DIV_CEIL(nOTs, nthreads * nloops);
|
||||
// round up to multiple of base OTs and subloops
|
||||
// discount for discarded OTs
|
||||
nOTs = DIV_CEIL(nOTs + 2 * 128, nbase * nsubloops) * nbase * nsubloops - 2 * 128;
|
||||
cout << "Running " << nOTs << " OT extensions per thread and loop\n" << flush;
|
||||
|
||||
// PRG for generating inputs etc
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
BitVector receiverInput(nOTs);
|
||||
receiverInput.randomize(G);
|
||||
BaseOT& bot = *bot_;
|
||||
|
||||
cout << "Initialize OT Extension\n";
|
||||
vector<OT_thread_info> tinfos(nthreads);
|
||||
vector<pthread_t> threads(nthreads);
|
||||
timeval OTextstart, OTextend;
|
||||
gettimeofday(&OTextstart, NULL);
|
||||
|
||||
// copy base inputs/outputs for each thread
|
||||
vector<BitVector> base_receiver_input_copy(nthreads);
|
||||
vector<vector< vector<BitVector> > > base_sender_inputs_copy(nthreads, vector<vector<BitVector> >(nbase, vector<BitVector>(2)));
|
||||
vector< vector<BitVector> > base_receiver_outputs_copy(nthreads, vector<BitVector>(nbase));
|
||||
vector<TwoPartyPlayer*> players(nthreads);
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
tinfos[i].receiverInput.assign(receiverInput);
|
||||
|
||||
base_receiver_input_copy[i].assign(baseReceiverInput);
|
||||
for (int j = 0; j < nbase; j++)
|
||||
{
|
||||
base_sender_inputs_copy[i][j][0].assign(bot.sender_inputs[j][0]);
|
||||
base_sender_inputs_copy[i][j][1].assign(bot.sender_inputs[j][1]);
|
||||
base_receiver_outputs_copy[i][j].assign(bot.receiver_outputs[j]);
|
||||
}
|
||||
// now setup resources for each thread
|
||||
// round robin with the names
|
||||
players[i] = new TwoPartyPlayer(N[i%N.size()], 1 - my_num, (i+1) * 1000);
|
||||
tinfos[i].thread_num = i+1;
|
||||
tinfos[i].other_player_num = 1 - my_num;
|
||||
tinfos[i].nOTs = nOTs;
|
||||
tinfos[i].ot_ext = new OTExtensionWithMatrix(nbase, bot.length(),
|
||||
nloops, nsubloops,
|
||||
players[i],
|
||||
base_receiver_input_copy[i],
|
||||
base_sender_inputs_copy[i],
|
||||
base_receiver_outputs_copy[i],
|
||||
ot_role,
|
||||
passive);
|
||||
|
||||
// create the thread
|
||||
pthread_create(&threads[i], NULL, run_otext_thread, &tinfos[i]);
|
||||
|
||||
// extend base OTs with PRG for the next thread
|
||||
bot.extend_length();
|
||||
}
|
||||
// wait for threads to finish
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
pthread_join(threads[i],NULL);
|
||||
cout << "thread " << i+1 << " finished\n" << flush;
|
||||
}
|
||||
|
||||
map<string,long long>& times = tinfos[0].ot_ext->times;
|
||||
for (map<string,long long>::iterator it = times.begin(); it != times.end(); it++)
|
||||
{
|
||||
long long sum = 0;
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
sum += tinfos[i].ot_ext->times[it->first];
|
||||
|
||||
cout << it->first << " on average took time "
|
||||
<< double(sum) / nthreads / 1e6 << endl;
|
||||
}
|
||||
|
||||
gettimeofday(&OTextend, NULL);
|
||||
double totaltime = timeval_diff(&OTextstart, &OTextend);
|
||||
cout << "Time for OTExt threads (" << role_to_str(ot_role) << "): " << totaltime/1000000 << endl << flush;
|
||||
|
||||
if (opt.isSet("-o"))
|
||||
{
|
||||
BitVector receiver_output, sender_output;
|
||||
char filename[1024];
|
||||
sprintf(filename, RECEIVER_INPUT, my_num);
|
||||
ofstream outf(filename);
|
||||
receiverInput.output(outf, false);
|
||||
outf.close();
|
||||
sprintf(filename, RECEIVER_OUTPUT, my_num);
|
||||
outf.open(filename);
|
||||
for (unsigned int i = 0; i < nOTs; i++)
|
||||
{
|
||||
receiver_output.assign_bytes((char*) tinfos[0].ot_ext->get_receiver_output(i), sizeof(__m128i));
|
||||
receiver_output.output(outf, false);
|
||||
}
|
||||
outf.close();
|
||||
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
sprintf(filename, SENDER_OUTPUT, my_num, i);
|
||||
outf.open(filename);
|
||||
for (int j = 0; j < nOTs; j++)
|
||||
{
|
||||
sender_output.assign_bytes((char*) tinfos[0].ot_ext->get_sender_output(i, j), sizeof(__m128i));
|
||||
sender_output.output(outf, false);
|
||||
}
|
||||
outf.close();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
delete players[i];
|
||||
delete tinfos[i].ot_ext;
|
||||
}
|
||||
}
|
||||
33
OT/OTMachine.h
Normal file
33
OT/OTMachine.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OTMachine.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_OTMACHINE_H_
|
||||
#define OT_OTMACHINE_H_
|
||||
|
||||
#include "OT/OTExtension.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
|
||||
class OTMachine
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
OT_ROLE ot_role;
|
||||
|
||||
public:
|
||||
int my_num, portnum_base, nthreads, nloops, nsubloops, nbase;
|
||||
long nOTs;
|
||||
bool passive;
|
||||
TwoPartyPlayer* P;
|
||||
BitVector baseReceiverInput;
|
||||
BaseOT* bot_;
|
||||
vector<Names> N;
|
||||
|
||||
OTMachine(int argc, const char** argv);
|
||||
~OTMachine();
|
||||
void run();
|
||||
};
|
||||
|
||||
#endif /* OT_OTMACHINE_H_ */
|
||||
164
OT/OTMultiplier.cpp
Normal file
164
OT/OTMultiplier.cpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OTMultiplier.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "OT/OTMultiplier.h"
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
template<class T>
|
||||
OTMultiplier<T>::OTMultiplier(NPartyTripleGenerator& generator,
|
||||
int thread_num) :
|
||||
generator(generator), thread_num(thread_num),
|
||||
rot_ext(128, 128, 0, 1,
|
||||
generator.players[thread_num], generator.baseReceiverInput,
|
||||
generator.baseSenderInputs[thread_num],
|
||||
generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check)
|
||||
{
|
||||
c_output.resize(generator.nTriplesPerLoop);
|
||||
pthread_mutex_init(&mutex, 0);
|
||||
pthread_cond_init(&ready, 0);
|
||||
thread = 0;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OTMultiplier<T>::~OTMultiplier()
|
||||
{
|
||||
pthread_mutex_destroy(&mutex);
|
||||
pthread_cond_destroy(&ready);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTMultiplier<T>::multiply()
|
||||
{
|
||||
BitVector keyBits(generator.field_size);
|
||||
keyBits.set_int128(0, generator.machine.get_mac_key<T>().to_m128i());
|
||||
rot_ext.extend<T>(generator.field_size, keyBits);
|
||||
vector< vector<BitVector> > senderOutput(128);
|
||||
vector<BitVector> receiverOutput;
|
||||
for (int j = 0; j < 128; j++)
|
||||
{
|
||||
senderOutput[j].resize(2);
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
senderOutput[j][i].resize(128);
|
||||
senderOutput[j][i].set_int128(0, rot_ext.senderOutputMatrices[i].squares[0].rows[j]);
|
||||
}
|
||||
}
|
||||
rot_ext.receiverOutputMatrix.to(receiverOutput);
|
||||
OTExtensionWithMatrix auth_ot_ext(128, 128, 0, 1,
|
||||
generator.players[thread_num], keyBits, senderOutput,
|
||||
receiverOutput, BOTH, true);
|
||||
|
||||
if (generator.machine.generateBits)
|
||||
multiplyForBits(auth_ot_ext);
|
||||
else
|
||||
multiplyForTriples(auth_ot_ext);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTMultiplier<T>::multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext)
|
||||
{
|
||||
auth_ot_ext.resize(generator.nPreampTriplesPerLoop * generator.field_size);
|
||||
|
||||
// dummy input for OT correlator
|
||||
vector<BitVector> _;
|
||||
vector< vector<BitVector> > __;
|
||||
BitVector ___;
|
||||
|
||||
OTExtensionWithMatrix otCorrelator(0, 0, 0, 0, generator.players[thread_num],
|
||||
___, __, _, BOTH, true);
|
||||
otCorrelator.resize(128 * generator.nPreampTriplesPerLoop);
|
||||
|
||||
rot_ext.resize(generator.field_size * generator.nPreampTriplesPerLoop + 2 * 128);
|
||||
|
||||
pthread_mutex_lock(&mutex);
|
||||
pthread_cond_signal(&ready);
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
|
||||
for (int i = 0; i < generator.nloops; i++)
|
||||
{
|
||||
BitVector aBits = generator.valueBits[0];
|
||||
//timers["Extension"].start();
|
||||
rot_ext.extend<T>(generator.field_size * generator.nPreampTriplesPerLoop, aBits);
|
||||
//timers["Extension"].stop();
|
||||
|
||||
//timers["Correlation"].start();
|
||||
otCorrelator.baseReceiverInput = aBits;
|
||||
otCorrelator.setup_for_correlation(rot_ext.senderOutputMatrices, rot_ext.receiverOutputMatrix);
|
||||
otCorrelator.correlate<T>(0, generator.nPreampTriplesPerLoop, generator.valueBits[1], false, generator.nAmplify);
|
||||
//timers["Correlation"].stop();
|
||||
|
||||
//timers["Triple computation"].start();
|
||||
|
||||
otCorrelator.reduce_squares(generator.nPreampTriplesPerLoop, c_output);
|
||||
|
||||
pthread_cond_signal(&ready);
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
|
||||
if (generator.machine.generateMACs)
|
||||
{
|
||||
macs.resize(3);
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
int nValues = generator.nTriplesPerLoop;
|
||||
if (generator.machine.check && (j % 2 == 0))
|
||||
nValues *= 2;
|
||||
auth_ot_ext.expand<T>(0, nValues);
|
||||
auth_ot_ext.correlate<T>(0, nValues, generator.valueBits[j], true);
|
||||
auth_ot_ext.reduce_squares(nValues, macs[j]);
|
||||
}
|
||||
|
||||
pthread_cond_signal(&ready);
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
}
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&mutex);
|
||||
}
|
||||
|
||||
template<>
|
||||
void OTMultiplier<gfp>::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext)
|
||||
{
|
||||
multiplyForTriples(auth_ot_ext);
|
||||
}
|
||||
|
||||
template<>
|
||||
void OTMultiplier<gf2n>::multiplyForBits(OTExtensionWithMatrix& auth_ot_ext)
|
||||
{
|
||||
int nBits = generator.nTriplesPerLoop + generator.field_size;
|
||||
int nBlocks = ceil(1.0 * nBits / generator.field_size);
|
||||
auth_ot_ext.resize(nBlocks * generator.field_size);
|
||||
macs.resize(1);
|
||||
macs[0].resize(nBits);
|
||||
|
||||
pthread_mutex_lock(&mutex);
|
||||
pthread_cond_signal(&ready);
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
|
||||
for (int i = 0; i < generator.nloops; i++)
|
||||
{
|
||||
auth_ot_ext.expand<gf2n>(0, nBlocks);
|
||||
auth_ot_ext.correlate<gf2n>(0, nBlocks, generator.valueBits[0], true);
|
||||
auth_ot_ext.transpose(0, nBlocks);
|
||||
|
||||
for (int j = 0; j < nBits; j++)
|
||||
{
|
||||
int128 r = auth_ot_ext.receiverOutputMatrix.squares[j/128].rows[j%128];
|
||||
int128 s = auth_ot_ext.senderOutputMatrices[0].squares[j/128].rows[j%128];
|
||||
macs[0][j] = r ^ s;
|
||||
}
|
||||
|
||||
pthread_cond_signal(&ready);
|
||||
pthread_cond_wait(&ready, &mutex);
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&mutex);
|
||||
}
|
||||
|
||||
template class OTMultiplier<gf2n>;
|
||||
template class OTMultiplier<gfp>;
|
||||
41
OT/OTMultiplier.h
Normal file
41
OT/OTMultiplier.h
Normal file
@@ -0,0 +1,41 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OTMultiplier.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_OTMULTIPLIER_H_
|
||||
#define OT_OTMULTIPLIER_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "OT/OTExtensionWithMatrix.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
class NPartyTripleGenerator;
|
||||
|
||||
template <class T>
|
||||
class OTMultiplier
|
||||
{
|
||||
void multiplyForTriples(OTExtensionWithMatrix& auth_ot_ext);
|
||||
void multiplyForBits(OTExtensionWithMatrix& auth_ot_ext);
|
||||
public:
|
||||
NPartyTripleGenerator& generator;
|
||||
int thread_num;
|
||||
OTExtensionWithMatrix rot_ext;
|
||||
//OTExtensionWithMatrix* auth_ot_ext;
|
||||
vector<T> c_output;
|
||||
vector< vector<T> > macs;
|
||||
|
||||
pthread_t thread;
|
||||
pthread_mutex_t mutex;
|
||||
pthread_cond_t ready;
|
||||
|
||||
OTMultiplier(NPartyTripleGenerator& generator, int thread_num);
|
||||
~OTMultiplier();
|
||||
void multiply();
|
||||
};
|
||||
|
||||
#endif /* OT_OTMULTIPLIER_H_ */
|
||||
44
OT/OTTripleSetup.cpp
Normal file
44
OT/OTTripleSetup.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "OTTripleSetup.h"
|
||||
|
||||
void OTTripleSetup::setup()
|
||||
{
|
||||
timeval baseOTstart, baseOTend;
|
||||
gettimeofday(&baseOTstart, NULL);
|
||||
|
||||
G.ReSeed();
|
||||
for (int i = 0; i < nbase; i++)
|
||||
{
|
||||
base_receiver_inputs[i] = G.get_uchar() & 1;
|
||||
}
|
||||
//baseReceiverInput.randomize(G);
|
||||
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
{
|
||||
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
|
||||
baseOTs[i]->exec_base(false);
|
||||
}
|
||||
gettimeofday(&baseOTend, NULL);
|
||||
double basetime = timeval_diff(&baseOTstart, &baseOTend);
|
||||
cout << "\t\tBaseTime: " << basetime/1000000 << endl << flush;
|
||||
|
||||
// Receiver send something to force synchronization
|
||||
// (since Sender finishes baseOTs before Receiver)
|
||||
}
|
||||
|
||||
void OTTripleSetup::close_connections()
|
||||
{
|
||||
for (size_t i = 0; i < players.size(); i++)
|
||||
{
|
||||
delete players[i];
|
||||
}
|
||||
}
|
||||
|
||||
OTTripleSetup::~OTTripleSetup()
|
||||
{
|
||||
for (size_t i = 0; i < baseOTs.size(); i++)
|
||||
{
|
||||
delete baseOTs[i];
|
||||
}
|
||||
}
|
||||
91
OT/OTTripleSetup.h
Normal file
91
OT/OTTripleSetup.h
Normal file
@@ -0,0 +1,91 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef OT_TRIPLESETUP_H_
|
||||
#define OT_TRIPLESETUP_H_
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "OT/BaseOT.h"
|
||||
#include "OT/OTMachine.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
/*
|
||||
* Class for creating and storing base OTs between every pair of parties.
|
||||
*/
|
||||
class OTTripleSetup
|
||||
{
|
||||
vector<int> base_receiver_inputs;
|
||||
vector< vector< vector<BitVector> > > baseSenderInputs;
|
||||
vector< vector<BitVector> > baseReceiverOutputs;
|
||||
|
||||
PRNG G;
|
||||
int nparties;
|
||||
int my_num;
|
||||
int nbase;
|
||||
bool real_OTs;
|
||||
|
||||
public:
|
||||
map<string,Timer> timers;
|
||||
vector<BaseOT*> baseOTs;
|
||||
vector<TwoPartyPlayer*> players;
|
||||
|
||||
int get_nparties() { return nparties; }
|
||||
int get_nbase() { return nbase; }
|
||||
int get_my_num() { return my_num; }
|
||||
int get_base_receiver_input(int i) { return base_receiver_inputs[i]; }
|
||||
|
||||
OTTripleSetup(Names& N, bool real_OTs)
|
||||
: nparties(N.num_players()), my_num(N.my_num()), nbase(128), real_OTs(real_OTs)
|
||||
{
|
||||
base_receiver_inputs.resize(nbase);
|
||||
players.resize(nparties - 1);
|
||||
baseOTs.resize(nparties - 1);
|
||||
baseSenderInputs.resize(nparties - 1);
|
||||
baseReceiverOutputs.resize(nparties - 1);
|
||||
|
||||
if (real_OTs)
|
||||
cout << "Doing real base OTs\n";
|
||||
else
|
||||
cout << "Doing fake base OTs\n";
|
||||
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
{
|
||||
int other_player, id;
|
||||
// i for indexing, other_player is actual number
|
||||
if (i >= my_num)
|
||||
other_player = i + 1;
|
||||
else
|
||||
other_player = i;
|
||||
// unique id per pair of parties (to assign port no.)
|
||||
if (my_num < other_player)
|
||||
id = my_num*nparties + other_player;
|
||||
else
|
||||
id = other_player*nparties + my_num;
|
||||
|
||||
players[i] = new TwoPartyPlayer(N, other_player, id);
|
||||
|
||||
// sets up a pair of base OTs, playing both roles
|
||||
if (real_OTs)
|
||||
{
|
||||
baseOTs[i] = new BaseOT(nbase, 128, players[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
baseOTs[i] = new FakeOT(nbase, 128, players[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
~OTTripleSetup();
|
||||
|
||||
// run the Base OTs
|
||||
void setup();
|
||||
// close down the sockets
|
||||
void close_connections();
|
||||
|
||||
//template <class T>
|
||||
//T get_mac_key();
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
13
OT/OText_main.cpp
Normal file
13
OT/OText_main.cpp
Normal file
@@ -0,0 +1,13 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* OText_main.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "OTMachine.h"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
OTMachine(argc, argv).run();
|
||||
}
|
||||
15
OT/OutputCheck.h
Normal file
15
OT/OutputCheck.h
Normal file
@@ -0,0 +1,15 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* check.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_OUTPUTCHECK_H_
|
||||
#define OT_OUTPUTCHECK_H_
|
||||
|
||||
#define RECEIVER_INPUT "Player-Data/OT-receiver%d-input"
|
||||
#define RECEIVER_OUTPUT "Player-Data/OT-receiver%d-output"
|
||||
#define SENDER_OUTPUT "Player-Data/OT-sender%d-output%d"
|
||||
|
||||
#endif /* OT_OUTPUTCHECK_H_ */
|
||||
107
OT/Tools.cpp
Normal file
107
OT/Tools.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Tools.h"
|
||||
#include "Math/gf2nlong.h"
|
||||
|
||||
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
vector<octetStream> seed_strm(2);
|
||||
vector<octetStream> Comm_seed(2);
|
||||
vector<octetStream> Open_seed(2);
|
||||
G.get_octetStream(seed_strm[0], len);
|
||||
|
||||
Commit(Comm_seed[0], Open_seed[0], seed_strm[0], player.my_num());
|
||||
player.send_receive_player(Comm_seed);
|
||||
player.send_receive_player(Open_seed);
|
||||
|
||||
memset(seed, 0, len*sizeof(octet));
|
||||
|
||||
if (!Open(seed_strm[1], Comm_seed[1], Open_seed[1], player.other_player_num()))
|
||||
{
|
||||
throw invalid_commitment();
|
||||
}
|
||||
for (int i = 0; i < len; i++)
|
||||
{
|
||||
seed[i] = seed_strm[0].get_data()[i] ^ seed_strm[1].get_data()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void shiftl128(word x1, word x2, word& res1, word& res2, size_t k)
|
||||
{
|
||||
if (k > 128)
|
||||
throw invalid_length();
|
||||
if (k >= 64) // shifting a 64-bit integer by more than 63 bits is "undefined"
|
||||
{
|
||||
x1 = x2;
|
||||
x2 = 0;
|
||||
shiftl128(x1, x2, res1, res2, k - 64);
|
||||
}
|
||||
else
|
||||
{
|
||||
res1 = (x1 << k) | (x2 >> (64-k));
|
||||
res2 = (x2 << k);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce modulo x^128 + x^7 + x^2 + x + 1
|
||||
// NB this is incorrect as it bit-reflects the result as required for
|
||||
// GCM mode
|
||||
void gfred128(__m128i tmp3, __m128i tmp6, __m128i *res)
|
||||
{
|
||||
__m128i tmp2, tmp4, tmp5, tmp7, tmp8, tmp9;
|
||||
tmp7 = _mm_srli_epi32(tmp3, 31);
|
||||
tmp8 = _mm_srli_epi32(tmp6, 31);
|
||||
|
||||
tmp3 = _mm_slli_epi32(tmp3, 1);
|
||||
tmp6 = _mm_slli_epi32(tmp6, 1);
|
||||
|
||||
tmp9 = _mm_srli_si128(tmp7, 12);
|
||||
tmp8 = _mm_slli_si128(tmp8, 4);
|
||||
tmp7 = _mm_slli_si128(tmp7, 4);
|
||||
tmp3 = _mm_or_si128(tmp3, tmp7);
|
||||
tmp6 = _mm_or_si128(tmp6, tmp8);
|
||||
tmp6 = _mm_or_si128(tmp6, tmp9);
|
||||
|
||||
tmp7 = _mm_slli_epi32(tmp3, 31);
|
||||
tmp8 = _mm_slli_epi32(tmp3, 30);
|
||||
tmp9 = _mm_slli_epi32(tmp3, 25);
|
||||
|
||||
tmp7 = _mm_xor_si128(tmp7, tmp8);
|
||||
tmp7 = _mm_xor_si128(tmp7, tmp9);
|
||||
tmp8 = _mm_srli_si128(tmp7, 4);
|
||||
tmp7 = _mm_slli_si128(tmp7, 12);
|
||||
tmp3 = _mm_xor_si128(tmp3, tmp7);
|
||||
|
||||
tmp2 = _mm_srli_epi32(tmp3, 1);
|
||||
tmp4 = _mm_srli_epi32(tmp3, 2);
|
||||
tmp5 = _mm_srli_epi32(tmp3, 7);
|
||||
tmp2 = _mm_xor_si128(tmp2, tmp4);
|
||||
tmp2 = _mm_xor_si128(tmp2, tmp5);
|
||||
tmp2 = _mm_xor_si128(tmp2, tmp8);
|
||||
tmp3 = _mm_xor_si128(tmp3, tmp2);
|
||||
|
||||
tmp6 = _mm_xor_si128(tmp6, tmp3);
|
||||
*res = tmp6;
|
||||
}
|
||||
|
||||
// Based on Intel's code for GF(2^128) mul, with reduction
|
||||
void gfmul128 (__m128i a, __m128i b, __m128i *res)
|
||||
{
|
||||
__m128i tmp3, tmp6;
|
||||
mul128(a, b, &tmp3, &tmp6);
|
||||
// Now do the reduction
|
||||
gfred128(tmp3, tmp6, res);
|
||||
}
|
||||
|
||||
string word_to_bytes(const word w)
|
||||
{
|
||||
stringstream ss;
|
||||
octet* bytes = (octet*) &w;
|
||||
ss << hex;
|
||||
for (unsigned int i = 0; i < sizeof(word); i++)
|
||||
ss << (int)bytes[i] << " ";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
51
OT/Tools.h
Normal file
51
OT/Tools.h
Normal file
@@ -0,0 +1,51 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _OTTOOLS
|
||||
#define _OTTOOLS
|
||||
|
||||
#include "Networking/Player.h"
|
||||
#include "Tools/Commit.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
#define SEED_SIZE_BYTES SEED_SIZE
|
||||
|
||||
/*
|
||||
* Generate a secure, random seed between 2 parties via commitment
|
||||
*/
|
||||
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len);
|
||||
|
||||
/*
|
||||
* GF(2^128) multiplication using Intel instructions
|
||||
* (should this go in gf2n class???)
|
||||
*/
|
||||
void gfmul128(__m128i a, __m128i b, __m128i *res);
|
||||
void gfred128(__m128i a1, __m128i a2, __m128i *res);
|
||||
|
||||
//#if defined(__SSE2__)
|
||||
/*
|
||||
* Convert __m128i to string of type T
|
||||
*/
|
||||
template <typename T>
|
||||
string __m128i_toString(const __m128i var) {
|
||||
stringstream sstr;
|
||||
sstr << hex;
|
||||
const T* values = (const T*) &var;
|
||||
if (sizeof(T) == 1) {
|
||||
for (unsigned int i = 0; i < sizeof(__m128i); i++) {
|
||||
sstr << (int) values[i] << " ";
|
||||
}
|
||||
} else {
|
||||
for (unsigned int i = 0; i < sizeof(__m128i) / sizeof(T); i++) {
|
||||
sstr << values[i] << " ";
|
||||
}
|
||||
}
|
||||
return sstr.str();
|
||||
}
|
||||
//#endif
|
||||
|
||||
string word_to_bytes(const word w);
|
||||
|
||||
void shiftl128(word x1, word x2, word& res1, word& res2, size_t k);
|
||||
|
||||
|
||||
#endif
|
||||
270
OT/TripleMachine.cpp
Normal file
270
OT/TripleMachine.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* TripleMachine.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include <OT/TripleMachine.h>
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/OTMachine.h"
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
template <class T>
|
||||
void* run_ngenerator_thread(void* ptr)
|
||||
{
|
||||
((NPartyTripleGenerator*)ptr)->generate<T>();
|
||||
return 0;
|
||||
}
|
||||
|
||||
TripleMachine::TripleMachine(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
opt.add(
|
||||
"2", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of parties (default: 2).", // Help description.
|
||||
"-N", // Flag token.
|
||||
"--nparties" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
1, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"This player's number, 0/1 (required).", // Help description.
|
||||
"-p", // Flag token.
|
||||
"--player" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"1",
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
"Number of threads (default: 1).",
|
||||
"-x",
|
||||
"--nthreads"
|
||||
);
|
||||
opt.add(
|
||||
"1000", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of triples (default: 1000).", // Help description.
|
||||
"-n", // Flag token.
|
||||
"--ntriples" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"1", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of loops (default: 1).", // Help description.
|
||||
"-l", // Flag token.
|
||||
"--nloops" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Generate MACs (implies -a).", // Help description.
|
||||
"-m", // Flag token.
|
||||
"--macs" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Amplify triples.", // Help description.
|
||||
"-a", // Flag token.
|
||||
"--amplify" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Check triples (implies -m).", // Help description.
|
||||
"-c", // Flag token.
|
||||
"--check" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"GF(p) triples", // Help description.
|
||||
"-P", // Flag token.
|
||||
"--prime-field" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Channel bonding", // Help description.
|
||||
"-b", // Flag token.
|
||||
"--bonding" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Generate bits", // Help description.
|
||||
"-B", // Flag token.
|
||||
"--bits" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Write results to files.", // Help description.
|
||||
"-o", // Flag token.
|
||||
"--output" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
opt.get("-p")->getInt(my_num);
|
||||
opt.get("-N")->getInt(nplayers);
|
||||
opt.get("-x")->getInt(nthreads);
|
||||
opt.get("-n")->getInt(ntriples);
|
||||
opt.get("-l")->getInt(nloops);
|
||||
generateBits = opt.get("-B")->isSet;
|
||||
check = opt.get("-c")->isSet || generateBits;
|
||||
generateMACs = opt.get("-m")->isSet || check;
|
||||
amplify = opt.get("-a")->isSet || generateMACs;
|
||||
primeField = opt.get("-P")->isSet;
|
||||
bonding = opt.get("-b")->isSet;
|
||||
output = opt.get("-o")->isSet;
|
||||
|
||||
if (!opt.isSet("-p"))
|
||||
{
|
||||
string usage;
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
nTriplesPerThread = DIV_CEIL(ntriples, nthreads);
|
||||
|
||||
prep_data_dir = get_prep_dir(nplayers, 128, 128);
|
||||
ofstream outf;
|
||||
bigint p;
|
||||
generate_online_setup(outf, prep_data_dir, p, 128, 128);
|
||||
// doesn't work with Montgomery multiplication
|
||||
gfp::init_field(p, false);
|
||||
gf2n::init_field(128);
|
||||
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
mac_key2.randomize(G);
|
||||
mac_keyp.randomize(G);
|
||||
}
|
||||
|
||||
void TripleMachine::run()
|
||||
{
|
||||
cout << "my_num: " << my_num << endl;
|
||||
Names N[2];
|
||||
N[0].init(my_num, nplayers, 10000, "HOSTS");
|
||||
int nConnections = 1;
|
||||
if (bonding)
|
||||
{
|
||||
N[1].init(my_num, nplayers, 11000, "HOSTS2");
|
||||
nConnections = 2;
|
||||
}
|
||||
// do the base OTs
|
||||
OTTripleSetup setup(N[0], false);
|
||||
setup.setup();
|
||||
setup.close_connections();
|
||||
|
||||
vector<NPartyTripleGenerator*> generators(nthreads);
|
||||
vector<pthread_t> threads(nthreads);
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
generators[i] = new NPartyTripleGenerator(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this);
|
||||
}
|
||||
ntriples = generators[0]->nTriples * nthreads;
|
||||
cout <<"Setup generators\n";
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
// lock before starting thread to avoid race condition
|
||||
generators[i]->lock();
|
||||
if (primeField)
|
||||
pthread_create(&threads[i], 0, run_ngenerator_thread<gfp>, generators[i]);
|
||||
else
|
||||
pthread_create(&threads[i], 0, run_ngenerator_thread<gf2n>, generators[i]);
|
||||
}
|
||||
|
||||
// wait for initialization, then start clock and computation
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
generators[i]->wait();
|
||||
cout << "Starting computation" << endl;
|
||||
gettimeofday(&start, 0);
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
generators[i]->signal();
|
||||
generators[i]->unlock();
|
||||
}
|
||||
|
||||
// wait for threads to finish
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
pthread_join(threads[i],NULL);
|
||||
cout << "thread " << i+1 << " finished\n" << flush;
|
||||
}
|
||||
|
||||
map<string,Timer>& timers = generators[0]->timers;
|
||||
for (map<string,Timer>::iterator it = timers.begin(); it != timers.end(); it++)
|
||||
{
|
||||
double sum = 0;
|
||||
for (size_t i = 0; i < generators.size(); i++)
|
||||
sum += generators[i]->timers[it->first].elapsed();
|
||||
cout << it->first << " on average took time "
|
||||
<< sum / generators.size() << endl;
|
||||
}
|
||||
|
||||
gettimeofday(&stop, 0);
|
||||
double time = timeval_diff_in_seconds(&start, &stop);
|
||||
cout << "Time: " << time << endl;
|
||||
cout << "Throughput: " << ntriples / time << endl;
|
||||
|
||||
for (size_t i = 0; i < generators.size(); i++)
|
||||
delete generators[i];
|
||||
|
||||
output_mac_keys();
|
||||
}
|
||||
|
||||
void TripleMachine::output_mac_keys()
|
||||
{
|
||||
stringstream ss;
|
||||
ss << prep_data_dir << "Player-MAC-Keys-P" << my_num;
|
||||
cout << "Writing MAC key to " << ss.str() << endl;
|
||||
ofstream outputFile(ss.str().c_str());
|
||||
outputFile << nplayers << endl;
|
||||
outputFile << mac_keyp << " " << mac_key2 << endl;
|
||||
}
|
||||
|
||||
template<> gf2n TripleMachine::get_mac_key()
|
||||
{
|
||||
return mac_key2;
|
||||
}
|
||||
|
||||
template<> gfp TripleMachine::get_mac_key()
|
||||
{
|
||||
return mac_keyp;
|
||||
}
|
||||
40
OT/TripleMachine.h
Normal file
40
OT/TripleMachine.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* TripleMachine.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef OT_TRIPLEMACHINE_H_
|
||||
#define OT_TRIPLEMACHINE_H_
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
class TripleMachine
|
||||
{
|
||||
gf2n mac_key2;
|
||||
gfp mac_keyp;
|
||||
|
||||
public:
|
||||
int my_num, nplayers, nthreads, ntriples, nloops;
|
||||
int nTriplesPerThread;
|
||||
string prep_data_dir;
|
||||
bool generateMACs;
|
||||
bool amplify;
|
||||
bool check;
|
||||
bool primeField;
|
||||
bool bonding;
|
||||
bool generateBits;
|
||||
bool output;
|
||||
struct timeval start, stop;
|
||||
|
||||
TripleMachine(int argc, const char** argv);
|
||||
void run();
|
||||
|
||||
template <class T>
|
||||
T get_mac_key();
|
||||
void output_mac_keys();
|
||||
};
|
||||
|
||||
#endif /* OT_TRIPLEMACHINE_H_ */
|
||||
179
Player-Online.cpp
Normal file
179
Player-Online.cpp
Normal file
@@ -0,0 +1,179 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Processor/Machine.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
|
||||
opt.syntax = "./Player-Online.x [OPTIONS] <playernum> <progname>\n";
|
||||
opt.example = "./Player-Online.x -lgp 64 -lg2 128 -m new 0 sample-prog\n./Player-Online.x -pn 13000 -h localhost 1 sample-prog\n";
|
||||
|
||||
opt.add(
|
||||
"128", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(p) field (default: 128)", // Help description.
|
||||
"-lgp", // Flag token.
|
||||
"--lgp" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"40", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(2^n) field (default: 40)", // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"5000", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Port number base to attempt to start connections from (default: 5000)", // Help description.
|
||||
"-pn", // Flag token.
|
||||
"--portnumbase" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"localhost", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Host where Server.x is running (default: localhost)", // Help description.
|
||||
"-h", // Flag token.
|
||||
"--hostname" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"empty", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Where to obtain memory, new|old|empty (default: empty)\n\t"
|
||||
"new: copy from Player-Memory-P<i> file\n\t"
|
||||
"old: reuse previous memory in Memory-P<i>\n\t"
|
||||
"empty: create new empty memory", // Help description.
|
||||
"-m", // Flag token.
|
||||
"--memory" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Direct communication instead of star-shaped", // Help description.
|
||||
"-d", // Flag token.
|
||||
"--direct" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Star-shaped communication handled by background threads", // Help description.
|
||||
"-P", // Flag token.
|
||||
"--parallel" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"0", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Sum at most n shares at once when using indirect communication", // Help description.
|
||||
"-s", // Flag token.
|
||||
"--opening-sum" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Use player-specific threads for communication", // Help description.
|
||||
"-t", // Flag token.
|
||||
"--threads" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"0", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Maximum number of parties to send to at once", // Help description.
|
||||
"-b", // Flag token.
|
||||
"--max-broadcast" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
|
||||
vector<string*> allArgs(opt.firstArgs);
|
||||
allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end());
|
||||
string progname;
|
||||
int playerno;
|
||||
string usage;
|
||||
vector<string> badOptions;
|
||||
unsigned int i;
|
||||
|
||||
if (allArgs.size() != 3)
|
||||
{
|
||||
cerr << "ERROR: incorrect number of arguments to Player-Online.x\n";
|
||||
cerr << "Arguments given were:\n";
|
||||
for (unsigned int j = 1; j < allArgs.size(); j++)
|
||||
cout << "'" << *allArgs[j] << "'" << endl;
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
playerno = atoi(allArgs[1]->c_str());
|
||||
progname = *allArgs[2];
|
||||
|
||||
}
|
||||
|
||||
if(!opt.gotRequired(badOptions))
|
||||
{
|
||||
for (i=0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Missing required option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(!opt.gotExpected(badOptions))
|
||||
{
|
||||
for(i=0; i < badOptions.size(); ++i)
|
||||
cerr << "ERROR: Got unexpected number of arguments for option " << badOptions[i] << ".";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
string memtype, hostname;
|
||||
int lg2, lgp, pnbase, opening_sum, max_broadcast;
|
||||
|
||||
opt.get("--portnumbase")->getInt(pnbase);
|
||||
opt.get("--lgp")->getInt(lgp);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
opt.get("--memory")->getString(memtype);
|
||||
opt.get("--hostname")->getString(hostname);
|
||||
opt.get("--opening-sum")->getInt(opening_sum);
|
||||
opt.get("--max-broadcast")->getInt(max_broadcast);
|
||||
|
||||
|
||||
Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2,
|
||||
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
|
||||
opt.get("--threads")->isSet, max_broadcast).run();
|
||||
|
||||
cerr << "Command line:";
|
||||
for (int i = 0; i < argc; i++)
|
||||
cerr << " " << argv[i];
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
|
||||
145
Processor/Buffer.cpp
Normal file
145
Processor/Buffer.cpp
Normal file
@@ -0,0 +1,145 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Buffer.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Buffer.h"
|
||||
#include "Processor/InputTuple.h"
|
||||
|
||||
bool BufferBase::rewind = false;
|
||||
|
||||
|
||||
void BufferBase::setup(ifstream* f, int length, const char* type)
|
||||
{
|
||||
file = f;
|
||||
tuple_length = length;
|
||||
data_type = type;
|
||||
}
|
||||
|
||||
void BufferBase::seekg(int pos)
|
||||
{
|
||||
file->seekg(pos * tuple_length);
|
||||
if (file->eof() || file->fail())
|
||||
{
|
||||
file->clear();
|
||||
file->seekg(0);
|
||||
if (!rewind)
|
||||
cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl;
|
||||
rewind = true;
|
||||
}
|
||||
next = BUFFER_SIZE;
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
Buffer<T, U>::~Buffer()
|
||||
{
|
||||
if (timer.elapsed() && data_type)
|
||||
cerr << T::type_string() << " " << data_type << " reading: "
|
||||
<< timer.elapsed() << endl;
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
void Buffer<T, U>::fill_buffer()
|
||||
{
|
||||
if (T::size() == sizeof(T))
|
||||
{
|
||||
// read directly
|
||||
read((char*)buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
char read_buffer[sizeof(buffer)];
|
||||
read(read_buffer);
|
||||
//memset(buffer, 0, sizeof(buffer));
|
||||
for (int i = 0; i < BUFFER_SIZE; i++)
|
||||
buffer[i].assign(&read_buffer[i*T::size()]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
void Buffer<T, U>::read(char* read_buffer)
|
||||
{
|
||||
int size_in_bytes = T::size() * BUFFER_SIZE;
|
||||
int n_read = 0;
|
||||
timer.start();
|
||||
do
|
||||
{
|
||||
file->read(read_buffer + n_read, size_in_bytes - n_read);
|
||||
n_read += file->gcount();
|
||||
if (file->eof())
|
||||
{
|
||||
file->clear(); // unset EOF flag
|
||||
file->seekg(0);
|
||||
if (!rewind)
|
||||
cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl;
|
||||
rewind = true;
|
||||
eof = true;
|
||||
}
|
||||
if (file->fail())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "IO problem when buffering " << T::type_string();
|
||||
if (data_type)
|
||||
ss << " " << data_type;
|
||||
throw file_error(ss.str());
|
||||
}
|
||||
}
|
||||
while (n_read < size_in_bytes);
|
||||
timer.stop();
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
void Buffer<T,U>::input(U& a)
|
||||
{
|
||||
if (next == BUFFER_SIZE)
|
||||
{
|
||||
fill_buffer();
|
||||
next = 0;
|
||||
}
|
||||
|
||||
a = buffer[next];
|
||||
next++;
|
||||
}
|
||||
|
||||
template < template<class T> class U, template<class T> class V >
|
||||
BufferBase& BufferHelper<U,V>::get_buffer(DataFieldType field_type)
|
||||
{
|
||||
if (field_type == DATA_MODP)
|
||||
return bufferp;
|
||||
else if (field_type == DATA_GF2N)
|
||||
return buffer2;
|
||||
else
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template < template<class T> class U, template<class T> class V >
|
||||
void BufferHelper<U,V>::setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type)
|
||||
{
|
||||
files[field_type] = new ifstream(filename.c_str(), ios::in | ios::binary);
|
||||
if (files[field_type]->fail())
|
||||
throw file_error(filename);
|
||||
get_buffer(field_type).setup(files[field_type], tuple_length, data_type);
|
||||
}
|
||||
|
||||
template<template<class T> class U, template<class T> class V>
|
||||
void BufferHelper<U,V>::close()
|
||||
{
|
||||
for (int i = 0; i < N_DATA_FIELD_TYPE; i++)
|
||||
if (files[i])
|
||||
{
|
||||
files[i]->close();
|
||||
delete files[i];
|
||||
}
|
||||
}
|
||||
|
||||
template class Buffer< Share<gfp>, Share<gfp> >;
|
||||
template class Buffer< Share<gf2n>, Share<gf2n> >;
|
||||
template class Buffer< InputTuple<gfp>, RefInputTuple<gfp> >;
|
||||
template class Buffer< InputTuple<gf2n>, RefInputTuple<gf2n> >;
|
||||
template class Buffer< gfp, gfp >;
|
||||
template class Buffer< gf2n, gf2n >;
|
||||
|
||||
template class BufferHelper<Share, Share>;
|
||||
template class BufferHelper<InputTuple, RefInputTuple>;
|
||||
74
Processor/Buffer.h
Normal file
74
Processor/Buffer.h
Normal file
@@ -0,0 +1,74 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Buffer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_BUFFER_H_
|
||||
#define PROCESSOR_BUFFER_H_
|
||||
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
#include "Math/Share.h"
|
||||
#include "Math/field_types.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
#ifndef BUFFER_SIZE
|
||||
#define BUFFER_SIZE 101
|
||||
#endif
|
||||
|
||||
|
||||
class BufferBase
|
||||
{
|
||||
protected:
|
||||
static bool rewind;
|
||||
|
||||
ifstream* file;
|
||||
int next;
|
||||
const char* data_type;
|
||||
Timer timer;
|
||||
int tuple_length;
|
||||
|
||||
public:
|
||||
bool eof;
|
||||
|
||||
BufferBase() : file(0), next(BUFFER_SIZE), data_type(0), tuple_length(-1), eof(false) {};
|
||||
void setup(ifstream* f, int length, const char* type = 0);
|
||||
void seekg(int pos);
|
||||
bool is_up() { return file != 0; }
|
||||
};
|
||||
|
||||
|
||||
template <class T, class U>
|
||||
class Buffer : public BufferBase
|
||||
{
|
||||
T buffer[BUFFER_SIZE];
|
||||
|
||||
void read(char* read_buffer);
|
||||
|
||||
public:
|
||||
~Buffer();
|
||||
void input(U& a);
|
||||
void fill_buffer();
|
||||
};
|
||||
|
||||
|
||||
template < template<class T> class U, template<class T> class V >
|
||||
class BufferHelper
|
||||
{
|
||||
public:
|
||||
Buffer< U<gfp>, V<gfp> > bufferp;
|
||||
Buffer< U<gf2n>, V<gf2n> > buffer2;
|
||||
ifstream* files[N_DATA_FIELD_TYPE];
|
||||
|
||||
BufferHelper() { memset(files, 0, sizeof(files)); }
|
||||
void input(V<gfp>& a) { bufferp.input(a); }
|
||||
void input(V<gf2n>& a) { buffer2.input(a); }
|
||||
BufferBase& get_buffer(DataFieldType field_type);
|
||||
void setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type = 0);
|
||||
void close();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_BUFFER_H_ */
|
||||
218
Processor/Data_Files.cpp
Normal file
218
Processor/Data_Files.cpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Processor/Processor.h"
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
const char* Data_Files::dtype_names[N_DTYPE] = { "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples" };
|
||||
const char* Data_Files::field_names[] = { "p", "2" };
|
||||
const char* Data_Files::long_field_names[] = { "gfp", "gf2n" };
|
||||
const bool Data_Files::implemented[N_DATA_FIELD_TYPE][N_DTYPE] = {
|
||||
{ true, true, true, true, false, false },
|
||||
{ true, true, true, true, true, true },
|
||||
};
|
||||
const int Data_Files::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 };
|
||||
|
||||
Lock Data_Files::tuple_lengths_lock;
|
||||
map<DataTag, int> Data_Files::tuple_lengths;
|
||||
|
||||
|
||||
void DataPositions::set_num_players(int num_players)
|
||||
{
|
||||
files.resize(N_DATA_FIELD_TYPE, vector<int>(N_DTYPE));
|
||||
inputs.resize(num_players, vector<int>(N_DATA_FIELD_TYPE));
|
||||
}
|
||||
|
||||
void DataPositions::increase(const DataPositions& delta)
|
||||
{
|
||||
if (inputs.size() != delta.inputs.size())
|
||||
throw invalid_length();
|
||||
for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
files[field_type][dtype] += delta.files[field_type][dtype];
|
||||
for (unsigned int j = 0; j < inputs.size(); j++)
|
||||
inputs[j][field_type] += delta.inputs[j][field_type];
|
||||
|
||||
map<DataTag, int>::const_iterator it;
|
||||
const map<DataTag, int>& delta_ext = delta.extended[field_type];
|
||||
for (it = delta_ext.begin(); it != delta_ext.end(); it++)
|
||||
extended[field_type][it->first] += it->second;
|
||||
}
|
||||
}
|
||||
|
||||
void DataPositions::print_cost() const
|
||||
{
|
||||
ifstream file("cost");
|
||||
double total_cost = 0;
|
||||
for (int i = 0; i < N_DATA_FIELD_TYPE; i++)
|
||||
{
|
||||
cerr << " Type " << Data_Files::field_names[i] << endl;
|
||||
for (int j = 0; j < N_DTYPE; j++)
|
||||
{
|
||||
double cost_per_item = 0;
|
||||
file >> cost_per_item;
|
||||
if (cost_per_item < 0)
|
||||
break;
|
||||
int items_used = files[i][j];
|
||||
double cost = items_used * cost_per_item;
|
||||
total_cost += cost;
|
||||
cerr.fill(' ');
|
||||
cerr << " " << setw(10) << cost << " = " << setw(10) << items_used
|
||||
<< " " << setw(14) << Data_Files::dtype_names[j] << " à " << setw(11)
|
||||
<< cost_per_item << endl;
|
||||
}
|
||||
for (map<DataTag, int>::const_iterator it = extended[i].begin();
|
||||
it != extended[i].end(); it++)
|
||||
{
|
||||
cerr.fill(' ');
|
||||
cerr << setw(27) << it->second << " " << setw(14) << it->first.get_string() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "Total cost: " << total_cost << endl;
|
||||
}
|
||||
|
||||
|
||||
int Data_Files::share_length(int field_type)
|
||||
{
|
||||
switch (field_type)
|
||||
{
|
||||
case DATA_MODP:
|
||||
return 2 * gfp::t() * sizeof(mp_limb_t);
|
||||
case DATA_GF2N:
|
||||
return 2 * sizeof(word);
|
||||
default:
|
||||
throw invalid_params();
|
||||
}
|
||||
}
|
||||
|
||||
int Data_Files::tuple_length(int field_type, int dtype)
|
||||
{
|
||||
return tuple_size[dtype] * share_length(field_type);
|
||||
}
|
||||
|
||||
Data_Files::Data_Files(int myn, int n, const string& prep_data_dir) :
|
||||
usage(n), prep_data_dir(prep_data_dir)
|
||||
{
|
||||
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
|
||||
num_players=n;
|
||||
my_num=myn;
|
||||
char filename[1024];
|
||||
input_buffers = new BufferHelper<Share, Share>[num_players];
|
||||
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
{
|
||||
if (implemented[field_type][dtype])
|
||||
{
|
||||
sprintf(filename,(prep_data_dir + "%s-%s-P%d").c_str(),dtype_names[dtype],
|
||||
field_names[field_type],my_num);
|
||||
buffers[dtype].setup(DataFieldType(field_type), filename,
|
||||
tuple_length(field_type, dtype), dtype_names[dtype]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i=0; i<num_players; i++)
|
||||
{
|
||||
sprintf(filename,(prep_data_dir + "Inputs-%s-P%d-%d").c_str(),
|
||||
field_names[field_type],my_num,i);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(DataFieldType(field_type), filename,
|
||||
share_length(field_type) * 3 / 2);
|
||||
else
|
||||
input_buffers[i].setup(DataFieldType(field_type), filename,
|
||||
share_length(field_type));
|
||||
}
|
||||
}
|
||||
|
||||
cerr << "done\n";
|
||||
}
|
||||
|
||||
Data_Files::~Data_Files()
|
||||
{
|
||||
for (int i = 0; i < N_DTYPE; i++)
|
||||
buffers[i].close();
|
||||
for (int i = 0; i < num_players; i++)
|
||||
input_buffers[i].close();
|
||||
delete[] input_buffers;
|
||||
my_input_buffers.close();
|
||||
for (map<DataTag, BufferHelper<Share, Share> >::iterator it =
|
||||
extended.begin(); it != extended.end(); it++)
|
||||
it->second.close();
|
||||
}
|
||||
|
||||
void Data_Files::seekg(DataPositions& pos)
|
||||
{
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
if (implemented[field_type][dtype])
|
||||
buffers[dtype].get_buffer(DataFieldType(field_type)).seekg(pos.files[field_type][dtype]);
|
||||
for (int j = 0; j < num_players; j++)
|
||||
if (j == my_num)
|
||||
my_input_buffers.get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]);
|
||||
else
|
||||
input_buffers[j].get_buffer(DataFieldType(field_type)).seekg(pos.inputs[j][field_type]);
|
||||
for (map<DataTag, int>::const_iterator it = pos.extended[field_type].begin();
|
||||
it != pos.extended[field_type].end(); it++)
|
||||
{
|
||||
setup_extended(DataFieldType(field_type), it->first);
|
||||
extended[it->first].get_buffer(DataFieldType(field_type)).seekg(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
usage = pos;
|
||||
}
|
||||
|
||||
void Data_Files::skip(const DataPositions& pos)
|
||||
{
|
||||
DataPositions new_pos = usage;
|
||||
new_pos.increase(pos);
|
||||
seekg(new_pos);
|
||||
}
|
||||
|
||||
void Data_Files::setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size)
|
||||
{
|
||||
BufferBase& buffer = extended[tag].get_buffer(field_type);
|
||||
tuple_lengths_lock.lock();
|
||||
int tuple_length = tuple_lengths[tag];
|
||||
int my_tuple_length = tuple_size * share_length(field_type);
|
||||
if (tuple_length > 0)
|
||||
{
|
||||
if (tuple_size > 0 && my_tuple_length != tuple_length)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Inconsistent size of " << field_names[field_type] << " "
|
||||
<< tag.get_string() << ": " << my_tuple_length << " vs "
|
||||
<< tuple_length;
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
tuple_lengths[tag] = my_tuple_length;
|
||||
tuple_lengths_lock.unlock();
|
||||
|
||||
if (!buffer.is_up())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << prep_data_dir << tag.get_string() << "-" << field_names[field_type] << "-P" << my_num;
|
||||
extended[tag].setup(field_type, ss.str(), tuple_length);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Data_Files::get(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size)
|
||||
{
|
||||
usage.extended[T::field_type()][tag] += vector_size;
|
||||
setup_extended(T::field_type(), tag, regs.size());
|
||||
for (int j = 0; j < vector_size; j++)
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
extended[tag].input(proc.get_S_ref<T>(regs[i] + j));
|
||||
}
|
||||
|
||||
template void Data_Files::get<gfp>(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
template void Data_Files::get<gf2n>(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
158
Processor/Data_Files.h
Normal file
158
Processor/Data_Files.h
Normal file
@@ -0,0 +1,158 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Data_Files
|
||||
#define _Data_Files
|
||||
|
||||
/* This class holds the Online data files all in one place
|
||||
* so the streams are easy to pass around and access
|
||||
*/
|
||||
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/Share.h"
|
||||
#include "Math/field_types.h"
|
||||
#include "Processor/Buffer.h"
|
||||
#include "Processor/InputTuple.h"
|
||||
#include "Tools/Lock.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
using namespace std;
|
||||
|
||||
enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE };
|
||||
|
||||
class DataTag
|
||||
{
|
||||
int t[4];
|
||||
|
||||
public:
|
||||
// assume that tag is three integers
|
||||
DataTag(const int* tag)
|
||||
{
|
||||
strncpy((char*)t, (char*)tag, 3 * sizeof(int));
|
||||
t[3] = 0;
|
||||
}
|
||||
string get_string() const
|
||||
{
|
||||
return string((char*)t);
|
||||
}
|
||||
bool operator<(const DataTag& other) const
|
||||
{
|
||||
for (int i = 0; i < 3; i++)
|
||||
if (t[i] != other.t[i])
|
||||
return t[i] < other.t[i];
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct DataPositions
|
||||
{
|
||||
vector< vector<int> > files;
|
||||
vector< vector<int> > inputs;
|
||||
map<DataTag, int> extended[N_DATA_FIELD_TYPE];
|
||||
|
||||
DataPositions(int num_players = 0) { set_num_players(num_players); }
|
||||
void set_num_players(int num_players);
|
||||
void increase(const DataPositions& delta);
|
||||
void print_cost() const;
|
||||
};
|
||||
|
||||
class Processor;
|
||||
|
||||
class Data_Files
|
||||
{
|
||||
static map<DataTag, int> tuple_lengths;
|
||||
static Lock tuple_lengths_lock;
|
||||
|
||||
BufferHelper<Share, Share> buffers[N_DTYPE];
|
||||
BufferHelper<Share, Share>* input_buffers;
|
||||
BufferHelper<InputTuple, RefInputTuple> my_input_buffers;
|
||||
map<DataTag, BufferHelper<Share, Share> > extended;
|
||||
|
||||
int my_num,num_players;
|
||||
|
||||
DataPositions usage;
|
||||
|
||||
const string prep_data_dir;
|
||||
|
||||
public:
|
||||
|
||||
static const char* dtype_names[N_DTYPE];
|
||||
static const char* field_names[N_DATA_FIELD_TYPE];
|
||||
static const char* long_field_names[N_DATA_FIELD_TYPE];
|
||||
static const bool implemented[N_DATA_FIELD_TYPE][N_DTYPE];
|
||||
static const int tuple_size[N_DTYPE];
|
||||
|
||||
static int share_length(int field_type);
|
||||
static int tuple_length(int field_type, int dtype);
|
||||
|
||||
Data_Files(int my_num,int n,const string& prep_data_dir);
|
||||
~Data_Files();
|
||||
|
||||
DataPositions tellg();
|
||||
void seekg(DataPositions& pos);
|
||||
void skip(const DataPositions& pos);
|
||||
template<class T>
|
||||
bool eof(Dtype dtype);
|
||||
template<class T>
|
||||
bool input_eof(int player);
|
||||
|
||||
void setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size = 0);
|
||||
template<class T>
|
||||
void get(Processor& proc, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
|
||||
DataPositions get_usage()
|
||||
{
|
||||
return usage;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void get_three(DataFieldType field_type, Dtype dtype, Share<T>& a, Share<T>& b, Share<T>& c)
|
||||
{
|
||||
usage.files[field_type][dtype]++;
|
||||
buffers[dtype].input(a);
|
||||
buffers[dtype].input(b);
|
||||
buffers[dtype].input(c);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void get_two(DataFieldType field_type, Dtype dtype, Share<T>& a, Share<T>& b)
|
||||
{
|
||||
usage.files[field_type][dtype]++;
|
||||
buffers[dtype].input(a);
|
||||
buffers[dtype].input(b);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void get_one(DataFieldType field_type, Dtype dtype, Share<T>& a)
|
||||
{
|
||||
usage.files[field_type][dtype]++;
|
||||
buffers[dtype].input(a);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void get_input(Share<T>& a,T& x,int i)
|
||||
{
|
||||
usage.inputs[i][T::field_type()]++;
|
||||
RefInputTuple<T> tuple(a, x);
|
||||
if (i==my_num)
|
||||
my_input_buffers.input(tuple);
|
||||
else
|
||||
input_buffers[i].input(a);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
bool Data_Files::eof(Dtype dtype)
|
||||
{ return buffers[dtype].get_buffer(T::field_type()).eof; }
|
||||
|
||||
template<class T> inline
|
||||
bool Data_Files::input_eof(int player)
|
||||
{
|
||||
if (player == my_num)
|
||||
return my_input_buffers.get_buffer(T::field_type()).eof;
|
||||
else
|
||||
return input_buffers[player].get_buffer(T::field_type()).eof;
|
||||
}
|
||||
|
||||
#endif
|
||||
89
Processor/Input.cpp
Normal file
89
Processor/Input.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Input.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Input.h"
|
||||
#include "Processor.h"
|
||||
|
||||
template<class T>
|
||||
Input<T>::Input(Processor& proc, MAC_Check<T>& mc) :
|
||||
proc(proc), MC(mc), values_input(0)
|
||||
{
|
||||
buffer.setup(&proc.private_input, -1, "private input");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Input<T>::~Input()
|
||||
{
|
||||
if (timer.elapsed() > 0)
|
||||
cerr << T::type_string() << " inputs: " << timer.elapsed() << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::adjust_mac(Share<T>& share, T& value)
|
||||
{
|
||||
T tmp;
|
||||
tmp.mul(MC.get_alphai(), value);
|
||||
tmp.add(share.get_mac(),tmp);
|
||||
share.set_mac(tmp);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::start(int player, int n_inputs)
|
||||
{
|
||||
if (player == proc.P.my_num())
|
||||
{
|
||||
octetStream o;
|
||||
shares.resize(n_inputs);
|
||||
|
||||
for (int i = 0; i < n_inputs; i++)
|
||||
{
|
||||
T rr, t;
|
||||
Share<T>& share = shares[i];
|
||||
proc.DataF.get_input(share, rr, player);
|
||||
T xi;
|
||||
buffer.input(t);
|
||||
t.sub(t, rr);
|
||||
t.pack(o);
|
||||
xi.add(t, share.get_share());
|
||||
share.set_share(xi);
|
||||
adjust_mac(share, t);
|
||||
}
|
||||
|
||||
proc.P.send_all(o, true);
|
||||
values_input += n_inputs;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::stop(int player, vector<int> targets)
|
||||
{
|
||||
T tmp;
|
||||
|
||||
if (player == proc.P.my_num())
|
||||
{
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
proc.get_S_ref<T>(targets[i]) = shares[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
T t;
|
||||
octetStream o;
|
||||
timer.start();
|
||||
proc.P.receive_player(player, o, true);
|
||||
timer.stop();
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
{
|
||||
Share<T>& share = proc.get_S_ref<T>(targets[i]);
|
||||
proc.DataF.get_input(share, t, player);
|
||||
t.unpack(o);
|
||||
adjust_mac(share, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template class Input<gf2n>;
|
||||
template class Input<gfp>;
|
||||
43
Processor/Input.h
Normal file
43
Processor/Input.h
Normal file
@@ -0,0 +1,43 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Input.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_INPUT_H_
|
||||
#define PROCESSOR_INPUT_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "Math/Share.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
#include "Processor/Buffer.h"
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
class Processor;
|
||||
|
||||
template<class T>
|
||||
class Input
|
||||
{
|
||||
Processor& proc;
|
||||
MAC_Check<T>& MC;
|
||||
vector< Share<T> > shares;
|
||||
Buffer<T,T> buffer;
|
||||
Timer timer;
|
||||
|
||||
void adjust_mac(Share<T>& share, T& value);
|
||||
|
||||
public:
|
||||
int values_input;
|
||||
|
||||
Input(Processor& proc, MAC_Check<T>& mc);
|
||||
~Input();
|
||||
|
||||
void start(int player, int n_inputs);
|
||||
void stop(int player, vector<int> targets);
|
||||
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_INPUT_H_ */
|
||||
42
Processor/InputTuple.h
Normal file
42
Processor/InputTuple.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* InputTuple.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_INPUTTUPLE_H_
|
||||
#define PROCESSOR_INPUTTUPLE_H_
|
||||
|
||||
|
||||
template <class T>
|
||||
struct InputTuple
|
||||
{
|
||||
Share<T> share;
|
||||
T value;
|
||||
|
||||
static int size()
|
||||
{ return Share<T>::size() + T::size(); }
|
||||
|
||||
static string type_string()
|
||||
{ return T::type_string(); }
|
||||
|
||||
void assign(const char* buffer)
|
||||
{
|
||||
share.assign(buffer);
|
||||
value.assign(buffer + Share<T>::size());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <class T>
|
||||
struct RefInputTuple
|
||||
{
|
||||
Share<T>& share;
|
||||
T& value;
|
||||
RefInputTuple(Share<T>& share, T& value) : share(share), value(value) {}
|
||||
void operator=(InputTuple<T>& other) { share = other.share; value = other.value; }
|
||||
};
|
||||
|
||||
|
||||
#endif /* PROCESSOR_INPUTTUPLE_H_ */
|
||||
1558
Processor/Instruction.cpp
Normal file
1558
Processor/Instruction.cpp
Normal file
File diff suppressed because it is too large
Load Diff
311
Processor/Instruction.h
Normal file
311
Processor/Instruction.h
Normal file
@@ -0,0 +1,311 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#ifndef _Instruction
|
||||
#define _Instruction
|
||||
|
||||
/* Class to read and decode an instruction
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "Processor/Memory.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Math/Integer.h"
|
||||
#include "Auth/MAC_Check.h"
|
||||
|
||||
class Machine;
|
||||
class Processor;
|
||||
|
||||
/*
|
||||
* Opcode constants
|
||||
*
|
||||
* Whenever these are changed the corresponding dict in Compiler/instructions.py
|
||||
* MUST also be changed. (+ the documentation)
|
||||
*/
|
||||
enum
|
||||
{
|
||||
// Load/store
|
||||
LDI = 0x1,
|
||||
LDSI = 0x2,
|
||||
LDMC = 0x3,
|
||||
LDMS = 0x4,
|
||||
STMC = 0x5,
|
||||
STMS = 0x6,
|
||||
LDMCI = 0x7,
|
||||
LDMSI = 0x8,
|
||||
STMCI = 0x9,
|
||||
STMSI = 0xA,
|
||||
MOVC = 0xB,
|
||||
MOVS = 0xC,
|
||||
PROTECTMEMS = 0xD,
|
||||
PROTECTMEMC = 0xE,
|
||||
PROTECTMEMINT = 0xF,
|
||||
LDMINT = 0xCA,
|
||||
STMINT = 0xCB,
|
||||
LDMINTI = 0xCC,
|
||||
STMINTI = 0xCD,
|
||||
PUSHINT = 0xCE,
|
||||
POPINT = 0xCF,
|
||||
MOVINT = 0xD0,
|
||||
// Machine
|
||||
LDTN = 0x10,
|
||||
LDARG = 0x11,
|
||||
REQBL = 0x12,
|
||||
STARG = 0x13,
|
||||
TIME = 0x14,
|
||||
START = 0x15,
|
||||
STOP = 0x16,
|
||||
USE = 0x17,
|
||||
USE_INP = 0x18,
|
||||
RUN_TAPE = 0x19,
|
||||
JOIN_TAPE = 0x1A,
|
||||
CRASH = 0x1B,
|
||||
USE_PREP = 0x1C,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
ADDM = 0x22,
|
||||
ADDCI = 0x23,
|
||||
ADDSI = 0x24,
|
||||
SUBC = 0x25,
|
||||
SUBS = 0x26,
|
||||
SUBML = 0x27,
|
||||
SUBMR = 0x28,
|
||||
SUBCI = 0x29,
|
||||
SUBSI = 0x2A,
|
||||
SUBCFI = 0x2B,
|
||||
SUBSFI = 0x2C,
|
||||
// Multiplication/division/other arithmetic
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
MULCI = 0x32,
|
||||
MULSI = 0x33,
|
||||
DIVC = 0x34,
|
||||
DIVCI = 0x35,
|
||||
MODC = 0x36,
|
||||
MODCI = 0x37,
|
||||
LEGENDREC = 0x38,
|
||||
// Open
|
||||
STARTOPEN = 0xA0,
|
||||
STOPOPEN = 0xA1,
|
||||
// Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
SQUARE = 0x52,
|
||||
INV = 0x53,
|
||||
INPUTMASK = 0x56,
|
||||
PREP = 0x57,
|
||||
// Input
|
||||
INPUT = 0x60,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
READSOCKETS = 0x64,
|
||||
WRITESOCKETC = 0x65,
|
||||
WRITESOCKETS = 0x66,
|
||||
OPENSOCKET = 0x67,
|
||||
CLOSESOCKET = 0x68,
|
||||
// Bitwise logic
|
||||
ANDC = 0x70,
|
||||
XORC = 0x71,
|
||||
ORC = 0x72,
|
||||
ANDCI = 0x73,
|
||||
XORCI = 0x74,
|
||||
ORCI = 0x75,
|
||||
NOTC = 0x76,
|
||||
// Bitwise shifts
|
||||
SHLC = 0x80,
|
||||
SHRC = 0x81,
|
||||
SHLCI = 0x82,
|
||||
SHRCI = 0x83,
|
||||
// Branching and comparison
|
||||
JMP = 0x90,
|
||||
JMPNZ = 0x91,
|
||||
JMPEQZ = 0x92,
|
||||
EQZC = 0x93,
|
||||
LTZC = 0x94,
|
||||
LTC = 0x95,
|
||||
GTC = 0x96,
|
||||
EQC = 0x97,
|
||||
JMPI = 0x98,
|
||||
// Integers
|
||||
LDINT = 0x9A,
|
||||
ADDINT = 0x9B,
|
||||
SUBINT = 0x9C,
|
||||
MULINT = 0x9D,
|
||||
DIVINT = 0x9E,
|
||||
// Conversion
|
||||
CONVINT = 0xC0,
|
||||
CONVMODP = 0xC1,
|
||||
|
||||
// IO
|
||||
PRINTMEM = 0xB0,
|
||||
PRINTREG = 0XB1,
|
||||
RAND = 0xB2,
|
||||
PRINTREGPLAIN = 0xB3,
|
||||
PRINTCHR = 0xB4,
|
||||
PRINTSTR = 0xB5,
|
||||
PUBINPUT = 0xB6,
|
||||
RAWOUTPUT = 0xB7,
|
||||
STARTPRIVATEOUTPUT = 0xB8,
|
||||
STOPPRIVATEOUTPUT = 0xB9,
|
||||
PRINTCHRINT = 0xBA,
|
||||
PRINTSTRINT = 0xBB,
|
||||
|
||||
// GF(2^n) versions
|
||||
|
||||
// Load/store
|
||||
GLDI = 0x101,
|
||||
GLDSI = 0x102,
|
||||
GLDMC = 0x103,
|
||||
GLDMS = 0x104,
|
||||
GSTMC = 0x105,
|
||||
GSTMS = 0x106,
|
||||
GLDMCI = 0x107,
|
||||
GLDMSI = 0x108,
|
||||
GSTMCI = 0x109,
|
||||
GSTMSI = 0x10A,
|
||||
GMOVC = 0x10B,
|
||||
GMOVS = 0x10C,
|
||||
GPROTECTMEMS = 0x10D,
|
||||
GPROTECTMEMC = 0x10E,
|
||||
// Machine
|
||||
GREQBL = 0x112,
|
||||
GUSE_PREP = 0x11C,
|
||||
// Addition
|
||||
GADDC = 0x120,
|
||||
GADDS = 0x121,
|
||||
GADDM = 0x122,
|
||||
GADDCI = 0x123,
|
||||
GADDSI = 0x124,
|
||||
GSUBC = 0x125,
|
||||
GSUBS = 0x126,
|
||||
GSUBML = 0x127,
|
||||
GSUBMR = 0x128,
|
||||
GSUBCI = 0x129,
|
||||
GSUBSI = 0x12A,
|
||||
GSUBCFI = 0x12B,
|
||||
GSUBSFI = 0x12C,
|
||||
// Multiplication/division
|
||||
GMULC = 0x130,
|
||||
GMULM = 0x131,
|
||||
GMULCI = 0x132,
|
||||
GMULSI = 0x133,
|
||||
GDIVC = 0x134,
|
||||
GDIVCI = 0x135,
|
||||
GMULBITC = 0x136,
|
||||
GMULBITM = 0x137,
|
||||
// Open
|
||||
GSTARTOPEN = 0x1A0,
|
||||
GSTOPOPEN = 0x1A1,
|
||||
// Data access
|
||||
GTRIPLE = 0x150,
|
||||
GBIT = 0x151,
|
||||
GSQUARE = 0x152,
|
||||
GINV = 0x153,
|
||||
GBITTRIPLE = 0x154,
|
||||
GBITGF2NTRIPLE = 0x155,
|
||||
GINPUTMASK = 0x156,
|
||||
GPREP = 0x157,
|
||||
// Input
|
||||
GINPUT = 0x160,
|
||||
GSTARTINPUT = 0x161,
|
||||
GSTOPINPUT = 0x162,
|
||||
GREADSOCKETS = 0x164,
|
||||
GWRITESOCKETS = 0x166,
|
||||
// Bitwise logic
|
||||
GANDC = 0x170,
|
||||
GXORC = 0x171,
|
||||
GORC = 0x172,
|
||||
GANDCI = 0x173,
|
||||
GXORCI = 0x174,
|
||||
GORCI = 0x175,
|
||||
GNOTC = 0x176,
|
||||
// Bitwise shifts
|
||||
GSHLCI = 0x182,
|
||||
GSHRCI = 0x183,
|
||||
GBITDEC = 0x184,
|
||||
GBITCOM = 0x185,
|
||||
// Conversion
|
||||
GCONVINT = 0x1C0,
|
||||
GCONVGF2N = 0x1C1,
|
||||
// IO
|
||||
GPRINTMEM = 0x1B0,
|
||||
GPRINTREG = 0X1B1,
|
||||
GPRINTREGPLAIN = 0x1B3,
|
||||
GRAWOUTPUT = 0x1B7,
|
||||
GSTARTPRIVATEOUTPUT = 0x1B8,
|
||||
GSTOPPRIVATEOUTPUT = 0x1B9,
|
||||
};
|
||||
|
||||
|
||||
// Register types
|
||||
enum RegType {
|
||||
MODP,
|
||||
GF2N,
|
||||
INT,
|
||||
MAX_REG_TYPE,
|
||||
NONE
|
||||
};
|
||||
|
||||
enum SecrecyType {
|
||||
SECRET,
|
||||
CLEAR,
|
||||
MAX_SECRECY_TYPE
|
||||
};
|
||||
|
||||
|
||||
struct TempVars {
|
||||
gf2n ans2; Share<gf2n> Sans2;
|
||||
gfp ansp; Share<gfp> Sansp;
|
||||
bigint aa,aa2;
|
||||
// INPUT and LDSI
|
||||
gfp rrp,tp,tmpp;
|
||||
gfp xip;
|
||||
// GINPUT and GLDSI
|
||||
gf2n rr2,t2,tmp2;
|
||||
gf2n xi2;
|
||||
};
|
||||
|
||||
|
||||
class Instruction
|
||||
{
|
||||
int opcode; // The code
|
||||
int size; // Vector size
|
||||
int r[3]; // Three possible registers
|
||||
unsigned int n; // Possible immediate value
|
||||
vector<int> start; // Values for a start/stop open
|
||||
|
||||
public:
|
||||
|
||||
// Reads a single instruction from the istream
|
||||
void parse(istream& s);
|
||||
|
||||
// Return whether usage is known
|
||||
bool get_offline_data_usage(DataPositions& usage);
|
||||
|
||||
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
|
||||
RegType get_reg_type() const;
|
||||
|
||||
bool is_direct_memory_access(SecrecyType sec_type) const;
|
||||
|
||||
// Returns the maximal register used
|
||||
int get_max_reg(RegType reg_type) const;
|
||||
|
||||
// Returns the memory size used if applicable and known
|
||||
int get_mem(RegType reg_type, SecrecyType sec_type) const;
|
||||
|
||||
friend ostream& operator<<(ostream& s,const Instruction& instr);
|
||||
|
||||
// Execute this instruction, updateing the processor and memory
|
||||
// and streams pointing to the triples etc
|
||||
void execute(Processor& Proc) const;
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
362
Processor/Machine.cpp
Normal file
362
Processor/Machine.cpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Machine.h"
|
||||
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include <sys/time.h>
|
||||
|
||||
#include "Math/Setup.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <pthread.h>
|
||||
using namespace std;
|
||||
|
||||
Machine::Machine(int my_number, int PortnumBase, string hostname,
|
||||
string progname_str, string memtype, int lgp, int lg2, bool direct,
|
||||
int opening_sum, bool parallel, bool receive_threads, int max_broadcast)
|
||||
: my_number(my_number), nthreads(0), tn(0), numt(0), usage_unknown(false),
|
||||
progname(progname_str), direct(direct), opening_sum(opening_sum), parallel(parallel),
|
||||
receive_threads(receive_threads), max_broadcast(max_broadcast)
|
||||
{
|
||||
N.init(my_number,PortnumBase,hostname.c_str());
|
||||
|
||||
if (opening_sum < 2)
|
||||
this->opening_sum = N.num_players();
|
||||
if (max_broadcast < 2)
|
||||
this->max_broadcast = N.num_players();
|
||||
|
||||
// Set up the fields
|
||||
prep_dir_prefix = get_prep_dir(N.num_players(), lgp, lg2);
|
||||
read_setup(prep_dir_prefix);
|
||||
|
||||
char filename[1024];
|
||||
int nn;
|
||||
|
||||
sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number);
|
||||
inpf.open(filename);
|
||||
if (inpf.fail())
|
||||
{
|
||||
cerr << "Could not open MAC key file. Perhaps it needs to be generated?\n";
|
||||
throw file_error(filename);
|
||||
}
|
||||
inpf >> nn;
|
||||
if (nn!=N.num_players())
|
||||
{ cerr << "KeyGen was last run with " << nn << " players." << endl;
|
||||
cerr << " - You are running Online with " << N.num_players() << " players." << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
alphapi.input(inpf,true);
|
||||
alpha2i.input(inpf,true);
|
||||
cerr << "MAC Key p = " << alphapi << endl;
|
||||
cerr << "MAC Key 2 = " << alpha2i << endl;
|
||||
inpf.close();
|
||||
|
||||
|
||||
// Initialize the global memory
|
||||
if (memtype.compare("new")==0)
|
||||
{sprintf(filename, "Player-Data/Player-Memory-P%d", my_number);
|
||||
ifstream memfile(filename);
|
||||
if (memfile.fail()) { throw file_error(filename); }
|
||||
Load_Memory(M2,memfile);
|
||||
Load_Memory(Mp,memfile);
|
||||
Load_Memory(Mi,memfile);
|
||||
memfile.close();
|
||||
}
|
||||
else if (memtype.compare("old")==0)
|
||||
{
|
||||
sprintf(filename, "Player-Data/Memory-P%d", my_number);
|
||||
inpf.open(filename,ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(); }
|
||||
inpf >> M2 >> Mp >> Mi;
|
||||
inpf.close();
|
||||
}
|
||||
else if (!(memtype.compare("empty")==0))
|
||||
{ cerr << "Invalid memory argument" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
sprintf(filename, "Programs/Schedules/%s.sch",progname.c_str());
|
||||
cerr << "Opening file " << filename << endl;
|
||||
inpf.open(filename);
|
||||
if (inpf.fail()) { throw file_error("Missing '" + string(filename) + "'. Did you compile '" + progname + "'?"); }
|
||||
|
||||
int nprogs;
|
||||
inpf >> nthreads;
|
||||
inpf >> nprogs;
|
||||
|
||||
// Keep record of used offline data
|
||||
pos.set_num_players(N.num_players());
|
||||
|
||||
cerr << "Number of threads I will run in parallel = " << nthreads << endl;
|
||||
cerr << "Number of program sequences I need to load = " << nprogs << endl;
|
||||
|
||||
// Load in the programs
|
||||
progs.resize(nprogs,N.num_players());
|
||||
char threadname[1024];
|
||||
for (int i=0; i<nprogs; i++)
|
||||
{ inpf >> threadname;
|
||||
sprintf(filename,"Programs/Bytecode/%s.bc",threadname);
|
||||
cerr << "Loading program " << i << " from " << filename << endl;
|
||||
ifstream pinp(filename);
|
||||
if (pinp.fail()) { throw file_error(filename); }
|
||||
progs[i].parse(pinp);
|
||||
pinp.close();
|
||||
if (progs[i].direct_mem2_s() > M2.size_s())
|
||||
{
|
||||
cerr << threadname << " needs more secret mod2 memory, resizing to "
|
||||
<< progs[i].direct_mem2_s() << endl;
|
||||
M2.resize_s(progs[i].direct_mem2_s());
|
||||
}
|
||||
if (progs[i].direct_memp_s() > Mp.size_s())
|
||||
{
|
||||
cerr << threadname << " needs more secret modp memory, resizing to "
|
||||
<< progs[i].direct_memp_s() << endl;
|
||||
Mp.resize_s(progs[i].direct_memp_s());
|
||||
}
|
||||
if (progs[i].direct_mem2_c() > M2.size_c())
|
||||
{
|
||||
cerr << threadname << " needs more clear mod2 memory, resizing to "
|
||||
<< progs[i].direct_mem2_c() << endl;
|
||||
M2.resize_c(progs[i].direct_mem2_c());
|
||||
}
|
||||
if (progs[i].direct_memp_c() > Mp.size_c())
|
||||
{
|
||||
cerr << threadname << " needs more clear modp memory, resizing to "
|
||||
<< progs[i].direct_memp_c() << endl;
|
||||
Mp.resize_c(progs[i].direct_memp_c());
|
||||
}
|
||||
if (progs[i].direct_memi_c() > Mi.size_c())
|
||||
{
|
||||
cerr << threadname << " needs more clear integer memory, resizing to "
|
||||
<< progs[i].direct_memi_c() << endl;
|
||||
Mi.resize_c(progs[i].direct_memi_c());
|
||||
}
|
||||
}
|
||||
|
||||
progs[0].print_offline_cost();
|
||||
|
||||
/* Set up the threads */
|
||||
tinfo.resize(nthreads);
|
||||
threads.resize(nthreads);
|
||||
t_mutex.resize(nthreads);
|
||||
client_ready.resize(nthreads);
|
||||
server_ready.resize(nthreads);
|
||||
join_timer.resize(nthreads);
|
||||
|
||||
for (int i=0; i<nthreads; i++)
|
||||
{ pthread_mutex_init(&t_mutex[i],NULL);
|
||||
pthread_cond_init(&client_ready[i],NULL);
|
||||
pthread_cond_init(&server_ready[i],NULL);
|
||||
tinfo[i].thread_num=i;
|
||||
tinfo[i].Nms=&N;
|
||||
tinfo[i].alphapi=&alphapi;
|
||||
tinfo[i].alpha2i=&alpha2i;
|
||||
tinfo[i].prognum=-2; // Dont do anything until we are ready
|
||||
tinfo[i].finished=true;
|
||||
tinfo[i].ready=false;
|
||||
tinfo[i].machine=this;
|
||||
// lock for synchronization
|
||||
pthread_mutex_lock(&t_mutex[i]);
|
||||
pthread_create(&threads[i],NULL,Main_Func,&tinfo[i]);
|
||||
}
|
||||
|
||||
// synchronize with clients before starting timer
|
||||
for (int i=0; i<nthreads; i++)
|
||||
{
|
||||
while (!tinfo[i].ready)
|
||||
{
|
||||
cerr << "Waiting for thread " << i << " to be ready" << endl;
|
||||
pthread_cond_wait(&client_ready[i],&t_mutex[i]);
|
||||
}
|
||||
pthread_mutex_unlock(&t_mutex[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number)
|
||||
{
|
||||
pthread_mutex_lock(&t_mutex[thread_number]);
|
||||
tinfo[thread_number].prognum=tape_number;
|
||||
tinfo[thread_number].arg=arg;
|
||||
tinfo[thread_number].pos=pos;
|
||||
tinfo[thread_number].finished=false;
|
||||
//printf("Send signal to run program %d in thread %d\n",tape_number,thread_number);
|
||||
pthread_cond_signal(&server_ready[thread_number]);
|
||||
pthread_mutex_unlock(&t_mutex[thread_number]);
|
||||
//printf("Running line %d\n",exec);
|
||||
if (progs[tape_number].usage_unknown())
|
||||
{ // only one thread allowed
|
||||
if (numt>1)
|
||||
{ cerr << "Line " << line_number << " has " <<
|
||||
numt << " threads but tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
throw invalid_program();
|
||||
}
|
||||
else if (line_number == -1)
|
||||
{
|
||||
cerr << "Internally called tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
throw invalid_program();
|
||||
}
|
||||
usage_unknown = true;
|
||||
return DataPositions(N.num_players());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Bits, Triples, Squares, and Inverses skipping
|
||||
return progs[tape_number].get_offline_data_used();
|
||||
}
|
||||
}
|
||||
|
||||
void Machine::join_tape(int i)
|
||||
{
|
||||
join_timer[i].start();
|
||||
pthread_mutex_lock(&t_mutex[i]);
|
||||
//printf("Waiting for client to terminate\n");
|
||||
if ((tinfo[i].finished)==false)
|
||||
{ pthread_cond_wait(&client_ready[i],&t_mutex[i]); }
|
||||
pthread_mutex_unlock(&t_mutex[i]);
|
||||
join_timer[i].stop();
|
||||
}
|
||||
|
||||
void Machine::run()
|
||||
{
|
||||
Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID);
|
||||
proc_timer.start();
|
||||
timer[0].start();
|
||||
|
||||
bool flag=true;
|
||||
usage_unknown=false;
|
||||
int exec=0;
|
||||
while (flag)
|
||||
{ inpf >> numt;
|
||||
if (numt==0)
|
||||
{ flag=false; }
|
||||
else
|
||||
{ for (int i=0; i<numt; i++)
|
||||
{
|
||||
// Now load up data
|
||||
inpf >> tn;
|
||||
|
||||
// Cope with passing an integer parameter to a tape
|
||||
int arg;
|
||||
if (inpf.get() == ':')
|
||||
inpf >> arg;
|
||||
else
|
||||
arg = 0;
|
||||
|
||||
//cerr << "Run scheduled tape " << tn << " in thread " << i << endl;
|
||||
pos.increase(run_tape(i, tn, arg, exec));
|
||||
}
|
||||
// Make sure all terminate before we continue
|
||||
for (int i=0; i<numt; i++)
|
||||
{ join_tape(i);
|
||||
}
|
||||
if (usage_unknown)
|
||||
{ // synchronize files
|
||||
pos = tinfo[0].pos;
|
||||
usage_unknown = false;
|
||||
}
|
||||
//printf("Finished running line %d\n",exec);
|
||||
exec++;
|
||||
}
|
||||
}
|
||||
|
||||
char compiler[1000];
|
||||
inpf.get();
|
||||
inpf.getline(compiler, 1000);
|
||||
if (compiler[0] != 0)
|
||||
cerr << "Compiler: " << compiler << endl;
|
||||
inpf.close();
|
||||
|
||||
finish_timer.start();
|
||||
// Tell all C-threads to stop
|
||||
for (int i=0; i<nthreads; i++)
|
||||
{ pthread_mutex_lock(&t_mutex[i]);
|
||||
//printf("Send kill signal to client\n");
|
||||
tinfo[i].prognum=-1;
|
||||
tinfo[i].ready = false;
|
||||
pthread_cond_signal(&server_ready[i]);
|
||||
pthread_mutex_unlock(&t_mutex[i]);
|
||||
}
|
||||
|
||||
cerr << "Waiting for all clients to finish" << endl;
|
||||
// Wait until all clients have signed out
|
||||
for (int i=0; i<nthreads; i++)
|
||||
{
|
||||
pthread_mutex_lock(&t_mutex[i]);
|
||||
tinfo[i].ready = true;
|
||||
pthread_cond_signal(&server_ready[i]);
|
||||
pthread_mutex_unlock(&t_mutex[i]);
|
||||
pthread_join(threads[i],NULL);
|
||||
pthread_mutex_destroy(&t_mutex[i]);
|
||||
pthread_cond_destroy(&client_ready[i]);
|
||||
pthread_cond_destroy(&server_ready[i]);
|
||||
}
|
||||
finish_timer.stop();
|
||||
|
||||
for (unsigned int i = 0; i < join_timer.size(); i++)
|
||||
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
|
||||
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
|
||||
cerr << "Process timer: " << proc_timer.elapsed() << endl;
|
||||
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
|
||||
timer.erase(0);
|
||||
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
|
||||
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
|
||||
|
||||
if (opening_sum < N.num_players() && !direct)
|
||||
cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl;
|
||||
else
|
||||
cerr << "Summed all shares at once" << endl;
|
||||
|
||||
if (max_broadcast < N.num_players() && !direct)
|
||||
cerr << "Send to at most " << max_broadcast << " parties at once" << endl;
|
||||
else
|
||||
cerr << "Full broadcast" << endl;
|
||||
|
||||
// Reduce memory size to speed up
|
||||
int max_size = 1 << 20;
|
||||
if (M2.size_s() > max_size)
|
||||
M2.resize_s(max_size);
|
||||
if (Mp.size_s() > max_size)
|
||||
Mp.resize_s(max_size);
|
||||
|
||||
// Write out the memory to use next time
|
||||
char filename[1024];
|
||||
sprintf(filename,"Player-Data/Memory-P%d",my_number);
|
||||
ofstream outf(filename,ios::out | ios::binary);
|
||||
outf << M2 << Mp << Mi;
|
||||
outf.close();
|
||||
|
||||
extern unsigned long long sent_amount, sent_counter;
|
||||
cerr << "Data sent = " << sent_amount << " bytes in "
|
||||
<< sent_counter << " calls,";
|
||||
cerr << sent_amount / sent_counter / N.num_players()
|
||||
<< " bytes per call" << endl;
|
||||
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
{
|
||||
cerr << "Num " << Data_Files::dtype_names[dtype] << "\t=";
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
cerr << " " << pos.files[field_type][dtype];
|
||||
cerr << endl;
|
||||
}
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
cerr << "Num " << Data_Files::long_field_names[field_type] << " Inputs\t=";
|
||||
for (int i = 0; i < N.num_players(); i++)
|
||||
cerr << " " << pos.inputs[i][field_type];
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
cerr << "Total cost of program:" << endl;
|
||||
pos.print_cost();
|
||||
|
||||
cerr << "End of prog" << endl;
|
||||
}
|
||||
|
||||
|
||||
83
Processor/Machine.h
Normal file
83
Processor/Machine.h
Normal file
@@ -0,0 +1,83 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
/*
|
||||
* Machine.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MACHINE_H_
|
||||
#define MACHINE_H_
|
||||
|
||||
#include "Processor/Memory.h"
|
||||
#include "Processor/Program.h"
|
||||
|
||||
#include "Processor/Online-Thread.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Math/gfp.h"
|
||||
|
||||
#include "Tools/time-func.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
using namespace std;
|
||||
|
||||
class Machine
|
||||
{
|
||||
/* The mutex's lock the C-threads and then only release
|
||||
* then we an MPC thread is ready to run on the C-thread.
|
||||
* Control is passed back to the main loop when the
|
||||
* MPC thread releases the mutex
|
||||
*/
|
||||
|
||||
vector<thread_info> tinfo;
|
||||
vector<pthread_t> threads;
|
||||
|
||||
int my_number;
|
||||
Names N;
|
||||
gfp alphapi;
|
||||
gf2n alpha2i;
|
||||
|
||||
int nthreads;
|
||||
|
||||
ifstream inpf;
|
||||
|
||||
// Keep record of used offline data
|
||||
DataPositions pos;
|
||||
|
||||
int tn,numt;
|
||||
bool usage_unknown;
|
||||
|
||||
public:
|
||||
|
||||
vector<pthread_mutex_t> t_mutex;
|
||||
vector<pthread_cond_t> client_ready;
|
||||
vector<pthread_cond_t> server_ready;
|
||||
vector<Program> progs;
|
||||
|
||||
Memory<gf2n> M2;
|
||||
Memory<gfp> Mp;
|
||||
Memory<Integer> Mi;
|
||||
|
||||
std::map<int,Timer> timer;
|
||||
vector<Timer> join_timer;
|
||||
Timer finish_timer;
|
||||
|
||||
string prep_dir_prefix;
|
||||
string progname;
|
||||
|
||||
bool direct;
|
||||
int opening_sum;
|
||||
bool parallel;
|
||||
bool receive_threads;
|
||||
int max_broadcast;
|
||||
|
||||
Machine(int my_number, int PortnumBase, string hostname, string progname,
|
||||
string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel,
|
||||
bool receive_threads, int max_broadcast);
|
||||
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number);
|
||||
void join_tape(int thread_number);
|
||||
void run();
|
||||
};
|
||||
|
||||
#endif /* MACHINE_H_ */
|
||||
147
Processor/Memory.cpp
Normal file
147
Processor/Memory.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// (C) 2016 University of Bristol. See License.txt
|
||||
|
||||
#include "Processor/Memory.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/Integer.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#ifdef MEMPROTECT
|
||||
template<class T>
|
||||
void Memory<T>::protect_s(unsigned int start, unsigned int end)
|
||||
{
|
||||
protected_s.insert(pair<unsigned int,unsigned int>(start, end));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Memory<T>::protect_c(unsigned int start, unsigned int end)
|
||||
{
|
||||
protected_c.insert(pair<unsigned int,unsigned int>(start, end));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
bool Memory<T>::is_protected_s(unsigned int index)
|
||||
{
|
||||
for (set< pair<unsigned int,unsigned int> >::iterator it = protected_s.begin();
|
||||
it != protected_s.end(); it++)
|
||||
if (it->first <= index and it->second > index)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
bool Memory<T>::is_protected_c(unsigned int index)
|
||||
{
|
||||
for (set< pair<unsigned int,unsigned int> >::iterator it = protected_c.begin();
|
||||
it != protected_c.end(); it++)
|
||||
if (it->first <= index and it->second > index)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
template<class T>
|
||||
ostream& operator<<(ostream& s,const Memory<T>& M)
|
||||
{
|
||||
s << M.MS.size() << endl;
|
||||
s << M.MC.size() << endl;
|
||||
|
||||
#ifdef DEBUG
|
||||
for (unsigned int i=0; i<M.MS.size(); i++)
|
||||
{ M.MS[i].output(s,true); s << endl; }
|
||||
s << endl;
|
||||
|
||||
for (unsigned int i=0; i<M.MC.size(); i++)
|
||||
{ M.MC[i].output(s,true); s << endl; }
|
||||
s << endl;
|
||||
#else
|
||||
for (unsigned int i=0; i<M.MS.size(); i++)
|
||||
{ M.MS[i].output(s,false); }
|
||||
|
||||
for (unsigned int i=0; i<M.MC.size(); i++)
|
||||
{ M.MC[i].output(s,false); }
|
||||
#endif
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
istream& operator>>(istream& s,Memory<T>& M)
|
||||
{
|
||||
int len;
|
||||
|
||||
s >> len;
|
||||
M.resize_s(len);
|
||||
s >> len;
|
||||
M.resize_c(len);
|
||||
s.seekg(1, istream::cur);
|
||||
|
||||
for (unsigned int i=0; i<M.MS.size(); i++)
|
||||
{ M.MS[i].input(s,false); }
|
||||
|
||||
for (unsigned int i=0; i<M.MC.size(); i++)
|
||||
{ M.MC[i].input(s,false); }
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Load_Memory(Memory<T>& M,ifstream& inpf)
|
||||
{
|
||||
int a;
|
||||
T val;
|
||||
Share<T> S;
|
||||
|
||||
inpf >> a;
|
||||
M.resize_s(a);
|
||||
inpf >> a;
|
||||
M.resize_c(a);
|
||||
|
||||
cerr << "Reading Clear Memory" << endl;
|
||||
|
||||
// Read clear memory
|
||||
inpf >> a;
|
||||
val.input(inpf,true);
|
||||
while (a!=-1)
|
||||
{ M.write_C(a,val);
|
||||
inpf >> a;
|
||||
val.input(inpf,true);
|
||||
}
|
||||
cerr << "Reading Shared Memory" << endl;
|
||||
|
||||
// Read shared memory
|
||||
inpf >> a;
|
||||
S.input(inpf,true);
|
||||
while (a!=-1)
|
||||
{ M.write_S(a,S);
|
||||
inpf >> a;
|
||||
S.input(inpf,true);
|
||||
}
|
||||
}
|
||||
|
||||
template class Memory<gfp>;
|
||||
template class Memory<gf2n>;
|
||||
template class Memory<Integer>;
|
||||
|
||||
template istream& operator>>(istream& s,Memory<gfp>& M);
|
||||
template istream& operator>>(istream& s,Memory<gf2n>& M);
|
||||
template istream& operator>>(istream& s,Memory<Integer>& M);
|
||||
|
||||
template ostream& operator<<(ostream& s,const Memory<gfp>& M);
|
||||
template ostream& operator<<(ostream& s,const Memory<gf2n>& M);
|
||||
template ostream& operator<<(ostream& s,const Memory<Integer>& M);
|
||||
|
||||
template void Load_Memory(Memory<gfp>& M,ifstream& inpf);
|
||||
template void Load_Memory(Memory<gf2n>& M,ifstream& inpf);
|
||||
template void Load_Memory(Memory<Integer>& M,ifstream& inpf);
|
||||
|
||||
#ifdef USE_GF2N_LONG
|
||||
template class Memory<gf2n_short>;
|
||||
template istream& operator>>(istream& s,Memory<gf2n_short>& M);
|
||||
template ostream& operator<<(ostream& s,const Memory<gf2n_short>& M);
|
||||
template void Load_Memory(Memory<gf2n_short>& M,ifstream& inpf);
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user