SPDZ2k offline phase.

This commit is contained in:
Nikolaj Volgushev
2019-03-26 18:36:07 +11:00
committed by Marcel Keller
parent 5a0413de7a
commit 881b4403ac
49 changed files with 2125 additions and 380 deletions

1
.gitignore vendored
View File

@@ -11,6 +11,7 @@ keys/*
##############################
CONFIG.mine
config_mine.py
HOSTS
# Temporary files #
###################

View File

@@ -43,14 +43,20 @@ MAC_Check<T>::~MAC_Check()
{
}
template<class T>
void MAC_Check<T>::PrepareSending(vector<T>& values, const vector<Share<T> >& S)
{
values.resize(S.size());
for (unsigned int i=0; i<S.size(); i++)
{ values[i]=S[i].get_share(); }
}
template<class T>
void MAC_Check<T>::POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P)
{
AddToMacs(S);
values.resize(S.size());
for (unsigned int i=0; i<S.size(); i++)
{ values[i]=S[i].get_share(); }
PrepareSending(values, S);
this->start(values, P);
@@ -115,9 +121,9 @@ void MAC_Check<T>::CheckIfNeeded(const Player& P)
template <class T>
void MAC_Check<T>::AddToCheck(const T& mac, const T& value, const Player& P)
void MAC_Check<T>::AddToCheck(const Share<T>& share, const T& value, const Player& P)
{
macs.push_back(mac);
macs.push_back(share.get_mac());
vals.push_back(value);
popen_cnt++;
CheckIfNeeded(P);
@@ -179,6 +185,115 @@ int mc_base_id(int function_id, int thread_num)
return (function_id << 28) + ((T::field_type() + 1) << 24) + (thread_num << 16);
}
template<class T, class U, class V>
MAC_Check_Z2k<T, U, V>::MAC_Check_Z2k(const T& ai, const Share<T>& dummy_element, int opening_sum, int max_broadcast, int send_player) :
MAC_Check<T>(ai, opening_sum, max_broadcast, send_player),
dummy_element(dummy_element)
{
}
template<class T, class U, class V>
void MAC_Check_Z2k<T, U, V>::AddToCheck(const Share<T>& share, const T& value, const Player& P)
{
shares.push_back(share.get_share());
MAC_Check<T>::AddToCheck(share, value, P);
}
template<class T, class U, class V>
void MAC_Check_Z2k<T, U, V>::AddToMacs(const vector<Share<T> >& shares)
{
for (auto& share : shares)
this->shares.push_back(share.get_share());
MAC_Check<T>::AddToMacs(shares);
}
template<class T, class U, class V>
void MAC_Check_Z2k<T, U, V>::PrepareSending(vector<T>& values,
const vector<Share<T> >& S)
{
values.clear();
values.reserve(S.size());
for (auto& share : S)
values.push_back(V(share.get_share()));
}
template<class T, class U, class V>
Share<T> MAC_Check_Z2k<T, U, V>::get_random_element() {
return dummy_element;
}
template<class T, class U, class V>
void MAC_Check_Z2k<T, U, V>::set_random_element(const Share<T>& random_element) {
this->dummy_element = random_element;
}
template<class T, class U, class V>
void MAC_Check_Z2k<T, U, V>::Check(const Player& P)
{
if (this->WaitingForCheck() == 0)
return;
int k = V::N_BITS;
octet seed[SEED_SIZE];
Create_Random_Seed(seed,P,SEED_SIZE);
PRNG G;
G.SetSeed(seed);
T y, mj;
y.assign_zero();
mj.assign_zero();
vector<U> chi;
for (int i = 0; i < this->popen_cnt; ++i)
{
U temp_chi;
temp_chi.randomize(G);
T xi = this->vals[i];
y += xi * temp_chi;
T mji = this->macs[i];
mj += temp_chi * mji;
chi.push_back(temp_chi);
}
Share<T> r = get_random_element();
T lj = r.get_mac();
U pj;
pj.assign_zero();
for (int i = 0; i < this->popen_cnt; ++i)
{
T xji = shares[i];
V xbarji = xji;
U pji = U((xji - xbarji) >> k);
pj += chi[i] * pji;
}
pj += U(r.get_share());
U pbar(pj);
vector<octetStream> pj_stream(P.num_players());
pj.pack(pj_stream[P.my_num()]);
P.Broadcast_Receive(pj_stream, true);
for (int j=0; j<P.num_players(); j++) {
if (j!=P.my_num()) {
pbar += pj_stream[j].consume(U::size());
}
}
T zj = mj - (this->alphai * y) - (((this->alphai * pbar)) << k) + (lj << k);
vector<T> zjs(P.num_players());
zjs[P.my_num()] = zj;
Commit_And_Open(zjs, P);
T zj_sum;
zj_sum.assign_zero();
for (int i = 0; i < P.num_players(); ++i)
zj_sum += zjs[i];
this->vals.erase(this->vals.begin(), this->vals.begin() + this->popen_cnt);
this->macs.erase(this->macs.begin(), this->macs.begin() + this->popen_cnt);
this->shares.erase(this->shares.begin(), this->shares.begin() + this->popen_cnt);
this->popen_cnt=0;
if (!zj_sum.is_zero()) { throw mac_fail(); }
}
template<class T>
Separate_MAC_Check<T>::Separate_MAC_Check(const T& ai, Names& Nms,
int thread_num, int opening_sum, int max_broadcast, int send_player) :
@@ -381,4 +496,13 @@ template class Direct_MAC_Check<gf2n_short>;
template class Parallel_MAC_Check<gf2n_short>;
#endif
template class MAC_Check_Z2k<Z2<64>, Z2<32>, Z2<32> >;
template class MAC_Check_Z2k<Z2<128>, Z2<64>, Z2<64> >;
template class MAC_Check_Z2k<Z2<96>, Z2<32>, Z2<64> >;
template class MAC_Check_Z2k<Z2<160>, Z2<96>, Z2<64> >;
template class MAC_Check_Z2k<Z2<192>, Z2<64>, Z2<128> >;
template class MAC_Check_Z2k<Z2<256>, Z2<96>, Z2<160> >;
template class MAC_Check<Z2<96> >;
template class MAC_Check<Z2<160> >;
template class MAC_Check<Z2<192> >;
template class MAC_Check<Z2<256> >;

View File

@@ -76,7 +76,8 @@ class MAC_Check : public TreeSum<T>
/* MAC Share */
T alphai;
void AddToMacs(const vector< Share<T> >& shares);
virtual void AddToMacs(const vector< Share<T> >& shares);
virtual void PrepareSending(vector<T>& values,const vector<Share<T> >& S);
void AddToValues(vector<T>& values);
void GetValues(vector<T>& values);
void CheckIfNeeded(const Player& P);
@@ -99,7 +100,7 @@ class MAC_Check : public TreeSum<T>
*/
virtual void POpen_Begin(vector<T>& values,const vector<Share<T> >& S,const Player& P);
virtual void POpen_End(vector<T>& values,const vector<Share<T> >& S,const Player& P);
void AddToCheck(const T& mac, const T& value, const Player& P);
virtual void AddToCheck(const Share<T>& share, const T& value, const Player& P);
virtual void Check(const Player& P);
int number() const { return values_opened; }
@@ -107,6 +108,25 @@ class MAC_Check : public TreeSum<T>
const T& get_alphai() const { return alphai; }
};
template<class T, class U, class V>
class MAC_Check_Z2k : public MAC_Check<T>
{
protected:
vector<T> shares;
Share<T> dummy_element;
Share<T> get_random_element();
void AddToMacs(const vector< Share<T> >& shares);
void PrepareSending(vector<T>& values,const vector<Share<T> >& S);
public:
void AddToCheck(const Share<T>& share, const T& value, const Player& P);
MAC_Check_Z2k(const T& ai, const Share<T>& dummy_element = {}, int opening_sum=10, int max_broadcast=10, int send_player=0);
virtual void Check(const Player& P);
void set_random_element(const Share<T>& random_element);
virtual ~MAC_Check_Z2k() {};
};
template <class T, class U, class V>
using MAC_Check_ = MAC_Check<T>;

View File

@@ -252,4 +252,9 @@ template void Create_Random(gf2n_short& ans,const Player& P);
template void Commit_And_Open(vector<gfp>& data,const Player& P);
template void Create_Random(gfp& ans,const Player& P);
template void Commit_And_Open(vector<Z2<64> >& data,const Player& P);
template void Commit_And_Open(vector<Z2<96> >& data,const Player& P);
template void Commit_And_Open(vector<Z2<128> >& data,const Player& P);
template void Commit_And_Open(vector<Z2<160> >& data,const Player& P);
template void Commit_And_Open(vector<Z2<192> >& data,const Player& P);
template void Commit_And_Open(vector<Z2<256> >& data,const Player& P);

View File

@@ -55,6 +55,34 @@ void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key)
}
}
template<class T, class V>
void check_share(vector<Share<T> >& Sa,
V& value,
T& mac,
int N,
const T& key)
{
value.assign(0);
mac.assign(0);
for (int i=0; i<N; i++)
{
value.add(Sa[i].get_share());
mac.add(Sa[i].get_mac());
}
V res;
res.mul(value, key);
if (res != mac)
{
cout << "Value: " << value << endl;
cout << "Input MAC: " << mac << endl;
cout << "Actual MAC: " << res << endl;
cout << "MAC key: " << key << endl;
throw mac_fail();
}
}
template void make_share(vector<Share<gf2n> >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G);
template void make_share(vector<Share<gfp> >& Sa,const gfp& a,int N,const gfp& key,PRNG& G);
@@ -66,7 +94,26 @@ template void make_share(vector<Share<gf2n_short> >& Sa,const gf2n_short& a,int
template void check_share(vector<Share<gf2n_short> >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key);
#endif
template void check_share(vector<Share<Z2<64> > >& Sa,Z2<64>& value,Z2<64>& mac,int N,const Z2<64>& key);
template void check_share(
vector<Share<Z2<160> > >& Sa,
Z2<64>& value,
Z2<160>& mac,
int N,
const Z2<160>& key);
template void check_share(
vector<Share<Z2<128> > >& Sa,
Z2<64>& value,
Z2<128>& mac,
int N,
const Z2<128>& key);
template void check_share(
vector<Share<Z2<64> > >& Sa,
Z2<32>& value,
Z2<64>& mac,
int N,
const Z2<64>& key);
// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40)
void expand_byte(gf2n_short& a,int b)

View File

@@ -6,6 +6,7 @@
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/Z2k.h"
#include "Math/Share.h"
#include <fstream>
@@ -17,6 +18,13 @@ void make_share(vector<Share<T> >& Sa,const T& a,int N,const T& key,PRNG& G);
template<class T>
void check_share(vector<Share<T> >& Sa,T& value,T& mac,int N,const T& key);
template<class T, class V>
void check_share(vector<Share<T> >& Sa,
V& value,
T& mac,
int N,
const T& key);
void expand_byte(gf2n_short& a,int b);
void collapse_byte(int& b,const gf2n_short& a);

18
CONFIG
View File

@@ -13,17 +13,29 @@ PREP_DIR = '-DPREP_DIR="Player-Data/"'
# set for 128-bit GF(2^n) and/or OT preprocessing
USE_GF2N_LONG = 0
# set SPDZ_2K bit length parameters K and S
SPDZ2K_K = -DSPDZ2K_K=64
SPDZ2K_S = -DSPDZ2K_S=64
# use additional optimizations in vole protocol
USE_OPT_VOLE = 1
NUM_VOLE_CHALLENGES = -DNUM_VOLE_CHALLENGES=3
# set to -march=<architecture> for optimization
# AVX2 support (Haswell or later) changes the bit matrix transpose
ARCH = -mtune=native -mavx
ARCH = -mtune=native -mavx -march=native
#use CONFIG.mine to overwrite DIR settings
# use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine
ifeq ($(USE_GF2N_LONG),1)
GF2N_LONG = -DUSE_GF2N_LONG
endif
ifeq ($(USE_OPT_VOLE),1)
OPT_VOLE = -DUSE_OPT_VOLE
endif
# MAX_MOD_SZ must be at least ceil(len(p)/len(word))
# Default is 2, which suffices for 128-bit p
# MOD = -DMAX_MOD_SZ=2
@@ -40,7 +52,7 @@ LDLIBS += -lrt
endif
CXX = g++
CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 --std=c++11 -Werror
CFLAGS = $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MOD) $(SPDZ2K_K) $(SPDZ2K_S) $(OPT_VOLE) $(NUM_VOLE_CHALLENGES) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 -mbmi2 --std=c++11 -Werror -no-pie
CPPFLAGS = $(CFLAGS)
LD = g++

89
Check-Offline-Z2k.cpp Normal file
View File

@@ -0,0 +1,89 @@
// (C) 2018 University of Bristol. See License.txt
#include "Math/Z2k.h"
#include "Math/Share.h"
#include "Math/Setup.h"
#include "Auth/fake-stuff.h"
#include <fstream>
#include <vector>
#include <numeric>
template <class T, class U, class V>
void check_triples_Z2k(int n_players, string type_char = "")
{
T keyp; keyp.assign_zero();
U pp;
ifstream inpf;
for (int i= 0; i < n_players; i++)
{
stringstream ss;
ss << get_prep_dir(n_players, 128, 128) << "Player-MAC-Key-";
if (type_char.size())
ss << type_char;
else
ss << U::type_char();
ss << "-P" << i;
cout << "Opening file " << ss.str() << endl;
inpf.open(ss.str().c_str());
if (inpf.fail()) { throw file_error(ss.str()); }
pp.input(inpf,false);
cout << " Key " << i << "\t p: " << pp << endl;
keyp.add(pp);
inpf.close();
}
cout << "--------------\n";
cout << "Final Keys :\t p: " << keyp << endl;
ifstream* inputFiles = new ifstream[n_players];
for (int i = 0; i < n_players; i++)
{
stringstream ss;
ss << get_prep_dir(n_players, 128, 128) << "Triples-";
if (type_char.size())
ss << type_char;
else
ss << T::type_char();
ss << "-P" << i;
inputFiles[i].open(ss.str().c_str());
cout << "Opening file " << ss.str() << endl;
}
int j = 0;
while (inputFiles[0].peek() != EOF)
{
V a,b,c,prod;
T mac;
vector<Share<T>> as(n_players), bs(n_players), cs(n_players);
for (int i = 0; i < n_players; i++)
{
as[i].input(inputFiles[i], false);
bs[i].input(inputFiles[i], false);
cs[i].input(inputFiles[i], false);
}
check_share<T, V>(as, a, mac, n_players, keyp);
check_share<T, V>(bs, b, mac, n_players, keyp);
check_share<T, V>(cs, c, mac, n_players, keyp);
prod.mul(a, b);
if (prod != c)
{
cout << j << ": " << c << " != " << a << " * " << b << endl;
throw bad_value();
}
j++;
}
cout << dec << j << " correct triples of type " << T::type_string() << endl;
delete[] inputFiles;
}
int main(int argc, char** argv)
{
int n_players = 2;
if (argc > 1)
n_players = atoi(argv[1]);
check_triples_Z2k<Z2<SPDZ2K_K + SPDZ2K_S>, Z2<SPDZ2K_S>, Z2<SPDZ2K_K>>(n_players);
}

View File

@@ -297,11 +297,11 @@ int main(int argc, const char** argv)
check_tuples(key2, N, dataF, DATA_INVERSE);
check_tuples(keyp, N, dataF, DATA_INVERSE);
Z2<64> keyz2k;
for (int i = 0; i < N; i++)
keyz2k += read_mac_key<Z2<96> >(PREP_DATA_PREFIX, i);
// Z2<160> keyz2k;
// for (int i = 0; i < N; i++)
// keyz2k += read_mac_key<Z2<96> >(PREP_DATA_PREFIX, i);
check_mult_triples(keyz2k, N, dataF, DATA_Z2K);
// check_mult_triples(keyz2k, N, dataF, DATA_Z2K);
for (int i = 0; i < N; i++)
delete dataF[i];

View File

@@ -94,6 +94,10 @@ class mac_fail: public exception
{ virtual const char* what() const throw()
{ return "MacCheck Failure"; }
};
class consistency_check_fail: public exception
{ virtual const char* what() const throw()
{ return "OT consistency check failed"; }
};
class invalid_program: public exception
{ virtual const char* what() const throw()
{ return "Invalid Program"; }

View File

@@ -291,7 +291,6 @@ void MachineBase::run()
<< " kbit per " << item_type().substr(0, item_type().length() - 1) << endl;
cout << "Produced " << total << " " << item_type() << " in "
<< timer.elapsed() << " seconds" << endl;
cout << "Throughput: " << total / timer.elapsed() << tradeoff() << endl;
cout << "CPU time: " << cpu_timer.elapsed() << endl;
extern unsigned long long sent_amount, sent_counter;
@@ -300,6 +299,8 @@ void MachineBase::run()
cout << sent_amount / sent_counter / N.num_players() << " bytes per call"
<< endl;
cout << "Time: " << timer.elapsed() << endl;
cout << "Throughput: " << total / timer.elapsed() << endl;
mult_performance();
}

View File

@@ -27,14 +27,14 @@ COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH)
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(OT)
LIB = libSPDZ.a
LIBSIMPLEOT = SimpleOT/libsimpleot.a
LIBSIMPLEOT = -no-pie SimpleOT/libsimpleot.a
# used for dependency generation
OBJS = $(COMPLETE)
DEPS := $(OBJS:.o=.d)
all: gen_input online offline externalIO
all: gen_input online offline externalIO check-passive.x Check-Offline-Z2k.x
ifeq ($(USE_NTL),1)
all: overdrive she-offline
@@ -63,6 +63,9 @@ Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR)
Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS)
Check-Offline-Z2k.x: Check-Offline-Z2k.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) Check-Offline-Z2k.cpp -o Check-Offline-Z2k.x $(COMMON) $(PROCESSOR) $(LDLIBS)
Server.x: Server.cpp $(COMMON)
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)

View File

@@ -2,9 +2,9 @@
#include "Share.h"
//#include "Tools/random.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Math/Z2k.h"
#include "Math/operators.h"
@@ -141,3 +141,8 @@ template class Share<gf2n_short>;
template gf2n_short combine(const vector< Share<gf2n_short> >& S);
template bool check_macs(const vector< Share<gf2n_short> >& S,const gf2n_short& key);
#endif
template class Share<Z2<96> >;
template class Share<Z2<160> >;
template class Share<Z2<192> >;
template class Share<Z2<256> >;

View File

@@ -9,7 +9,7 @@ template<int K>
Z2<K>::Z2(const bigint& x) : Z2()
{
auto mp = x.get_mpz_t();
memcpy(a, mp->_mp_d, sizeof(mp_limb_t) * min(N_WORDS, abs(mp->_mp_size)));
memcpy(a, mp->_mp_d, min((size_t)N_BYTES, sizeof(mp_limb_t) * abs(mp->_mp_size)));
if (mp->_mp_size < 0)
*this = Z2<K>() - *this;
}
@@ -20,31 +20,6 @@ bool Z2<K>::get_bit(int i) const
return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS));
}
template<int K>
Z2<K> Z2<K>::operator+(const Z2<K>& other) const
{
Z2<K> res;
mpn_add(res.a, a, N_WORDS, other.a, N_WORDS);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template<int K>
Z2<K> Z2<K>::operator-(const Z2<K>& other) const
{
Z2<K> res;
mpn_sub(res.a, a, N_WORDS, other.a, N_WORDS);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template <int K>
Z2<K>& Z2<K>::operator+=(const Z2<K>& other)
{
*this = *this + other;
return *this;
}
template <int K>
Z2<K> Z2<K>::operator<<(int i) const
{
@@ -52,7 +27,23 @@ Z2<K> Z2<K>::operator<<(int i) const
int n_limb_shift = i / N_LIMB_BITS;
for (int j = n_limb_shift; j < N_WORDS; j++)
res.a[j] = a[j - n_limb_shift];
mpn_lshift(res.a, res.a, N_WORDS, i % N_LIMB_BITS);
int n_inside_shift = i % N_LIMB_BITS;
if (n_inside_shift > 0)
mpn_lshift(res.a, res.a, N_WORDS, n_inside_shift);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template <int K>
Z2<K> Z2<K>::operator>>(int i) const
{
Z2<K> res;
int n_limb_shift = i / N_LIMB_BITS;
for (int j = 0; j < N_WORDS - n_limb_shift; j++)
res.a[j] = a[j + n_limb_shift];
int n_inside_shift = i % N_LIMB_BITS;
if (n_inside_shift > 0)
mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
@@ -60,13 +51,17 @@ Z2<K> Z2<K>::operator<<(int i) const
template<int K>
bool Z2<K>::operator==(const Z2<K>& other) const
{
#ifdef DEBUG_MPN
for (int i = 0; i < N_WORDS; i++)
cout << "cmp " << hex << a[i] << " " << other.a[i] << endl;
#endif
return mpn_cmp(a, other.a, N_WORDS) == 0;
}
template<int K>
void Z2<K>::randomize(PRNG& G)
{
G.get_octets((octet*)a, N_BYTES);
G.get_octets_<N_BYTES>((octet*)a);
}
template<int K>
@@ -98,6 +93,26 @@ void Z2<K>::output(ostream& s, bool human) const
s.write((char*)a, N_BYTES);
}
#define NS X(32) X(64) X(96) X(128) X(160) X(192) X(224) X(256) X(320) X(672)
#define X(N) template class Z2<N>;
NS
template <int K>
ostream& operator<<(ostream& o, const Z2<K>& x)
{
bool printing = false;
o << "0x" << noshowbase;
o.width(0);
for (int i = x.N_WORDS - 1; i >= 0; i--)
if (x.a[i] or printing or i == 0)
{
o << hex << x.a[i];
printing = true;
o.width(16);
o.fill('0');
}
o.width(0);
return o;
}
#define X(N) \
template class Z2<N>; \
template ostream& operator<<(ostream& o, const Z2<N>& x);
X(32) X(64) X(96) X(128) X(160) X(192) X(224) X(256) X(288) X(320) X(352) X(384) X(416) X(448) X(512) X(672)

View File

@@ -13,6 +13,7 @@ using namespace std;
#include "Tools/avx_memcpy.h"
#include "bigint.h"
#include "field_types.h"
#include "mpn_fixed.h"
template <int K>
class Z2
@@ -28,6 +29,7 @@ class Z2
mp_limb_t a[N_WORDS];
public:
static const int N_BITS = K;
static const int N_BYTES = (K + 7) / 8;
@@ -40,12 +42,15 @@ public:
static DataFieldType field_type() { return DATA_Z2K; }
template <int L, int M>
static Z2<K> Mul(const Z2<L>& x, const Z2<M>& y);
typedef Z2<K> value_type;
Z2() { assign_zero(); }
Z2(uint64_t x) : Z2() { a[0] = x; }
Z2(__m128i x) : Z2() { avx_memcpy(a, &x, min(N_BYTES, 16)); }
Z2(int x) : Z2(x < 0 ? bigint(x) : uint64_t(x)) {}
Z2(int x) : Z2() { if (x < 0) *this = bigint(x); else *this = uint64_t(x); }
Z2(const bigint& x);
Z2(const void* buffer) : Z2() { assign(buffer); }
template <int L>
@@ -61,17 +66,23 @@ public:
const void* get_ptr() const { return a; }
void negate() {
throw not_implemented();
}
Z2<K> operator+(const Z2<K>& other) const;
Z2<K> operator-(const Z2<K>& other) const;
template <int L>
Z2<K+L> operator*(const Z2<L>& other) const;
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(0); }
Z2<K> operator*(bool other) const { return other ? *this : Z2<K>(); }
Z2<K>& operator+=(const Z2<K>& other);
Z2<K>& operator-=(const Z2<K>& other);
Z2<K> operator<<(int i) const;
Z2<K> operator>>(int i) const;
bool operator==(const Z2<K>& other) const;
bool operator!=(const Z2<K>& other) const { return not (*this == other); }
@@ -79,8 +90,8 @@ public:
void add(const Z2<K>& a, const Z2<K>& b) { *this = a + b; }
void add(const Z2<K>& a) { *this += a; }
void sub(const Z2<K>& a, const Z2<K>& b) { *this = a - b; }
template <int L>
void mul(const Z2<K>& a, const Z2<L>& b) { *this = a * b; }
template <int M, int L>
void mul(const Z2<M>& a, const Z2<L>& b) { *this = Z2<K>::Mul(a, b); }
template <int t>
void add(octetStream& os) { add(os.consume(size())); }
@@ -100,26 +111,54 @@ public:
friend ostream& operator<<(ostream& o, const Z2<J>& x);
};
template<int K>
inline Z2<K> Z2<K>::operator+(const Z2<K>& other) const
{
Z2<K> res;
mpn_add_fixed_n<N_WORDS>(res.a, a, other.a);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template<int K>
Z2<K> Z2<K>::operator-(const Z2<K>& other) const
{
Z2<K> res;
mpn_sub_fixed_n<N_WORDS>(res.a, a, other.a);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template <int K>
inline Z2<K>& Z2<K>::operator+=(const Z2<K>& other)
{
mpn_add_fixed_n<N_WORDS>(a, other.a, a);
a[N_WORDS - 1] &= UPPER_MASK;
return *this;
}
template <int K>
Z2<K>& Z2<K>::operator-=(const Z2<K>& other)
{
*this = *this - other;
return *this;
}
template <int K>
template <int L, int M>
inline Z2<K> Z2<K>::Mul(const Z2<L>& x, const Z2<M>& y)
{
Z2<K> res;
mpn_mul_fixed_<N_WORDS, x.N_WORDS, y.N_WORDS>(res.a, x.a, y.a);
res.a[N_WORDS - 1] &= UPPER_MASK;
return res;
}
template <int K>
template <int L>
inline Z2<K+L> Z2<K>::operator*(const Z2<L>& other) const
{
mp_limb_t product[N_WORDS + other.N_WORDS];
if (K < L)
mpn_mul(product, other.a, other.N_WORDS, a, N_WORDS);
else
mpn_mul(product, a, N_WORDS, other.a, other.N_WORDS);
Z2<K+L> res;
avx_memcpy(res.a, product, res.N_BYTES);
return res;
}
template <int K>
inline ostream& operator<<(ostream& o, const Z2<K>& x)
{
for (int i = 0; i < x.N_WORDS; i++)
o << hex << x.a[i] << " ";
return o;
return Z2<K+L>::Mul(*this, other);
}
#endif /* MATH_Z2K_H_ */

View File

@@ -14,7 +14,7 @@ bool modp::rewind = false;
void modp::randomize(PRNG& G, const Zp_Data& ZpD)
{
bigint x=G.randomBnd(ZpD.pr);
memcpy(this->x, x.get_mpz_t()->_mp_d, ZpD.get_t());
memcpy(this->x, x.get_mpz_t()->_mp_d, ZpD.get_t() * sizeof(mp_limb_t));
}
void modp::pack(octetStream& o,const Zp_Data& ZpD) const

219
Math/mpn_fixed.h Normal file
View File

@@ -0,0 +1,219 @@
/*
* mpn_fixed.h
*
*/
#ifndef MATH_MPN_FIXED_H_
#define MATH_MPN_FIXED_H_
#include <mpir.h>
#include <string.h>
#include <assert.h>
inline void debug_print(const char* name, const mp_limb_t* x, int n)
{
(void)name, (void)x, (void)n;
#ifdef DEBUG_MPN
cout << name << " ";
for (int i = 0; i < n; i++)
cout << hex << x[n-i-1] << " ";
cout << endl;
#endif
}
template <int N>
inline void mpn_add_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
mpn_add(res, x, N, y, N);
}
template <>
inline void mpn_add_fixed_n<1>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
*res = *x + *y;
}
template <>
inline void mpn_add_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, y, 2 * sizeof(mp_limb_t));
debug_print("x", x, 2);
debug_print("y", y, 2);
debug_print("res", res, 2);
__asm__ (
"add %2, %0 \n"
"adc %3, %1 \n"
: "+&r"(res[0]), "+r"(res[1])
: "rm"(x[0]), "rm"(x[1])
: "cc"
);
debug_print("res", res, 2);
}
template <>
inline void mpn_add_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, y, 3 * sizeof(mp_limb_t));
debug_print("x", x, 3);
debug_print("y", y, 3);
debug_print("res", res, 3);
__asm__ (
"add %3, %0 \n"
"adc %4, %1 \n"
"adc %5, %2 \n"
: "+&r"(res[0]), "+&r"(res[1]), "+r"(res[2])
: "rm"(x[0]), "rm"(x[1]), "rm"(x[2])
: "cc"
);
debug_print("res", res, 3);
}
template <>
inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, y, 4 * sizeof(mp_limb_t));
__asm__ (
"add %4, %0 \n"
"adc %5, %1 \n"
"adc %6, %2 \n"
"adc %7, %3 \n"
: "+&r"(res[0]), "+&r"(res[1]), "+&r"(res[2]), "+r"(res[3])
: "rm"(x[0]), "rm"(x[1]), "rm"(x[2]), "rm"(x[3])
: "cc"
);
}
template <int N>
inline void mpn_sub_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
mpn_add(res, x, N, y, N);
}
template <>
inline void mpn_sub_fixed_n<1>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
*res = *x - *y;
}
template <>
inline void mpn_sub_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, x, 2 * sizeof(mp_limb_t));
__asm__ (
"sub %2, %0 \n"
"sbb %3, %1 \n"
: "+r"(res[0]), "+r"(res[1])
: "rm"(y[0]), "rm"(y[1])
: "cc"
);
}
template <>
inline void mpn_sub_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, x, 3 * sizeof(mp_limb_t));
__asm__ volatile (
"sub %3, %0 \n"
"sbb %4, %1 \n"
"sbb %5, %2 \n"
: "+r"(res[0]), "+r"(res[1]), "+r"(res[2])
: "rm"(y[0]), "rm"(y[1]), "rm"(y[2])
: "cc"
);
}
template <>
inline void mpn_sub_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, x, 4 * sizeof(mp_limb_t));
__asm__ volatile (
"sub %4, %0 \n"
"sbb %5, %1 \n"
"sbb %6, %2 \n"
"sbb %7, %3 \n"
: "+r"(res[0]), "+r"(res[1]), "+r"(res[2]), "+r"(res[3])
: "rm"(y[0]), "rm"(y[1]), "rm"(y[2]), "rm"(y[3])
: "cc"
);
}
inline void mpn_add_n_use_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y, mp_size_t n)
{
switch (n)
{
#define CASE(N) \
case N: \
mpn_add_fixed_n<N>(res, x, y); \
break;
CASE(1);
CASE(2);
CASE(3);
CASE(4);
default:
mpn_add_n(res, x, y, n);
break;
}
}
template <int L, int M>
inline void mpn_addmul_1_fixed_(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x)
{
mp_limb_t lower[L], higher[L];
lower[L - 1] = 0;
higher[L - 1] = 0;
for (int j = 0; j < M; j++)
lower[j] = _mulx_u64(x, y[j], (long long unsigned*)higher + j);
debug_print("lower", lower, L);
debug_print("higher", higher, L);
debug_print("before add", res, L + 1);
mpn_add_fixed_n<L>(res, lower, res);
debug_print("first add", res, L + 1);
mpn_add_fixed_n<L - 1>(res + 1, higher, res + 1);
debug_print("second add", res, L + 1);
}
template <int M>
inline void mpn_addmul_1_fixed(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x)
{
mpn_addmul_1_fixed_<M + 1, M>(res, y, x);
}
template <int L, int N, int M>
inline void mpn_mul_fixed_(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
assert(L <= N + M + 2);
mp_limb_t tmp[N + M + 2];
avx_memzero(tmp, sizeof(tmp));
for (int i = 0; i < N; i++)
mpn_addmul_1_fixed<M>(tmp + i, y, x[i]);
inline_mpn_copyi(res, tmp, L);
}
template <>
inline void mpn_mul_fixed_<3,3,3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
inline_mpn_zero(res, 3);
mp_limb_t* tmp = res;
mpn_addmul_1_fixed_<3,3>(tmp, y, x[0]);
mpn_addmul_1_fixed_<2,2>(tmp + 1, y, x[1]);
mpn_addmul_1_fixed_<1,1>(tmp + 2, y, x[2]);
}
template <>
inline void mpn_mul_fixed_<4,4,2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
inline_mpn_zero(res, 4);
mp_limb_t* tmp = res;
mpn_addmul_1_fixed_<3,2>(tmp, y, x[0]);
mpn_addmul_1_fixed_<3,2>(tmp + 1, y, x[1]);
mpn_addmul_1_fixed_<2,2>(tmp + 2, y, x[2]);
mpn_addmul_1_fixed_<1,1>(tmp + 3, y, x[3]);
}
template <int N, int M>
inline void mpn_mul_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
mpn_mul_fixed_<N + M, N, M>(res, x, y);
}
#endif /* MATH_MPN_FIXED_H_ */

View File

@@ -192,8 +192,7 @@ void square128::transpose()
} \
}
#ifdef __AVX2__
#define SIXTEENTOSIXTYFOUR Y(16) Y(32) Y(64)
#define Y(I) { \
#define Z(I) { \
const int J = I / 8; \
for (int i = 0; i < 16 / J; i++) \
{ \
@@ -216,7 +215,7 @@ void square128::transpose()
base += 16;
X(8)
base = k * 16;
SIXTEENTOSIXTYFOUR
Z(16) Z(32) Z(64)
for (int i = 0; i < 8; i++)
{
int a = base + i;
@@ -335,7 +334,10 @@ void square128::to(gfp& result)
for (int i = 0; i < 128; i++)
{
memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i]));
mpn_lshift(product, tmp[i/64], 4, i % 64);
if (i % 64 == 0)
memcpy(product, tmp[i/64], sizeof(product));
else
mpn_lshift(product, tmp[i/64], 4, i % 64);
mpn_add_n(sum, product, sum, 4);
}
mp_limb_t q[4], ans[4];
@@ -567,7 +569,7 @@ Matrix<U>& Matrix<U>::operator=(const Matrix<V>& other)
avx_memcpy(seed, &source, min(SEED_SIZE, (int)sizeof(source)));
prng.SetSeed(seed);
dest = 0;
prng.get_octets((octet*)dest.get_ptr(), U::N_ROW_BYTES);
prng.get_octets_<U::N_ROW_BYTES>((octet*)dest.get_ptr());
}
}
return *this;
@@ -718,7 +720,7 @@ Slice<U>& Slice<U>::rsub(Slice<U>& other)
if (bm.squares.size() < other.end)
throw invalid_length();
for (size_t i = other.start; i < other.end; i++)
bm.squares[i].rsub<T>(other.bm.squares[i]);
bm.squares[i].template rsub<T>(other.bm.squares[i]);
return *this;
}
@@ -727,7 +729,7 @@ template <class T>
Slice<U>& Slice<U>::sub(BitVector& other, int repeat)
{
if (end * U::PartType::N_COLUMNS > other.size() * repeat)
throw invalid_length();
throw invalid_length(to_string(U::PartType::N_COLUMNS));
for (size_t i = start; i < end; i++)
{
bm.squares[i].template sub<T>(other.get_ptr_to_byte(i / repeat,
@@ -741,7 +743,7 @@ template <class T>
void Slice<U>::randomize(int row, PRNG& G)
{
for (size_t i = start; i < end; i++)
bm.squares[i].randomize<T>(row, G);
bm.squares[i].template randomize<T>(row, G);
}
template <class U>
@@ -749,7 +751,7 @@ template <class T>
void Slice<U>::conditional_add(BitVector& conditions, U& other, bool useOffset)
{
for (size_t i = start; i < end; i++)
bm.squares[i].conditional_add<T>(conditions, other.squares[i], useOffset * i);
bm.squares[i].template conditional_add<T>(conditions, other.squares[i], useOffset * i);
}
template <>
@@ -774,6 +776,7 @@ void Slice<U>::print()
template <class U>
void Slice<U>::pack(octetStream& os) const
{
os.reserve(U::PartType::size() * (end - start));
for (size_t i = start; i < end; i++)
bm.squares[i].pack(os);
}
@@ -801,9 +804,11 @@ template void Slice<Matrix<Rectangle<Z2<N>, Z2<L> > > >::conditional_add< \
Z2<L> >(BitVector& conditions, \
Matrix<Rectangle<Z2<N>, Z2<L> > >& other, bool useOffset); \
X(96, 160)
//X(96, 160)
Y(64, 96)
Y(64, 64)
Y(32, 32)
template class Matrix<square128>;

View File

@@ -20,10 +20,14 @@
using namespace std;
union square128 {
typedef int128 RowType;
const static int N_ROWS = 128;
const static int N_COLUMNS = 128;
const static int N_ROW_BYTES = 128 / 8;
static size_t size() { return N_ROWS * sizeof(__m128i); }
#ifdef __AVX2__
__m256i doublerows[64];
#endif

View File

@@ -10,6 +10,7 @@
using namespace std;
#include <stdlib.h>
#include <pmmintrin.h>
#include <assert.h>
#include "Exceptions/Exceptions.h"
#include "Networking/data.h"
@@ -161,6 +162,7 @@ class BitVector
bool get_bit(int i) const
{
assert(i < (int)nbits);
return (bytes[i/8] >> (i % 8)) & 1;
}
void set_bit(int i,unsigned int a)
@@ -216,7 +218,7 @@ class BitVector
void pack(octetStream& o) const;
void unpack(octetStream& o);
string str(size_t end = SIZE_MAX)
string str(size_t end = SIZE_MAX) const
{
stringstream ss;
ss << hex;

View File

@@ -68,15 +68,19 @@ public:
template <int M>
void amplify(BitVector& a, T& b, Rectangle<Z2<M>, T>& c, PRNG& G)
{
avx_memzero(this, sizeof(*this));
assert(a.size() == M);
this->b = b;
for (int i = 0; i < N; i++)
{
T r;
r.randomize(G);
this->a[i] += r * a.get_bit(i);
this->c[i] += r * c.rows[i];
this->a[i] = 0;
this->c[i] = 0;
for (int j = 0; j < M; j++)
{
T r;
r.randomize(G);
this->a[i] += r * a.get_bit(j);
this->c[i] += r * c.rows[j];
}
}
}
@@ -106,8 +110,45 @@ public:
}
};
template <class T, int N>
class ShareTriple : public Triple<Share<T>, N>
// T is Z2<K + 2S>, U is Z2<K + S>
template <class T, class U, int N>
class PlainTriple_ : public PlainTriple<T,N>
{
public:
template <int M>
void amplify(BitVector& a, U& b, Rectangle<Z2<M>, U>& c, PRNG& G)
{
assert(a.size() == M);
this->b = b;
for (int i = 0; i < N; i++)
{
U aa = 0, cc = 0;
for (int j = 0; j < M; j++)
{
U r;
r.randomize(G);
if (a.get_bit(j))
aa += r;
cc += U::Mul(r, c.rows[j]);
}
this->a[i] = aa;
this->c[i] = cc;
}
}
void to(vector<BitVector>& valueBits, int i)
{
for (int j = 0; j < N; j++)
{
valueBits[0].set_portion(i * N + j, this->a[j]);
valueBits[2].set_portion(i * N + j, this->c[j]);
}
}
};
template <class T, class U, int N>
class ShareTriple_ : public Triple<Share<T>, N>
{
public:
void from(PlainTriple<T,N>& triple, vector<OTMultiplierBase*>& ot_multipliers,
@@ -119,9 +160,10 @@ public:
for (int j = 0; j < repeat; j++)
{
T value = triple.byIndex(l,j);
T mac = value * generator.machine.get_mac_key<T>();
T mac;
mac.mul(value, generator.machine.get_mac_key<U>());
for (int i = 0; i < generator.nparties-1; i++)
mac += ((MascotMultiplier<T>*)ot_multipliers[i])->macs[l][iTriple * repeat + j];
mac += ((OTMultiplierMac<T>*)ot_multipliers[i])->macs.at(l).at(iTriple * repeat + j);
Share<T>& share = this->byIndex(l,j);
share.set_share(value);
share.set_mac(mac);
@@ -129,9 +171,59 @@ public:
}
}
T computeCheckMAC(const T& maskedA)
Share<T> get_check_value(PRNG& G)
{
return this->c[0].get_mac() - maskedA * this->b.get_mac();
Share<T> res;
res += G.get<T>() * this->b;
for (int i = 0; i < N; i++)
{
res += G.get<T>() * this->a[i];
res += G.get<T>() * this->c[i];
}
return res;
}
template<class V>
Triple<Share<V>, 1> reduce() {
Triple<Share<V>, 1> triple;
Share<V> _a;
_a.set_share(V(this->a[0].get_share()));
_a.set_mac(V(this->a[0].get_mac()));
triple.a[0] = _a;
Share<V> _b;
_b.set_share(V(this->b.get_share()));
_b.set_mac(V(this->b.get_mac()));
triple.b = _b;
Share<V> _c;
_c.set_share(V(this->c[0].get_share()));
_c.set_mac(V(this->c[0].get_mac()));
triple.c[0] = _c;
return triple;
}
};
template <class T>
class TripleToSacrifice : public Triple<Share<T>, 1>
{
public:
template <class U>
void prepare_sacrifice(const ShareTriple_<T, U, 2>& uncheckedTriple, PRNG& G)
{
this->b = uncheckedTriple.b;
U t;
t.randomize(G);
this->a[0] = uncheckedTriple.a[0] * t - uncheckedTriple.a[1];
this->c[0] = uncheckedTriple.c[0] * t - uncheckedTriple.c[1];
}
Share<T> computeCheckShare(const T& maskedA)
{
return this->c[0] - maskedA * this->b;
}
};
@@ -221,11 +313,23 @@ OTMultiplierBase* NPartyTripleGenerator::new_multiplier(int i)
}
template<>
OTMultiplierBase* NPartyTripleGenerator::new_multiplier<Z2<64> >(int i)
OTMultiplierBase* NPartyTripleGenerator::new_multiplier<Z2<160> >(int i)
{
return new Spdz2kMultiplier<64, 96>(*this, i);
}
template<>
OTMultiplierBase* NPartyTripleGenerator::new_multiplier<Z2<128> >(int i)
{
return new Spdz2kMultiplier<64, 64>(*this, i);
}
template<>
OTMultiplierBase* NPartyTripleGenerator::new_multiplier<Z2<64> >(int i)
{
return new Spdz2kMultiplier<32, 32>(*this, i);
}
template<class T>
void NPartyTripleGenerator::generate()
{
@@ -300,11 +404,9 @@ void NPartyTripleGenerator::generateBits<gf2n>(vector< OTMultiplierBase* >& ot_m
valueBits[0].randomize_blocks<gf2n>(share_prg);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
signal_multipliers();
timers["Authentication OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
wait_for_multipliers();
timers["Authentication OTs"].stop();
octet seed[SEED_SIZE];
@@ -334,8 +436,7 @@ void NPartyTripleGenerator::generateBits<gf2n>(vector< OTMultiplierBase* >& ot_m
for (int j = 0; j < nTriplesPerLoop; j++)
bits[j].output(outputFile, false);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
signal_multipliers();
}
}
@@ -355,58 +456,107 @@ void NPartyTripleGenerator::generateBits(vector< OTMultiplierBase* >& ot_multipl
throw not_implemented();
}
template<int K, int S>
void NPartyTripleGenerator::generateTriplesZ2k(vector< OTMultiplierBase* >& ot_multipliers,
ofstream& outputFile)
{
(void) outputFile;
const int TAU = Spdz2kMultiplier<K, S>::TAU;
valueBits.resize(3);
for (int i = 0; i < 2; i++)
valueBits[2*i].resize(TAU * nTriplesPerLoop);
valueBits[1].resize((K + S) * (nTriplesPerLoop + 1));
b_padded_bits.resize((K + 2 * S) * (nTriplesPerLoop + 1));
vector< PlainTriple_<Z2<K + 2 * S>, Z2<K + S>, 2> > amplifiedTriples(nTriplesPerLoop);
vector< ShareTriple_<Z2<K + 2 * S>, Z2<S>, 2> > uncheckedTriples(nTriplesPerLoop);
MAC_Check_Z2k<Z2<K + 2 * S>, Z2<S>, Z2<K + S> > MC(machine.get_mac_key<Z2<S> >());
start_progress();
for (int k = 0; k < nloops; k++)
{
print_progress(k);
for (int j = 0; j < 2; j++)
valueBits[j].randomize_blocks<gf2n>(share_prg);
for (int j = 0; j < nTriplesPerLoop + 1; j++)
{
Z2<K + S> b(valueBits[1].get_ptr_to_bit(j, K + S));
b_padded_bits.set_portion(j, Z2<K + 2 * S>(b));
}
timers["OTs"].start();
wait_for_multipliers();
timers["OTs"].stop();
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
PRNG G;
G.SetSeed(seed);
for (int j = 0; j < nTriplesPerLoop; j++)
{
BitVector a(valueBits[0].get_ptr_to_bit(j, TAU), TAU);
Z2<K + S> b(valueBits[1].get_ptr_to_bit(j, K + S));
Z2kRectangle<TAU, K + S> c;
c.mul(a, b);
timers["Triple computation"].start();
for (int i = 0; i < nparties-1; i++)
{
c += ((Spdz2kMultiplier<K, S>*)ot_multipliers[i])->c_output[j];
}
timers["Triple computation"].stop();
amplifiedTriples[j].amplify(a, b, c, G);
amplifiedTriples[j].to(valueBits, j);
}
signal_multipliers();
wait_for_multipliers();
for (int j = 0; j < nTriplesPerLoop; j++)
{
uncheckedTriples[j].from(amplifiedTriples[j], ot_multipliers, j, *this);
}
// we can skip the consistency check since we're doing a mac-check next
// get piggy-backed random value
Z2<K + 2 * S> r_share = b_padded_bits.get_ptr_to_bit(nTriplesPerLoop, K + 2 * S);
Z2<K + 2 * S> r_mac;
r_mac.mul(r_share, this->machine.template get_mac_key<Z2<S>>());
for (int i = 0; i < this->nparties-1; i++)
r_mac += ((OTMultiplierMac<Z2<K + 2 * S>>*)ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop);
Share<Z2<K + 2 * S>> r;
r.set_share(r_share);
r.set_mac(r_mac);
MC.set_random_element(r);
sacrifice<Z2<K + 2 * S>, Z2<S>, Z2<K + S>>(uncheckedTriples, MC, G);
signal_multipliers();
}
}
template<>
void NPartyTripleGenerator::generateTriples<Z2<64> >(vector< OTMultiplierBase* >& ot_multipliers,
ofstream& outputFile)
{
const int K = 64;
const int S = 96;
const int TAU = Spdz2kMultiplier<K, S>::TAU;
valueBits.resize(3);
for (int i = 0; i < 2; i++)
valueBits[2*i].resize(TAU * nTriplesPerLoop);
valueBits[1].resize((K + S) * nTriplesPerLoop);
vector< PlainTriple<Z2<K + S>, 2> > amplifiedTriples(nTriplesPerLoop);
this->template generateTriplesZ2k<32, 32>(ot_multipliers, outputFile);
}
start_progress();
template<>
void NPartyTripleGenerator::generateTriples<Z2<128> >(vector< OTMultiplierBase* >& ot_multipliers,
ofstream& outputFile)
{
this->template generateTriplesZ2k<64, 64>(ot_multipliers, outputFile);
}
for (int k = 0; k < nloops; k++)
{
print_progress(k);
for (int j = 0; j < 2; j++)
valueBits[j].randomize_blocks<gf2n>(share_prg);
signal_multipliers();
timers["OTs"].start();
wait_for_multipliers();
timers["OTs"].stop();
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
PRNG G;
G.SetSeed(seed);
for (int j = 0; j < nTriplesPerLoop; j++)
{
BitVector a(valueBits[0].get_ptr_to_bit(j, TAU), TAU);
Z2<K + S> b(valueBits[1].get_ptr_to_bit(j, K + S));
Z2kRectangle<TAU, K + S> c;
c.mul(a, b);
timers["Triple computation"].start();
for (int i = 0; i < nparties-1; i++)
{
c += ((Spdz2kMultiplier<K, S>*)ot_multipliers[i])->c_output[j];
}
timers["Triple computation"].stop();
PlainTriple<Z2<K + S>, 2> amplifiedTriple;
amplifiedTriple.amplify(a, b, c, G);
if (machine.output)
amplifiedTriple.output(outputFile);
}
signal_multipliers();
}
template<>
void NPartyTripleGenerator::generateTriples<Z2<160> >(vector< OTMultiplierBase* >& ot_multipliers,
ofstream& outputFile)
{
this->template generateTriplesZ2k<64, 96>(ot_multipliers, outputFile);
}
template<class T>
@@ -440,8 +590,7 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult
valueBits[j].randomize_blocks<T>(share_prg);
timers["OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
wait_for_multipliers();
timers["OTs"].stop();
for (int j = 0; j < nPreampTriplesPerLoop; j++)
@@ -497,11 +646,9 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
amplifiedTriples[iTriple].to(valueBits, iTriple);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
signal_multipliers();
timers["Authentication OTs"].start();
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
wait_for_multipliers();
timers["Authentication OTs"].stop();
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
@@ -518,38 +665,88 @@ void NPartyTripleGenerator::generateTriples(vector< OTMultiplierBase* >& ot_mult
if (machine.check)
{
vector< Share<T> > maskedAs(nTriplesPerLoop);
vector< ShareTriple<T,1> > maskedTriples(nTriplesPerLoop);
for (int j = 0; j < nTriplesPerLoop; j++)
{
maskedTriples[j].amplify(uncheckedTriples[j], G);
maskedAs[j] = maskedTriples[j].a[0];
}
vector<T> openedAs(nTriplesPerLoop);
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
MC.POpen_End(openedAs, maskedAs, globalPlayer);
for (int j = 0; j < nTriplesPerLoop; j++)
MC.AddToCheck(maskedTriples[j].computeCheckMAC(openedAs[j]), 0, globalPlayer);
MC.Check(globalPlayer);
if (machine.generateBits)
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
else
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
uncheckedTriples[j].output(outputFile, 1);
sacrifice(uncheckedTriples, MC, G);
}
}
}
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
signal_multipliers();
}
}
template<class T, class U>
void NPartyTripleGenerator::sacrifice(
vector<ShareTriple_<T, U, 2> > uncheckedTriples, MAC_Check<T>& MC, PRNG& G)
{
vector< Share<T> > maskedAs(nTriplesPerLoop);
vector<TripleToSacrifice<T> > maskedTriples(nTriplesPerLoop);
for (int j = 0; j < nTriplesPerLoop; j++)
{
maskedTriples[j].prepare_sacrifice(uncheckedTriples[j], G);
maskedAs[j] = maskedTriples[j].a[0];
}
vector<T> openedAs(nTriplesPerLoop);
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
MC.POpen_End(openedAs, maskedAs, globalPlayer);
for (int j = 0; j < nTriplesPerLoop; j++) {
MC.AddToCheck(maskedTriples[j].computeCheckShare(openedAs[j]), 0,
globalPlayer);
}
MC.Check(globalPlayer);
if (machine.generateBits)
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
else
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
uncheckedTriples[j].output(outputFile, 1);
}
template<class T, class U, class V>
void NPartyTripleGenerator::sacrifice(
vector<ShareTriple_<T, U, 2> > uncheckedTriples, MAC_Check_Z2k<T, U, V>& MC, PRNG& G)
{
vector< Share<T> > maskedAs(nTriplesPerLoop);
vector<TripleToSacrifice<T> > maskedTriples(nTriplesPerLoop);
for (int j = 0; j < nTriplesPerLoop; j++)
{
// compute [p] = t * [a] - [ahat]
// and first part of [sigma], i.e., t * [c] - [chat]
maskedTriples[j].prepare_sacrifice(uncheckedTriples[j], G);
maskedAs[j] = maskedTriples[j].a[0];
}
vector<T> openedAs(nTriplesPerLoop);
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
MC.POpen_End(openedAs, maskedAs, globalPlayer);
vector<Share<T>> sigmas;
for (int j = 0; j < nTriplesPerLoop; j++) {
// compute t * [c] - [chat] - [b] * p
sigmas.push_back(maskedTriples[j].computeCheckShare(V(openedAs[j])));
}
vector<T> open_sigmas;
MC.POpen_Begin(open_sigmas, sigmas, globalPlayer);
MC.POpen_End(open_sigmas, sigmas, globalPlayer);
MC.Check(globalPlayer);
for (int j = 0; j < nTriplesPerLoop; j++) {
if (V(open_sigmas[j]) != 0)
throw mac_fail();
}
if (machine.generateBits)
generateBitsFromTriples(uncheckedTriples, MC, outputFile);
else
if (machine.output)
for (int j = 0; j < nTriplesPerLoop; j++)
uncheckedTriples[j].template reduce<V>().output(outputFile, 1);
}
template<>
void NPartyTripleGenerator::generateBitsFromTriples(
vector< ShareTriple<gfp,2> >& triples, MAC_Check<gfp>& MC, ofstream& outputFile)
@@ -576,9 +773,9 @@ void NPartyTripleGenerator::generateBitsFromTriples(
}
}
template<>
template<class T, class U>
void NPartyTripleGenerator::generateBitsFromTriples(
vector< ShareTriple<gf2n,2> >& triples, MAC_Check<gf2n>& MC, ofstream& outputFile)
vector< ShareTriple_<T, U, 2> >& triples, MAC_Check<T>& MC, ofstream& outputFile)
{
throw how_would_that_work();
// warning gymnastics
@@ -589,14 +786,12 @@ void NPartyTripleGenerator::generateBitsFromTriples(
void NPartyTripleGenerator::start_progress()
{
for (int i = 0; i < nparties-1; i++)
pthread_cond_wait(&ot_multipliers[i]->ready, &ot_multipliers[i]->mutex);
wait_for_multipliers();
lock();
signal();
wait();
gettimeofday(&last_lap, 0);
for (int i = 0; i < nparties-1; i++)
pthread_cond_signal(&ot_multipliers[i]->ready);
signal_multipliers();
}
void NPartyTripleGenerator::print_progress(int k)
@@ -655,3 +850,5 @@ template void NPartyTripleGenerator::generate<gf2n>();
template void NPartyTripleGenerator::generate<gfp>();
template void NPartyTripleGenerator::generate<Z2<64> >();
template void NPartyTripleGenerator::generate<Z2<128> >();
template void NPartyTripleGenerator::generate<Z2<160> >();

View File

@@ -19,8 +19,11 @@
#define N_AMPLIFY 3
template <class T, class U, int N>
class ShareTriple_;
template <class T, int N>
class ShareTriple;
using ShareTriple = ShareTriple_<T, T, N>;
class NPartyTripleGenerator
{
@@ -28,7 +31,6 @@ class NPartyTripleGenerator
Player globalPlayer;
int thread_num;
int my_num;
int nbase;
struct timeval last_lap;
@@ -40,14 +42,25 @@ class NPartyTripleGenerator
ofstream outputFile;
PRNG share_prg;
template <int K, int S>
void generateTriplesZ2k(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile);
template <class T>
void generateTriples(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile);
template <class T>
void generateBits(vector< OTMultiplierBase* >& ot_multipliers, ofstream& outputFile);
template <class T, int N>
void generateBitsFromTriples(vector<ShareTriple<T, N> >& triples,
template <class T, class U>
void generateBitsFromTriples(vector<ShareTriple_<T, U, 2> >& triples,
MAC_Check<T>& MC, ofstream& outputFile);
template <class T, class U>
void sacrifice(vector<ShareTriple_<T, U, 2> > uncheckedTriples,
MAC_Check<T>& MC, PRNG& G);
template <class T, class U, class V>
void sacrifice(vector<ShareTriple_<T, U, 2> > uncheckedTriples,
MAC_Check_Z2k<T, U, V>& MC, PRNG& G);
void start_progress();
void print_progress(int k);
@@ -65,6 +78,10 @@ public:
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
vector<BitVector> valueBits;
BitVector b_padded_bits;
int my_num;
int nTriples;
int nTriplesPerLoop;

View File

@@ -299,7 +299,6 @@ void OTExtension::transfer(int nOTs,
// randomize last 128 + 128 bits that will be discarded
for (int i = 0; i < 4; i++)
newReceiverInput.set_word(nOTs/64 - i, G.get_word());
// expand with PRG and create correlation
if (ot_role & RECEIVER)
{

View File

@@ -76,16 +76,15 @@ void OTCorrelator<U>::resize(int nOTs)
receiverOutputMatrix.resize_vertical(nOTs);
}
template<>
void OTExtensionWithMatrix::extend<Z2<160> >(int nOTs_requested,
BitVector& newReceiverInput)
{
extend<gf2n>(nOTs_requested, newReceiverInput);
}
// the template is used to denote the field of the hash output
template <class T>
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
{
extend_correlated(nOTs_requested, newReceiverInput);
hash_outputs<T>(nOTs_requested);
}
void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& newReceiverInput)
{
// if (nOTs % nbaseOTs != 0)
// throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n");
@@ -133,8 +132,6 @@ void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInp
#endif
}
hash_outputs<T>(nOTs);
receiverOutputMatrix.resize(nOTs_requested_rounded);
senderOutputMatrices[0].resize(nOTs_requested_rounded);
senderOutputMatrices[1].resize(nOTs_requested_rounded);
@@ -190,8 +187,8 @@ void OTExtensionWithMatrix::expand_transposed()
template <class U>
void OTCorrelator<U>::setup_for_correlation(BitVector& baseReceiverInput,
vector<BitMatrix>& baseSenderOutputs,
BitMatrix& baseReceiverOutput)
vector<U>& baseSenderOutputs,
U& baseReceiverOutput)
{
this->baseReceiverInput = baseReceiverInput;
receiverOutputMatrix = baseSenderOutputs[0];
@@ -282,28 +279,55 @@ void OTExtensionWithMatrix::transpose(int start, int slice)
*/
template <class T>
void OTExtensionWithMatrix::hash_outputs(int nOTs)
{
hash_outputs<T>(nOTs, senderOutputMatrices, receiverOutputMatrix);
}
template <class T, class V>
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput)
{
//cout << "Hashing... " << flush;
octetStream os, h_os(HASH_SIZE);
square128 tmp;
MMO mmo;
#ifdef OTEXT_TIMER
timeval startv, endv;
gettimeofday(&startv, NULL);
#endif
for (int i = 0; i < nOTs / 128; i++)
for (int i = 0; i < 2; i++)
senderOutput[i].resize_vertical(nOTs);
receiverOutput.resize_vertical(nOTs);
int n_rows = V::PartType::N_ROWS;
if (V::PartType::N_COLUMNS != T::size() * 8)
throw runtime_error("length mismatch for MMO hash");
if (nOTs % 8 != 0)
throw runtime_error("number of OTs must be divisible by 8");
for (int i = 0; i < nOTs; i += 8)
{
int i_outer_input = i / 128;
int i_inner_input = i % 128;
int i_outer_output = i / n_rows;
int i_inner_output = i % n_rows;
if (ot_role & SENDER)
{
tmp = senderOutputMatrices[0].squares[i];
tmp ^= baseReceiverInput;
senderOutputMatrices[0].squares[i].hash_row_wise<T>(mmo, senderOutputMatrices[0].squares[i]);
senderOutputMatrices[1].squares[i].hash_row_wise<T>(mmo, tmp);
int128 tmp[2][8];
for (int j = 0; j < 8; j++)
{
tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j];
tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0);
}
for (int j = 0; j < 2; j++)
mmo.hashBlocks<T, 8>(
&senderOutput[j].squares[i_outer_output].rows[i_inner_output],
&tmp[j]);
}
if (ot_role & RECEIVER)
{
receiverOutputMatrix.squares[i].hash_row_wise<T>(mmo, receiverOutputMatrix.squares[i]);
mmo.hashBlocks<T, 8>(
&receiverOutput.squares[i_outer_output].rows[i_inner_output],
&receiverOutputMatrix.squares[i_outer_input].rows[i_inner_input]);
}
}
//cout << "done.\n";
@@ -324,10 +348,7 @@ void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output)
output.resize(nTriples);
for (unsigned int j = 0; j < nTriples; j++)
{
T c1, c2;
receiverOutputMatrix.squares[j].to(c1);
senderOutputMatrices[0].squares[j].to(c2);
output[j] = c1 - c2;
receiverOutputMatrix.squares[j].template sub<T>(senderOutputMatrices[0].squares[j]).to(output[j]);
}
}
@@ -467,10 +488,17 @@ void OTExtensionWithMatrix::print_pre_expand()
template class OTCorrelator<BitMatrix>;
template class OTCorrelator<Matrix<square128> >;
template void OTCorrelator<Matrix<square128> >::correlate<gf2n>(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat);
template void OTCorrelator<Matrix<square128> >::correlate<gfp>(int start, int slice,
BitVector& newReceiverInput, bool useConstantBase, int repeat);
#define Z(BM,GF) \
template void OTCorrelator<BM>::correlate<GF>(int start, int slice, \
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
template void OTCorrelator<BM>::expand<GF>(int start, int slice); \
template void OTCorrelator<BM>::reduce_squares<GF>(unsigned int nTriples, \
vector<GF>& output);
#define ZZ(BM) Z(BM, gfp) Z(BM, gf2n)
ZZ(BitMatrix)
ZZ(Matrix<square128> )
template void OTExtensionWithMatrix::print_post_correlate<gf2n>(
BitVector& newReceiverInput, int j, int offset, int sender);
template void OTExtensionWithMatrix::print_post_correlate<gfp>(
@@ -479,27 +507,35 @@ template void OTExtensionWithMatrix::extend<gf2n>(int nOTs_requested,
BitVector& newReceiverInput);
template void OTExtensionWithMatrix::extend<gfp>(int nOTs_requested,
BitVector& newReceiverInput);
template void OTCorrelator<Matrix<square128> >::expand<gf2n>(int start, int slice);
template void OTCorrelator<Matrix<square128> >::expand<gfp>(int start, int slice);
template void OTExtensionWithMatrix::expand_transposed<gf2n>();
template void OTExtensionWithMatrix::expand_transposed<gfp>();
template void OTCorrelator<Matrix<square128> >::reduce_squares(unsigned int nTriples,
vector<gf2n>& output);
template void OTCorrelator<Matrix<square128> >::reduce_squares(unsigned int nTriples,
vector<gfp>& output);
#define ZZZ(GF, M) \
template void OTExtensionWithMatrix::hash_outputs<GF, M >(int, vector<M >&, M&);
#define MM Matrix<Rectangle<Z2<512>, Z2<160> > >
ZZZ(gfp, Matrix<square128>)
ZZZ(gf2n, Matrix<square128>)
ZZZ(Z2<160>, MM)
#undef X
#define X(N,L) \
template class OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >; \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::correlate<Z2<L> >(int start, int slice, \
BitVector& newReceiverInput, bool useConstantBase, int repeat); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::expand<Z2<L> >(int start, int slice); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<Z2<L> >& output); \
vector<Z2<N> >& output); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<Z2<N + L> >& output); \
vector<Z2<L> >& output); \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<Z2kRectangle<N, L> >& output); \
X(96, 160)
//X(96, 160)
Y(64, 96)
Y(64, 64)
Y(32, 32)
template void OTExtensionWithMatrix::hash_outputs<Z2<128>, Matrix<Rectangle<Z2<384>, Z2<128> > > >(int, std::vector<Matrix<Rectangle<Z2<384>, Z2<128> > >, std::allocator<Matrix<Rectangle<Z2<384>, Z2<128> > > > >&, Matrix<Rectangle<Z2<384>, Z2<128> > >&);
template void OTExtensionWithMatrix::hash_outputs<Z2<64>, Matrix<Rectangle<Z2<192>, Z2<64> > > >(int, std::vector<Matrix<Rectangle<Z2<192>, Z2<64> > >, std::allocator<Matrix<Rectangle<Z2<192>, Z2<64> > > > >&, Matrix<Rectangle<Z2<192>, Z2<64> > >&);

View File

@@ -17,8 +17,10 @@ class OTCorrelator : public OTExtension
{
public:
vector<U> senderOutputMatrices;
U receiverOutputMatrix;
U t1, u;
vector<U> matrices;
U& receiverOutputMatrix;
U& t1;
U u;
OTCorrelator(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
@@ -29,19 +31,20 @@ public:
OT_ROLE role=BOTH,
bool passive=false)
: OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
baseSenderInput, baseReceiverOutput, role, passive) {}
baseSenderInput, baseReceiverOutput, role, passive),
senderOutputMatrices(2), matrices(2),
receiverOutputMatrix(matrices[0]), t1(matrices[1]) {}
void resize(int nOTs);
template <class T>
void expand(int start, int slice);
void setup_for_correlation(BitVector& baseReceiverInput,
vector<BitMatrix>& baseSenderOutputs,
BitMatrix& baseReceiverOutput);
vector<U>& baseSenderOutputs,
U& baseReceiverOutput);
template <class T>
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
template <class T>
void reduce_squares(unsigned int nTriples, vector<T>& output);
};
class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
@@ -52,9 +55,9 @@ public:
OTExtensionWithMatrix(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
TwoPartyPlayer* player,
BitVector& baseReceiverInput,
vector< vector<BitVector> >& baseSenderInput,
vector<BitVector>& baseReceiverOutput,
const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput,
OT_ROLE role=BOTH,
bool passive=false)
: OTCorrelator<BitMatrix>(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
@@ -67,9 +70,12 @@ public:
void transfer(int nOTs, const BitVector& receiverInput);
template <class T>
void extend(int nOTs, BitVector& newReceiverInput);
void extend_correlated(int nOTs, BitVector& newReceiverInput);
template <class T>
void expand_transposed();
void transpose(int start, int slice);
template <class T, class V>
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput);
void print(BitVector& newReceiverInput, int i = 0);
template <class T>

View File

@@ -12,42 +12,59 @@
#include <math.h>
template<class T, class U, class V, class W, class X>
OTMultiplier<T, U, V, W, X>::OTMultiplier(NPartyTripleGenerator& generator,
//#define OTCORR_TIMER
template<class T, class U, class V, class X>
OTMultiplier<T, U, V, X>::OTMultiplier(NPartyTripleGenerator& generator,
int thread_num) :
generator(generator), thread_num(thread_num),
rot_ext(128, 128, 0, 1,
generator.players[thread_num], generator.baseReceiverInput,
generator.baseSenderInputs[thread_num],
generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check),
auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true),
otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true)
{
pthread_mutex_init(&mutex, 0);
pthread_cond_init(&ready, 0);
thread = 0;
pthread_mutex_init(&this->mutex, 0);
pthread_cond_init(&this->ready, 0);
this->thread = 0;
}
template<class T>
MascotMultiplier<T>::MascotMultiplier(NPartyTripleGenerator& generator,
int thread_num) :
OTMultiplier<T, T, T, square128, square128>(generator, thread_num)
OTMultiplier<T, T, T, square128>(generator, thread_num),
auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true)
{
c_output.resize(generator.nTriplesPerLoop);
}
template<class T, class U, class V, class W, class X>
OTMultiplier<T, U, V, W, X>::~OTMultiplier()
template <int K, int S>
Spdz2kMultiplier<K, S>::Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num) :
OTMultiplier<Z2<PASSIVE_MULT_BITS>, // bit length used when computing shares
Z2<S>, // bit length of key share
Z2<MAC_BITS>, // bit length used when computing mac shares
Z2kRectangle<TAU, PASSIVE_MULT_BITS> > // mult-rectangle
(generator, thread_num)
{
pthread_mutex_destroy(&mutex);
pthread_cond_destroy(&ready);
#ifdef USE_OPT_VOLE
mac_vole = new OTVole<Z2<MAC_BITS>, Z2<S>>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false);
#else
mac_vole = new OTVoleBase<Z2<MAC_BITS>, Z2<S>>(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, false);
#endif
}
template<class T, class U, class V, class W, class X>
void OTMultiplier<T, U, V, W, X>::multiply()
template<class T, class U, class V, class X>
OTMultiplier<T, U, V, X>::~OTMultiplier()
{
pthread_mutex_destroy(&this->mutex);
pthread_cond_destroy(&this->ready);
}
template<class T, class U, class V, class X>
void OTMultiplier<T, U, V, X>::multiply()
{
keyBits.set(generator.machine.get_mac_key<U>());
rot_ext.extend<T>(keyBits.size(), keyBits);
rot_ext.extend<gf2n>(keyBits.size(), keyBits);
senderOutput.resize(keyBits.size());
for (size_t j = 0; j < keyBits.size(); j++)
{
@@ -60,19 +77,17 @@ void OTMultiplier<T, U, V, W, X>::multiply()
}
rot_ext.receiverOutputMatrix.to(receiverOutput);
receiverOutput.resize(keyBits.size());
auth_ot_ext.init(keyBits, senderOutput, receiverOutput);
init_authenticator(keyBits, senderOutput, receiverOutput);
if (generator.machine.generateBits)
multiplyForBits();
multiplyForBits();
else
multiplyForTriples();
multiplyForTriples();
}
template<class T, class U, class V, class W, class X>
void OTMultiplier<T, U, V, W, X>::multiplyForTriples()
template<class T, class U, class V, class X>
void OTMultiplier<T, U, V, X>::multiplyForTriples()
{
auth_ot_ext.resize(generator.nPreampTriplesPerLoop * W::N_COLUMNS);
// dummy input for OT correlator
vector<BitVector> _;
vector< vector<BitVector> > __;
@@ -82,20 +97,24 @@ void OTMultiplier<T, U, V, W, X>::multiplyForTriples()
rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128);
pthread_mutex_lock(&mutex);
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
vector<Matrix<X> >& baseSenderOutputs = otCorrelator.matrices;
Matrix<X>& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
pthread_mutex_lock(&this->mutex);
this->signal_generator();
this->wait_for_generator();
for (int i = 0; i < generator.nloops; i++)
{
BitVector aBits = generator.valueBits[0];
//timers["Extension"].start();
rot_ext.extend<T>(X::N_ROWS * generator.nPreampTriplesPerLoop, aBits);
rot_ext.extend_correlated(X::N_ROWS * generator.nPreampTriplesPerLoop, aBits);
rot_ext.hash_outputs<T>(aBits.size(), baseSenderOutputs, baseReceiverOutput);
//timers["Extension"].stop();
//timers["Correlation"].start();
otCorrelator.setup_for_correlation(aBits, rot_ext.senderOutputMatrices,
rot_ext.receiverOutputMatrix);
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
baseReceiverOutput);
otCorrelator.template correlate<T>(0, generator.nPreampTriplesPerLoop,
generator.valueBits[1], false, generator.nAmplify);
//timers["Correlation"].stop();
@@ -104,16 +123,32 @@ void OTMultiplier<T, U, V, W, X>::multiplyForTriples()
this->after_correlation();
pthread_cond_signal(&this->ready);
pthread_cond_wait(&this->ready, &this->mutex);
this->signal_generator();
this->wait_for_generator();
}
pthread_mutex_unlock(&mutex);
pthread_mutex_unlock(&this->mutex);
}
template <class T>
void MascotMultiplier<T>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
const vector<BitVector>& receiverOutput) {
this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput);
}
template <int K, int S>
void Spdz2kMultiplier<K, S>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
const vector<BitVector>& receiverOutput) {
this->mac_vole->init(keyBits, senderOutput, receiverOutput);
}
template <class T>
void MascotMultiplier<T>::after_correlation()
{
this->auth_ot_ext.resize(this->generator.nPreampTriplesPerLoop * square128::N_COLUMNS);
this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop,
this->c_output);
@@ -141,16 +176,46 @@ void Spdz2kMultiplier<K, S>::after_correlation()
{
this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop,
this->c_output);
this->signal_generator();
this->wait_for_generator();
this->macs.resize(3);
#ifdef OTCORR_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
for (int j = 0; j < 3; j++)
{
int nValues = this->generator.nTriplesPerLoop;
BitVector* bits;
if (//this->generator.machine.check &&
(j % 2 == 0)){
nValues *= 2;
bits = &(this->generator.valueBits[j]);
}
else {
// piggy-backing mask after the b's
nValues++;
bits = &(this->generator.b_padded_bits);
}
this->mac_vole->evaluate(this->macs[j], nValues, *bits);
}
#ifdef OTCORR_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tCorrelated OT time: " << elapsed/1000000 << endl << flush;
#endif
}
template<>
void OTMultiplier<gfp, gfp, gfp, square128, square128>::multiplyForBits()
void OTMultiplier<gfp, gfp, gfp, square128>::multiplyForBits()
{
multiplyForTriples();
multiplyForTriples();
}
template<>
void OTMultiplier<gf2n, gf2n, gf2n, square128, square128>::multiplyForBits()
void OTMultiplier<gf2n, gf2n, gf2n, square128>::multiplyForBits()
{
int nBits = generator.nTriplesPerLoop + generator.field_size;
int nBlocks = ceil(1.0 * nBits / generator.field_size);
@@ -162,8 +227,8 @@ void OTMultiplier<gf2n, gf2n, gf2n, square128, square128>::multiplyForBits()
macs[0].resize(nBits);
pthread_mutex_lock(&mutex);
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
signal_generator();
wait_for_generator();
for (int i = 0; i < generator.nloops; i++)
{
@@ -178,25 +243,27 @@ void OTMultiplier<gf2n, gf2n, gf2n, square128, square128>::multiplyForBits()
macs[0][j] = r ^ s;
}
pthread_cond_signal(&ready);
pthread_cond_wait(&ready, &mutex);
signal_generator();
wait_for_generator();
}
pthread_mutex_unlock(&mutex);
}
template<class T, class U, class V, class W, class X>
void OTMultiplier<T, U, V, W, X>::multiplyForBits()
template<class T, class U, class V, class X>
void OTMultiplier<T, U, V, X>::multiplyForBits()
{
throw runtime_error("bit generation not implemented in this case");
}
template class OTMultiplier<gf2n, gf2n, gf2n, square128, square128>;
template class OTMultiplier<gfp, gfp, gfp, square128, square128>;
template class OTMultiplier<gf2n, gf2n, gf2n, square128>;
template class OTMultiplier<gfp, gfp, gfp, square128>;
template class MascotMultiplier<gf2n>;
template class MascotMultiplier<gfp>;
#define X(K, S) \
template class Spdz2kMultiplier<K, S>; \
template class OTMultiplier<Z2<K+S>, Z2<S>, Z2<K+S>, Z2kRectangle<K+S,K+S>, Z2kRectangle<TAU(K,S),K+S> >;
template class OTMultiplier<Z2<K+S>, Z2<S>, Z2<K+2*S>, Z2kRectangle<TAU(K,S),K+S> >;
X(64, 96)
X(64, 64)
X(32, 32)

View File

@@ -12,6 +12,7 @@
using namespace std;
#include "OT/OTExtensionWithMatrix.h"
#include "OT/OTVole.h"
#include "OT/Rectangle.h"
#include "Tools/random.h"
@@ -26,11 +27,22 @@ public:
virtual ~OTMultiplierBase() {}
virtual void multiply() = 0;
void signal_generator() { pthread_cond_signal(&ready); }
void wait_for_generator() { pthread_cond_wait(&ready, &mutex); }
};
template <class T, class U, class V, class W, class X>
class OTMultiplier : public OTMultiplierBase
template <class V>
class OTMultiplierMac : public OTMultiplierBase
{
public:
vector< vector<V> > macs;
};
template <class T, class U, class V, class X>
class OTMultiplier : public OTMultiplierMac<V>
{
protected:
BitVector keyBits;
vector< vector<BitVector> > senderOutput;
vector<BitVector> receiverOutput;
@@ -39,14 +51,15 @@ class OTMultiplier : public OTMultiplierBase
void multiplyForBits();
virtual void after_correlation() = 0;
virtual void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput) = 0;
public:
NPartyTripleGenerator& generator;
int thread_num;
OTExtensionWithMatrix rot_ext;
OTCorrelator<Matrix<W> > auth_ot_ext;
OTCorrelator<Matrix<X> > otCorrelator;
vector< vector<V> > macs;
OTMultiplier(NPartyTripleGenerator& generator, int thread_num);
virtual ~OTMultiplier();
@@ -54,9 +67,13 @@ public:
};
template <class T>
class MascotMultiplier : public OTMultiplier<T, T, T, square128, square128>
class MascotMultiplier : public OTMultiplier<T, T, T, square128>
{
OTCorrelator<Matrix<square128> > auth_ot_ext;
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
public:
vector<T> c_output;
@@ -64,24 +81,25 @@ public:
MascotMultiplier(NPartyTripleGenerator& generator, int thread_num);
};
// values, key, mac, mult-rectangle
template <int K, int S>
class Spdz2kMultiplier: public OTMultiplier<Z2<K + S>, Z2<S>, Z2<K + S>,
Z2kRectangle<K + S, K + S>, Z2kRectangle<TAU(K, S), K + S> >
class Spdz2kMultiplier: public OTMultiplier<Z2<K + S>, Z2<S>, Z2<K + 2 * S>,
Z2kRectangle<TAU(K, S), K + S> >
{
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
public:
static const int TAU = TAU(K, S);
static const int MAC_BITS = K + S;
static const int PASSIVE_MULT_BITS = K + S;
static const int MAC_BITS = K + 2 * S;
vector<Z2kRectangle<TAU, K + S> > c_output;
vector<Z2kRectangle<TAU, PASSIVE_MULT_BITS> > c_output;
OTVoleBase<Z2<MAC_BITS>, Z2<S>>* mac_vole;
Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num) :
OTMultiplier<Z2<MAC_BITS>, Z2<S>, Z2<MAC_BITS>,
Z2kRectangle<MAC_BITS, MAC_BITS>,
Z2kRectangle<TAU, MAC_BITS> >(generator, thread_num)
{
}
Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num);
};
#endif /* OT_OTMULTIPLIER_H_ */

358
OT/OTVole.cpp Normal file
View File

@@ -0,0 +1,358 @@
// (C) 2018 University of Bristol. See License.txt
#include "OTVole.h"
#include "Tools/oct.h"
#include "Auth/Subroutines.h"
//#define OTVOLE_TIMER
template <class T, class U>
void OTVoleBase<T, U>::evaluate(vector<T>& output, const vector<T>& newReceiverInput) {
const int N1 = newReceiverInput.size() + 1;
output.resize(newReceiverInput.size());
vector<octetStream> os(2);
if (this->ot_role & SENDER) {
T extra;
extra.randomize(local_prng);
vector<T> _corr(newReceiverInput);
_corr.push_back(extra);
corr_prime = Row<T>(_corr);
for (int i = 0; i < S; ++i)
{
t0[i] = Row<T>(N1);
t0[i].randomize(this->G_sender[i][0]);
t1[i] = Row<T>(N1);
t1[i].randomize(this->G_sender[i][1]);
Row<T> u = corr_prime + t1[i] + t0[i];
u.pack(os[0]);
}
}
send_if_ot_sender(this->player, os, this->ot_role);
if (this->ot_role & RECEIVER) {
for (int i = 0; i < S; ++i)
{
t[i] = Row<T>(N1);
t[i].randomize(this->G_receiver[i]);
int choice_bit = this->baseReceiverInput.get_bit(i);
if (choice_bit == 1) {
u[i].unpack(os[1]);
a[i] = u[i] - t[i];
} else {
a[i] = t[i];
u[i].unpack(os[1]);
}
}
}
os[0].reset_write_head();
os[1].reset_write_head();
if (!this->passive_only) {
this->consistency_check(os);
}
Row<T> res(N1);
if (this->ot_role & RECEIVER) {
for (int i = 0; i < S; ++i)
res += (a[i] << i);
}
if (this->ot_role & SENDER) {
for (int i = 0; i < S; ++i)
res -= (t0[i] << i);
}
for (int i = 0; i < N1 - 1; ++i)
output[i] = res.rows[i];
}
template <class T, class U>
void OTVoleBase<T, U>::evaluate(vector<T>& output, int nValues, const BitVector& newReceiverInput) {
if (newReceiverInput.size() != (size_t) nValues * T::N_BITS)
throw invalid_length();
vector<T> values(nValues);
for (int i = 0; i < nValues; ++i)
values[i] = T(newReceiverInput.get_ptr_to_bit(i, T::N_BITS));
evaluate(output, values);
}
template <class T, class U>
void OTVoleBase<T, U>::set_coeffs(__m128i* coefficients, PRNG& G, int num_blocks) const {
avx_memzero(coefficients, num_blocks);
for (int i = 0; i < num_blocks; ++i)
coefficients[i] = G.get_doubleword();
}
template <class T, class U>
void OTVoleBase<T, U>::hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients) {
octet hash[VOLE_HASH_SIZE] = {0};
this->hash_row(hash, row, coefficients);
os.append(hash, VOLE_HASH_SIZE);
}
template <class T, class U>
void OTVoleBase<T, U>::hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients) {
const __m128i* blocks = (const __m128i*) row.get_ptr();
int num_blocks = (row.size() * T::size()) / 16;
__m128i prods[2];
avx_memzero(prods, sizeof(prods));
__m128i res[2];
avx_memzero(res, sizeof(res));
for (int i = 0; i < num_blocks; ++i) {
mul128(blocks[i], coefficients[i], &prods[0], &prods[1]);
res[0] ^= prods[0];
res[1] ^= prods[1];
}
int total_bytes = row.size() * T::size();
if (total_bytes % 16 != 0) {
// need to handle "tail" bytes that did not evenly fit into block
int overflow_size = total_bytes % 16;
const octet* overflow_bytes = ((octet*) row.get_ptr()) + total_bytes - overflow_size;
octet bytes[16] = {0};
memcpy(bytes, overflow_bytes, overflow_size);
const __m128i* extra = (const __m128i*) bytes;
mul128(*extra, coefficients[num_blocks], &prods[0], &prods[1]);
res[0] ^= prods[0];
res[1] ^= prods[1];
}
crypto_generichash(hash, crypto_generichash_BYTES,
(octet*) res, crypto_generichash_BYTES, NULL, 0);
}
template <class T, class U>
void OTVoleBase<T, U>::consistency_check(vector<octetStream>& os) {
PRNG coef_prng_sender;
PRNG coef_prng_receiver;
if (this->ot_role & RECEIVER) {
coef_prng_receiver.ReSeed();
os[0].append(coef_prng_receiver.get_seed(), SEED_SIZE);
}
send_if_ot_receiver(this->player, os, this->ot_role);
if (this->ot_role & SENDER) {
octet seed[SEED_SIZE];
os[1].consume(seed, SEED_SIZE);
coef_prng_sender.SetSeed(seed);
}
os[0].reset_write_head();
os[1].reset_write_head();
if (this->ot_role & SENDER) {
#ifdef OTVOLE_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
int total_bytes = t0[0].size() * T::size();
int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0);
__m128i coefficients[num_blocks];
this->set_coeffs(coefficients, coef_prng_sender, num_blocks);
Row<T> t00(t0.size()), t01(t0.size()), t10(t0.size()), t11(t0.size());
for (int alpha = 0; alpha < S; ++alpha)
{
for (int beta = 0; beta < S; ++beta)
{
t00 = t0[alpha] - t0[beta];
t01 = t0[alpha] - t1[beta];
t10 = t1[alpha] - t0[beta];
t11 = t1[alpha] - t1[beta];
this->hash_row(os[0], t00, coefficients);
this->hash_row(os[0], t01, coefficients);
this->hash_row(os[0], t10, coefficients);
this->hash_row(os[0], t11, coefficients);
}
}
#ifdef OTVOLE_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tCheck time sender: " << elapsed/1000000 << endl << flush;
#endif
}
send_if_ot_sender(this->player, os, this->ot_role);
if (this->ot_role & RECEIVER) {
#ifdef OTVOLE_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
int total_bytes = t[0].size() * T::size();
int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0);
__m128i coefficients[num_blocks];
this->set_coeffs(coefficients, coef_prng_receiver, num_blocks);
octet h00[VOLE_HASH_SIZE] = {0};
octet h01[VOLE_HASH_SIZE] = {0};
octet h10[VOLE_HASH_SIZE] = {0};
octet h11[VOLE_HASH_SIZE] = {0};
vector<vector<octet*>> hashes(2);
hashes[0] = {h00, h01};
hashes[1] = {h10, h11};
for (int alpha = 0; alpha < S; ++alpha)
{
for (int beta = 0; beta < S; ++beta)
{
os[1].consume(hashes[0][0], VOLE_HASH_SIZE);
os[1].consume(hashes[0][1], VOLE_HASH_SIZE);
os[1].consume(hashes[1][0], VOLE_HASH_SIZE);
os[1].consume(hashes[1][1], VOLE_HASH_SIZE);
int choice_alpha = this->baseReceiverInput.get_bit(alpha);
int choice_beta = this->baseReceiverInput.get_bit(beta);
Row<T> tmp = t[alpha] - t[beta];
octet* choice_hash = hashes[choice_alpha][choice_beta];
octet diff_t[VOLE_HASH_SIZE] = {0};
this->hash_row(diff_t, tmp, coefficients);
octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta];
octet other_diff[VOLE_HASH_SIZE] = {0};
tmp = u[alpha] - u[beta] - t[alpha] + t[beta];
this->hash_row(other_diff, tmp, coefficients);
if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) {
throw consistency_check_fail();
}
if (!OCTETS_EQUAL(not_choice_hash, other_diff, VOLE_HASH_SIZE)) {
throw consistency_check_fail();
}
if (alpha != beta && u[alpha] == u[beta]) {
throw consistency_check_fail();
}
}
}
#ifdef OTVOLE_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tCheck receiver: " << elapsed/1000000 << endl << flush;
#endif
}
}
template <class T, class U>
void OTVole<T, U>::consistency_check(vector<octetStream>& os) {
PRNG sender_prg;
PRNG receiver_prg;
if (this->ot_role & RECEIVER) {
receiver_prg.ReSeed();
os[0].append(receiver_prg.get_seed(), SEED_SIZE);
}
send_if_ot_receiver(this->player, os, this->ot_role);
if (this->ot_role & SENDER) {
octet seed[SEED_SIZE];
os[1].consume(seed, SEED_SIZE);
sender_prg.SetSeed(seed);
}
os[0].reset_write_head();
os[1].reset_write_head();
if (this->ot_role & SENDER) {
#ifdef OTVOLE_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
int total_bytes = this->t0[0].size() * T::size();
int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0);
__m128i coefficients[num_blocks];
this->set_coeffs(coefficients, sender_prg, num_blocks);
Row<T> t00(this->t0.size()), t01(this->t0.size()), t10(this->t0.size()), t11(this->t0.size());
for (int alpha = 0; alpha < U::N_BITS; ++alpha)
{
for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i)
{
int beta = sender_prg.get_uint(U::N_BITS);
t00 = this->t0[alpha] - this->t0[beta];
t01 = this->t0[alpha] - this->t1[beta];
t10 = this->t1[alpha] - this->t0[beta];
t11 = this->t1[alpha] - this->t1[beta];
this->hash_row(os[0], t00, coefficients);
this->hash_row(os[0], t01, coefficients);
this->hash_row(os[0], t10, coefficients);
this->hash_row(os[0], t11, coefficients);
}
}
#ifdef OTVOLE_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tCheck time sender: " << elapsed/1000000 << endl << flush;
#endif
}
send_if_ot_sender(this->player, os, this->ot_role);
if (this->ot_role & RECEIVER) {
#ifdef OTVOLE_TIMER
timeval totalstartv, totalendv;
gettimeofday(&totalstartv, NULL);
#endif
int total_bytes = this->t[0].size() * T::size();
int num_blocks = (total_bytes) / 16 + ((total_bytes % 16) != 0);
__m128i coefficients[num_blocks];
this->set_coeffs(coefficients, receiver_prg, num_blocks);
octet h00[VOLE_HASH_SIZE] = {0};
octet h01[VOLE_HASH_SIZE] = {0};
octet h10[VOLE_HASH_SIZE] = {0};
octet h11[VOLE_HASH_SIZE] = {0};
vector<vector<octet*>> hashes(2);
hashes[0] = {h00, h01};
hashes[1] = {h10, h11};
for (int alpha = 0; alpha < U::N_BITS; ++alpha)
{
for (int i = 0; i < NUM_VOLE_CHALLENGES; ++i)
{
int beta = receiver_prg.get_uint(U::N_BITS);
os[1].consume(hashes[0][0], VOLE_HASH_SIZE);
os[1].consume(hashes[0][1], VOLE_HASH_SIZE);
os[1].consume(hashes[1][0], VOLE_HASH_SIZE);
os[1].consume(hashes[1][1], VOLE_HASH_SIZE);
int choice_alpha = this->baseReceiverInput.get_bit(alpha);
int choice_beta = this->baseReceiverInput.get_bit(beta);
Row<T> tmp = this->t[alpha] - this->t[beta];
octet* choice_hash = hashes[choice_alpha][choice_beta];
octet diff_t[VOLE_HASH_SIZE] = {0};
this->hash_row(diff_t, tmp, coefficients);
octet* not_choice_hash = hashes[1 - choice_alpha][1 - choice_beta];
octet other_diff[VOLE_HASH_SIZE] = {0};
tmp = this->u[alpha] - this->u[beta] - this->t[alpha] + this->t[beta];
this->hash_row(other_diff, tmp, coefficients);
if (!OCTETS_EQUAL(choice_hash, diff_t, VOLE_HASH_SIZE)) {
throw consistency_check_fail();
}
if (!OCTETS_EQUAL(not_choice_hash, other_diff, VOLE_HASH_SIZE)) {
throw consistency_check_fail();
}
if (alpha != beta && this->u[alpha] == this->u[beta]) {
throw consistency_check_fail();
}
}
}
#ifdef OTVOLE_TIMER
gettimeofday(&totalendv, NULL);
double elapsed = timeval_diff(&totalstartv, &totalendv);
cout << "\t\tCheck receiver: " << elapsed/1000000 << endl << flush;
#endif
}
}
template class OTVoleBase<Z2<256>, Z2<96>>;
template class OTVoleBase<Z2<192>, Z2<64>>;
template class OTVoleBase<Z2<96>, Z2<32>>;
template class OTVole<Z2<256>, Z2<96>>;
template class OTVole<Z2<192>, Z2<64>>;
template class OTVole<Z2<96>, Z2<32>>;

86
OT/OTVole.h Normal file
View File

@@ -0,0 +1,86 @@
// (C) 2018 University of Bristol. See License.txt
#ifndef _OTVOLE
#define _OTVOLE
#include "Math/Z2k.h"
#include "OTExtension.h"
#include "Row.h"
using namespace std;
template <class T, class U>
class OTVoleBase : public OTExtension
{
public:
static const int S = U::N_BITS;
OTVoleBase(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
TwoPartyPlayer* player,
const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput,
OT_ROLE role=BOTH,
bool passive=false)
: OTExtension(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive),
corr_prime(),
t0(U::N_BITS),
t1(U::N_BITS),
u(U::N_BITS),
t(U::N_BITS),
a(U::N_BITS) {
// need to flip roles for OT extension init, reset to original role here
this->ot_role = role;
local_prng.ReSeed();
}
void evaluate(vector<T>& output, const vector<T>& newReceiverInput);
void evaluate(vector<T>& output, int nValues, const BitVector& newReceiverInput);
protected:
// Sender fields
Row<T> corr_prime;
vector<Row<T>> t0, t1;
// Receiver fields
vector<Row<T>> u, t, a;
// Both
PRNG local_prng;
virtual void consistency_check (vector<octetStream>& os);
void set_coeffs(__m128i* coefficients, PRNG& G, int num_elements) const;
void hash_row(octetStream& os, const Row<T>& row, const __m128i* coefficients);
void hash_row(octet* hash, const Row<T>& row, const __m128i* coefficients);
};
template <class T, class U>
class OTVole : public OTVoleBase<T, U>
{
public:
OTVole(int nbaseOTs, int baseLength,
int nloops, int nsubloops,
TwoPartyPlayer* player,
const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput,
OT_ROLE role=BOTH,
bool passive=false)
: OTVoleBase<T, U>(nbaseOTs, baseLength, nloops, nsubloops, player, baseReceiverInput,
baseSenderInput, baseReceiverOutput, INV_ROLE(role), passive) {
}
protected:
void consistency_check(vector<octetStream>& os);
};
#endif

View File

@@ -102,8 +102,11 @@ void Rectangle<U, V>::unpack(octetStream& o)
#define X(N,L) \
template class Rectangle<Z2<N>, Z2<L> > ; \
template void Rectangle<Z2<N>, Z2<L> >::to(Z2<L>& result); \
template void Rectangle<Z2<N>, Z2<L> >::to(Z2<N>& result); \
template void Rectangle<Z2<N>, Z2<L> >::to(Z2<N + L>& result); \
template void Rectangle<Z2<N>, Z2<L> >::to(Z2<L - N>& result); \
Y(64, 96)
X(96, 160)
Y(64, 64)
Y(32, 32)

View File

@@ -11,18 +11,22 @@
#include "Math/Z2k.h"
#define TAU(K, S) 2 * K + 4 * S
#define Y(K, S) X(TAU(K, S), K + S) X(K + S, K + S)
#define Y(K, S) X(TAU(K, S), K + S) X(S, K + 2 * S)
template <class U, class V>
class Rectangle
{
public:
typedef V RowType;
static const int N_ROWS = U::N_BITS;
static const int N_COLUMNS = V::N_BITS;
static const int N_ROW_BYTES = V::N_BYTES;
V rows[N_ROWS];
static size_t size() { return N_ROWS * RowType::size(); }
bool operator==(const Rectangle<U,V>& other) const;
bool operator!=(const Rectangle<U,V>& other) const
{ return not (*this == other); }
@@ -32,6 +36,8 @@ public:
Rectangle<U, V>& operator+=(const Rectangle<U, V>& other);
Rectangle<U, V> operator-(const Rectangle<U, V> & other);
template <class T>
Rectangle<U, V>& sub(Rectangle<U, V>& other) { return other.rsub_(*this); }
template <class T>
Rectangle<U, V>& rsub(Rectangle<U, V>& other) { return rsub_(other); }
Rectangle<U, V>& rsub_(Rectangle<U, V>& other);

113
OT/Row.cpp Normal file
View File

@@ -0,0 +1,113 @@
#include "OT/Row.h"
#include "Exceptions/Exceptions.h"
template<class T>
bool Row<T>::operator ==(const Row<T>& other) const
{
return rows == other.rows;
}
template<class T>
Row<T>& Row<T>::operator +=(const Row<T>& other)
{
if (rows.size() != other.size()) {
throw invalid_length();
}
for (size_t i = 0; i < this->size(); i++)
rows[i] += other.rows[i];
return *this;
}
template<class T>
Row<T>& Row<T>::operator -=(const Row<T>& other)
{
if (rows.size() != other.size()) {
throw invalid_length();
}
for (size_t i = 0; i < this->size(); i++)
rows[i] -= other.rows[i];
return *this;
}
template<class T>
Row<T>& Row<T>::operator *=(const T& other)
{
for (size_t i = 0; i < this->size(); i++)
rows[i] = rows[i] * other;
return *this;
}
template<class T>
Row<T> Row<T>::operator *(const T& other)
{
Row<T> res = *this;
res *= other;
return res;
}
template<class T>
Row<T> Row<T>::operator +(const Row<T>& other)
{
Row<T> res = other;
res += *this;
return res;
}
template<class T>
Row<T> Row<T>::operator -(const Row<T>& other)
{
Row<T> res = *this;
res-=other;
return res;
}
template<class T>
void Row<T>::randomize(PRNG& G)
{
for (size_t i = 0; i < this->size(); i++)
rows[i].randomize(G);
}
template<class T>
Row<T> Row<T>::operator<<(int i) const {
if (i >= T::size() * 8) {
throw invalid_params();
}
Row<T> res = *this;
for (size_t j = 0; j < this->size(); j++)
res.rows[j] = res.rows[j] << i;
return res;
}
template<class T>
void Row<T>::pack(octetStream& o) const
{
o.store(this->size());
for (size_t i = 0; i < this->size(); i++)
rows[i].pack(o);
}
template<class T>
void Row<T>::unpack(octetStream& o)
{
size_t size;
o.get(size);
this->rows.resize(size);
for (size_t i = 0; i < this->size(); i++)
rows[i].unpack(o);
}
template <class V>
ostream& operator<<(ostream& o, const Row<V>& x)
{
for (size_t i = 0; i < x.size(); ++i)
o << x.rows[i] << " | ";
return o;
}
template class Row<Z2<96>>;
template ostream& operator<<(ostream& o, const Row<Z2<96>>& x);
template class Row<Z2<192>>;
template ostream& operator<<(ostream& o, const Row<Z2<192>>& x);
template class Row<Z2<256>>;
template ostream& operator<<(ostream& o, const Row<Z2<256>>& x);

52
OT/Row.h Normal file
View File

@@ -0,0 +1,52 @@
#ifndef OT_ROW_H_
#define OT_ROW_H_
#include "Math/Z2k.h"
#include "Math/gf2nlong.h"
#define VOLE_HASH_SIZE crypto_generichash_BYTES
template <class T>
class Row
{
public:
vector<T> rows;
Row(int size) : rows(size) {}
Row() : rows() {}
Row(const vector<T>& _rows) : rows(_rows) {}
bool operator==(const Row<T>& other) const;
bool operator!=(const Row<T>& other) const { return not (*this == other); }
Row<T>& operator+=(const Row<T>& other);
Row<T>& operator-=(const Row<T>& other);
Row<T>& operator*=(const T& other);
Row<T> operator*(const T& other);
Row<T> operator+(const Row<T> & other);
Row<T> operator-(const Row<T> & other);
Row<T> operator<<(int i) const;
// fine, since elements in vector are allocated contiguously
const void* get_ptr() const { return rows[0].get_ptr(); }
void randomize(PRNG& G);
void pack(octetStream& o) const;
void unpack(octetStream& o);
size_t size() const { return rows.size(); }
template <class V>
friend ostream& operator<<(ostream& o, const Row<V>& x);
};
template <int K>
using Z2kRow = Row<Z2<K>>;
#endif /* OT_ROW_H_ */

View File

@@ -112,6 +112,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
primeField = opt.get("-P")->isSet;
bonding = opt.get("-b")->isSet;
z2k = opt.get("-Z")->isSet;
check |= z2k;
bigint p;
if (output)
@@ -134,7 +135,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
G.ReSeed();
mac_key2.randomize(G);
mac_keyp.randomize(G);
mac_key296.randomize(G);
mac_keyz.randomize(G);
}
void TripleMachine::run()
@@ -168,8 +169,8 @@ void TripleMachine::run()
generators[i]->lock();
if (z2k)
{
pthread_create(&threads[i], 0, run_ngenerator_thread<Z2<64> >,
generators[i]);
pthread_create(&threads[i], 0, run_ngenerator_thread<Z2<SPDZ2K_K + SPDZ2K_S> >,
generators[i]);
continue;
}
if (primeField)
@@ -222,8 +223,9 @@ void TripleMachine::run()
void TripleMachine::output_mac_keys()
{
if (z2k)
write_mac_key(prep_data_dir, my_num, mac_key296);
if (z2k) {
write_mac_key(prep_data_dir, my_num, mac_keyz);
}
else
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2);
}
@@ -240,5 +242,15 @@ template<> gfp TripleMachine::get_mac_key()
template<> Z2<96> TripleMachine::get_mac_key()
{
return mac_key296;
return mac_keyz;
}
template<> Z2<64> TripleMachine::get_mac_key()
{
return mac_keyz;
}
template<> Z2<32> TripleMachine::get_mac_key()
{
return mac_keyz;
}

View File

@@ -17,7 +17,7 @@ class TripleMachine : public OfflineMachineBase
{
gf2n mac_key2;
gfp mac_keyp;
Z2<96> mac_key296;
Z2<SPDZ2K_S> mac_keyz;
public:
int nloops;

View File

@@ -193,10 +193,10 @@ void BufferHelper<U,V>::purge()
template class Buffer< Share<gfp>, Share<gfp> >;
template class Buffer< Share<gf2n>, Share<gf2n> >;
template class Buffer< Share<Z2<64> >, Share<Z2<64> > >;
template class Buffer< Share<Z2<160> >, Share<Z2<160> > >;
template class Buffer< InputTuple<gfp>, RefInputTuple<gfp> >;
template class Buffer< InputTuple<gf2n>, RefInputTuple<gf2n> >;
template class Buffer< InputTuple<Z2<64> >, RefInputTuple<Z2<64> > >;
template class Buffer< InputTuple<Z2<160> >, RefInputTuple<Z2<160> > >;
template class Buffer< gfp, gfp >;
template class Buffer< gf2n, gf2n >;

View File

@@ -69,13 +69,13 @@ class BufferHelper
public:
Buffer< U<gfp>, V<gfp> > bufferp;
Buffer< U<gf2n>, V<gf2n> > buffer2;
Buffer< U<Z2<64> >, V<Z2<64> > > bufferz2k;
Buffer< U<Z2<160> >, V<Z2<160> > > bufferz2k;
ifstream* files[N_DATA_FIELD_TYPE];
BufferHelper() { memset(files, 0, sizeof(files)); }
void input(V<gfp>& a) { bufferp.input(a); }
void input(V<gf2n>& a) { buffer2.input(a); }
void input(V<Z2<64> >& a) { bufferz2k.input(a); }
void input(V<Z2<160> >& a) { bufferz2k.input(a); }
BufferBase& get_buffer(DataFieldType field_type);
void setup(DataFieldType field_type, string filename, int tuple_length, const char* data_type = 0);
void close();

View File

@@ -6,8 +6,8 @@
#include <iomanip>
const char* Data_Files::field_names[] = { "p", "2", "Z2^64" };
const char* Data_Files::long_field_names[] = { "gfp", "gf2n", "Z2^64" };
const char* Data_Files::field_names[] = { "p", "2", "Z2^160" };
const char* Data_Files::long_field_names[] = { "gfp", "gf2n", "Z2^160" };
const bool Data_Files::implemented[N_DATA_FIELD_TYPE][N_DTYPE] = {
{ true, true, true, true, false, false },
{ true, true, true, true, true, true },

View File

@@ -31,7 +31,7 @@ Machine::Machine(int my_number, Names& playerNames,
prep_dir_prefix = get_prep_dir(N.num_players(), lgp, lg2);
read_setup(prep_dir_prefix);
char filename[1024];
char filename[2048];
int nn;
sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number);

View File

@@ -122,7 +122,7 @@ namespace Config {
pubkeys[i].resize(crypto_sign_PUBLICKEYBYTES);
infile.read((char*)&pubkeys[i][0],pubkeys[i].size());
}
} catch (ConfigError e) {
} catch (ConfigError& e) {
pubkeys.resize(0);
}

View File

@@ -10,20 +10,27 @@
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Math/bigint.h"
#include "Math/Z2k.h"
#include <unistd.h>
void MMO::zeroIV()
{
octet key[AES_BLK_SIZE];
memset(key,0,AES_BLK_SIZE*sizeof(octet));
setIV(key);
if (N_KEYS > (1 << 8))
throw not_implemented();
for (int i = 0; i < N_KEYS; i++)
{
octet key[AES_BLK_SIZE];
memset(key, 0, AES_BLK_SIZE * sizeof(octet));
key[i] = i;
setIV(i, key);
}
}
void MMO::setIV(octet key[AES_BLK_SIZE])
void MMO::setIV(int i, octet key[AES_BLK_SIZE])
{
aes_schedule(IV,key);
aes_schedule(IV[i],key);
}
@@ -50,26 +57,91 @@ void MMO::encrypt_and_xor(void* output, const void* input, const octet* key,
_mm_storeu_si128(((__m128i*)output) + indices[i], out[i]);
}
template <>
void MMO::hashOneBlock<gf2n>(octet* output, octet* input)
template <class T, int N>
void MMO::hashBlocks(void* output, const void* input)
{
encrypt_and_xor<1>(output, input, IV);
int n_blocks = DIV_CEIL(T::size(), 16);
if (n_blocks > N_KEYS)
throw runtime_error("not enough MMO keys");
__m128i tmp[N];
int block_size = sizeof(tmp[0]);
for (int i = 0; i < n_blocks; i++)
{
encrypt_and_xor<N>(tmp, input, IV[i]);
for (int j = 0; j < N; j++)
memcpy((char*)output + j * sizeof(T) + i * block_size, &tmp[j],
min(T::size() - i * block_size, block_size));
}
}
template <>
void MMO::hashOneBlock<gfp>(octet* output, octet* input)
void MMO::hashBlocks<gfp, 1>(void* output, const void* input)
{
encrypt_and_xor<1>(output, input, IV);
if (gfp::get_ZpD().get_t() != 2)
throw not_implemented();
encrypt_and_xor<1>(output, input, IV[0]);
while (mpn_cmp((mp_limb_t*)output, gfp::get_ZpD().get_prA(), gfp::t()) >= 0)
encrypt_and_xor<1>(output, output, IV);
encrypt_and_xor<1>(output, output, IV[0]);
}
template <>
void MMO::hashBlockWise<gf2n,128>(octet* output, octet* input)
{
for (int i = 0; i < 16; i++)
encrypt_and_xor<8>(&((__m128i*)output)[i*8], &((__m128i*)input)[i*8], IV);
encrypt_and_xor<8>(&((__m128i*)output)[i*8], &((__m128i*)input)[i*8], IV[0]);
}
template <>
void MMO::hashBlocks<gfp, 8>(void* output, const void* input)
{
if (gfp::get_ZpD().get_t() != 2)
throw not_implemented();
__m128i* in = (__m128i*)input;
__m128i* out = (__m128i*)output;
encrypt_and_xor<8>(out, in, IV[0]);
int left = 8;
int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7};
while (left)
{
int now_left = 0;
for (int j = 0; j < left; j++)
if (mpn_cmp((mp_limb_t*)&out[indices[j]], gfp::get_ZpD().get_prA(), gfp::t()) >= 0)
{
indices[now_left] = indices[j];
now_left++;
}
left = now_left;
// and now my favorite hack
switch (left) {
case 8:
ecb_aes_128_encrypt<8>(out, out, IV[0], indices);
break;
case 7:
ecb_aes_128_encrypt<7>(out, out, IV[0], indices);
break;
case 6:
ecb_aes_128_encrypt<6>(out, out, IV[0], indices);
break;
case 5:
ecb_aes_128_encrypt<5>(out, out, IV[0], indices);
break;
case 4:
ecb_aes_128_encrypt<4>(out, out, IV[0], indices);
break;
case 3:
ecb_aes_128_encrypt<3>(out, out, IV[0], indices);
break;
case 2:
ecb_aes_128_encrypt<2>(out, out, IV[0], indices);
break;
case 1:
ecb_aes_128_encrypt<1>(out, out, IV[0], indices);
break;
default:
break;
}
}
}
template <>
@@ -79,49 +151,11 @@ void MMO::hashBlockWise<gfp,128>(octet* output, octet* input)
{
__m128i* in = &((__m128i*)input)[i*8];
__m128i* out = &((__m128i*)output)[i*8];
encrypt_and_xor<8>(out, in, IV);
int left = 8;
int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7};
while (left)
{
int now_left = 0;
for (int j = 0; j < left; j++)
if (mpn_cmp((mp_limb_t*)&out[indices[j]], gfp::get_ZpD().get_prA(), gfp::t()) >= 0)
{
indices[now_left] = indices[j];
now_left++;
}
left = now_left;
// and now my favorite hack
switch (left) {
case 8:
ecb_aes_128_encrypt<8>(out, out, IV, indices);
break;
case 7:
ecb_aes_128_encrypt<7>(out, out, IV, indices);
break;
case 6:
ecb_aes_128_encrypt<6>(out, out, IV, indices);
break;
case 5:
ecb_aes_128_encrypt<5>(out, out, IV, indices);
break;
case 4:
ecb_aes_128_encrypt<4>(out, out, IV, indices);
break;
case 3:
ecb_aes_128_encrypt<3>(out, out, IV, indices);
break;
case 2:
ecb_aes_128_encrypt<2>(out, out, IV, indices);
break;
case 1:
ecb_aes_128_encrypt<1>(out, out, IV, indices);
break;
default:
break;
}
}
hashBlocks<gfp,8>(out, in);
}
}
#define ZZ(F,N) \
template void MMO::hashBlocks<F,N>(void*, const void*);
#define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8)
Z(gf2n) Z(Z2<64>) Z(Z2<128>) Z(Z2<160>) Z(Z2<256>)

View File

@@ -13,7 +13,8 @@
// Matyas-Meyer-Oseas hashing
class MMO
{
octet IV[176] __attribute__((aligned (16)));
static const int N_KEYS = 2;
octet IV[N_KEYS][176] __attribute__((aligned (16)));
template<int N>
static void encrypt_and_xor(void* output, const void* input,
@@ -25,9 +26,11 @@ class MMO
public:
MMO() { zeroIV(); }
void zeroIV();
void setIV(octet key[AES_BLK_SIZE]);
void setIV(int i, octet key[AES_BLK_SIZE]);
template <class T>
void hashOneBlock(octet* output, octet* input);
void hashOneBlock(void* output, const void* input) { hashBlocks<T, 1>((T*)output, input); }
template <class T, int N>
void hashBlocks(void* output, const void* input);
template <class T, int N>
void hashBlockWise(octet* output, octet* input);
template <class T>

View File

@@ -18,7 +18,6 @@ typedef unsigned char octet;
typedef unsigned long word;
#endif
inline int CEIL_LOG2(int x)
{
int result = 0;

19
Tools/oct.h Normal file
View File

@@ -0,0 +1,19 @@
#ifndef TOOLS_OCT_H_
#define TOOLS_OCT_H_
typedef unsigned char octet;
inline void PRINT_OCTET(const octet* bytes, size_t size) {
for (size_t i = 0; i < size; ++i)
cout << hex << (int) bytes[i];
cout << flush << endl;
}
inline bool OCTETS_EQUAL(const octet* left, const octet* right, int size) {
for (int i = 0; i < size; ++i)
if (left[i] != right[i])
return false;
return true;
}
#endif

View File

@@ -36,16 +36,6 @@ void octetStream::assign(const octetStream& os)
}
void octetStream::swap(octetStream& os)
{
const size_t size = sizeof(octetStream);
char tmp[size];
memcpy(tmp, this, size);
memcpy(this, &os, size);
memcpy(&os, tmp, size);
}
octetStream::octetStream(size_t maxlen)
{
mxlen=maxlen; len=0; ptr=0;

View File

@@ -42,10 +42,10 @@ class octetStream
void resize(size_t l);
void resize_precise(size_t l);
void reserve(size_t l);
void clear();
void assign(const octetStream& os);
void swap(octetStream& os);
octetStream() : len(0), mxlen(0), ptr(0), data(0) {}
octetStream(size_t maxlen);
@@ -177,6 +177,11 @@ inline void octetStream::resize_precise(size_t l)
mxlen=l;
}
inline void octetStream::reserve(size_t l)
{
if (len + l > mxlen)
resize_precise(len + l);
}
inline void octetStream::append(const octet* x, const size_t l)
{

View File

@@ -3,6 +3,7 @@
#include "Tools/random.h"
#include "Math/bigint.h"
#include "Auth/Subroutines.h"
#include <stdio.h>
#include <sodium.h>
@@ -38,6 +39,12 @@ void PRNG::SetSeed(PRNG& G)
SetSeed(tmp);
}
void PRNG::SecureSeed(Player& player)
{
Create_Random_Seed(seed, player, SEED_SIZE);
InitSeed();
}
void PRNG::InitSeed()
{
#ifdef USE_AES
@@ -136,7 +143,25 @@ unsigned int PRNG::get_uint()
return ans;
}
unsigned int PRNG::get_uint(int upper)
{
// adopting Java 7 implementation of bounded nextInt here
if (upper <= 0)
throw invalid_argument("Must be positive");
// power of 2 case
if ((upper & (upper - 1)) == 0) {
unsigned int r = (upper < 255) ? get_uchar() : get_uint();
// zero out higher order bits
return r % upper;
}
// not power of 2
int r, reduced;
do {
r = (upper < 255) ? get_uchar() : get_uint();
reduced = r % upper;
} while (r - reduced + (upper - 1) < 0);
return reduced;
}
unsigned char PRNG::get_uchar()
{
@@ -147,7 +172,6 @@ unsigned char PRNG::get_uchar()
return ans;
}
__m128i PRNG::get_doubleword()
{
if (cnt > RAND_SIZE - 16)

View File

@@ -6,6 +6,7 @@
#include "Tools/octetStream.h"
#include "Tools/sha1.h"
#include "Tools/aes.h"
#include "Tools/avx_memcpy.h"
#define USE_AES
@@ -19,6 +20,7 @@
#define RAND_SIZE (PIPELINES * AES_BLK_SIZE)
#endif
class Player;
/* This basically defines a randomness expander, if using
* as a real PRG on an input stream you should first collapse
@@ -62,12 +64,14 @@ class PRNG
// Set seed from array
void SetSeed(const unsigned char*);
void SetSeed(PRNG& G);
void SecureSeed(Player& player);
void InitSeed();
double get_double();
bool get_bit() { return get_uchar() & 1; }
unsigned char get_uchar();
unsigned int get_uint();
unsigned int get_uint(int upper);
void get_bigint(bigint& res, int n_bits, bool positive = true);
void get(bigint& res, int n_bits, bool positive = true);
void get(int& res, int n_bits, bool positive = true);
@@ -82,9 +86,34 @@ class PRNG
__m128i get_doubleword();
void get_octetStream(octetStream& ans,int len);
void get_octets(octet* ans, int len);
template <int L>
void get_octets_(octet* ans);
template <class T>
T get();
const octet* get_seed() const
{ return seed; }
};
template <class T>
T PRNG::get()
{
T res;
res.randomize(*this);
return res;
}
template<int L>
inline void PRNG::get_octets_(octet* ans)
{
if (L < RAND_SIZE - cnt)
{
avx_memcpy(ans, random + cnt, L);
cnt += L;
}
else
get_octets(ans, L);
}
#endif

View File

@@ -70,6 +70,75 @@ void check_triples(int n_players, string type_char = "")
delete[] inputFiles;
}
template <class T, class U>
void check_triples_Z2k(int n_players, string type_char = "", bool macs = true)
{
ifstream* inputFiles = new ifstream[n_players];
for (int i = 0; i < n_players; i++)
{
stringstream ss;
ss << get_prep_dir(n_players, 128, 128) << "Triples-";
if (type_char.size())
ss << type_char;
else
ss << T::type_char();
ss << "-P" << i;
inputFiles[i].open(ss.str().c_str());
cout << "Opening file " << ss.str() << endl;
}
int j = 0;
while (inputFiles[0].peek() != EOF)
{
U a,b,c,cc,tmp,prod;
T dummy;
vector<T> as(n_players), bs(n_players), cs(n_players);
for (int i = 0; i < n_players; i++)
{
as[i].input(inputFiles[i], false);
if (macs)
dummy.input(inputFiles[i], false);
bs[i].input(inputFiles[i], false);
if (macs)
dummy.input(inputFiles[i], false);
cs[i].input(inputFiles[i], false);
if (macs)
dummy.input(inputFiles[i], false);
}
a = accumulate(as.begin(), as.end(), U());
b = accumulate(bs.begin(), bs.end(), U());
c = accumulate(cs.begin(), cs.end(), U());
prod = a * b;
if (prod != c)
{
cout << T::type_string() << ": Error in " << j << endl;
cout << "a " << a << " " << as[0] << " " << as[1] << endl;
cout << "b " << b << " " << bs[0] << " " << bs[1] << endl;
cout << "c " << c << " " << cs[0] << " " << cs[1] << endl;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
{
tmp = as[i] * bs[j];
cc += tmp;
cout << "a" << i << " * b" << j << " " << tmp << endl;
}
cout << "c " << c << endl;
cout << "cc " << cc << endl;
cout << "a*b " << prod << endl;
cout << "DID YOU INDICATE THE CORRECT NUMBER OF PLAYERS?" << endl;
return;
}
j++;
}
cout << dec << j << " correct triples of type " << T::type_string() << endl;
delete[] inputFiles;
}
int main(int argc, char** argv)
{
int n_players = 2;
@@ -79,5 +148,5 @@ int main(int argc, char** argv)
gfp::init_field(gfp::pr(), false);
check_triples<gf2n>(n_players);
check_triples<gfp>(n_players);
check_triples<Z2<160> >(n_players, "Z2^64");
check_triples_Z2k<Z2<160>, Z2<64>>(n_players);
}