Optimized matrix multiplication; data type for quantization.

This commit is contained in:
Marcel Keller
2019-02-14 15:15:37 +11:00
parent 216fbdf1e1
commit b6a18675e8
108 changed files with 2676 additions and 999 deletions

1
.gitignore vendored
View File

@@ -62,6 +62,7 @@ Programs/Public-Input/*
*.a *.a
*.static *.static
*.d *.d
local/
# Packages # # Packages #
############ ############

View File

@@ -466,3 +466,5 @@ template class MAC_Check_Base<ShamirShare<gfp>>;
template class MAC_Check_Base<ShamirShare<gf2n>>; template class MAC_Check_Base<ShamirShare<gf2n>>;
template class MAC_Check_Base<MaliciousShamirShare<gfp>>; template class MAC_Check_Base<MaliciousShamirShare<gfp>>;
template class MAC_Check_Base<MaliciousShamirShare<gf2n>>; template class MAC_Check_Base<MaliciousShamirShare<gf2n>>;
template class MAC_Check_Base<Share<gfp>>;
template class MAC_Check_Base<Share<gf2n>>;

View File

@@ -85,14 +85,15 @@ void CommMaliciousRepMC<T>::POpen_Begin(vector<typename T::clear>& values,
for (auto& x : S) for (auto& x : S)
for (int i = 0; i < 2; i++) for (int i = 0; i < 2; i++)
x[i].pack(os[1 - i]); x[i].pack(os[1 - i]);
P.send_relative(os); P.pass_around(os[0], 1);
P.pass_around(os[1], 2);
} }
template<class T> template<class T>
void CommMaliciousRepMC<T>::POpen_End(vector<typename T::clear>& values, void CommMaliciousRepMC<T>::POpen_End(vector<typename T::clear>& values,
const vector<T>& S, const Player& P) const vector<T>& S, const Player& P)
{ {
P.receive_relative(os); (void) P;
if (os[0] != os[1]) if (os[0] != os[1])
throw mac_fail(); throw mac_fail();
values.clear(); values.clear();

View File

@@ -4,7 +4,7 @@
*/ */
#include "MaliciousShamirMC.h" #include "MaliciousShamirMC.h"
#include "Processor/ShamirMachine.h" #include "Machines/ShamirMachine.h"
template<class T> template<class T>
MaliciousShamirMC<T>::MaliciousShamirMC() MaliciousShamirMC<T>::MaliciousShamirMC()

View File

@@ -8,7 +8,7 @@
#include "MAC_Check.h" #include "MAC_Check.h"
#include "Math/ShamirShare.h" #include "Math/ShamirShare.h"
#include "Processor/ShamirMachine.h" #include "Machines/ShamirMachine.h"
template<class T> template<class T>
class ShamirMC : public MAC_Check_Base<T> class ShamirMC : public MAC_Check_Base<T>

View File

@@ -14,7 +14,12 @@ void AuthValue::assign(const word& value, const int128& mac_key, bool not_first_
share = 0; share = 0;
else else
share = value; share = value;
#ifdef __PCLMUL__
mac = _mm_clmulepi64_si128(_mm_cvtsi64_si128(mac_key.get_lower()), _mm_cvtsi64_si128(value), 0); mac = _mm_clmulepi64_si128(_mm_cvtsi64_si128(mac_key.get_lower()), _mm_cvtsi64_si128(value), 0);
#else
(void) mac_key;
throw runtime_error("need to compile with PCLMUL support");
#endif
} }
ostream& operator<<(ostream& o, const AuthValue& auth_value) ostream& operator<<(ostream& o, const AuthValue& auth_value)

View File

@@ -12,6 +12,7 @@
#include <string.h> #include <string.h>
#include "Tools/FlexBuffer.h" #include "Tools/FlexBuffer.h"
#include "Math/gf2nlong.h"
using namespace std; using namespace std;
@@ -54,8 +55,7 @@ ostream& operator<<(ostream& o, const __m128i& x);
inline bool Key::operator==(const Key& other) { inline bool Key::operator==(const Key& other) {
__m128i neq = _mm_xor_si128(r, other.r); return int128(r) == other.r;
return _mm_test_all_zeros(neq,neq);
} }
inline Key& Key::operator-=(const Key& other) { inline Key& Key::operator-=(const Key& other) {
@@ -88,7 +88,14 @@ inline void Key::set_signal(bool signal)
inline Key Key::doubling(int i) const inline Key Key::doubling(int i) const
{ {
#ifdef __AVX2__
return _mm_sllv_epi64(r, _mm_set_epi64x(i, i)); return _mm_sllv_epi64(r, _mm_set_epi64x(i, i));
#else
uint64_t halfs[2];
halfs[1] = _mm_cvtsi128_si64(_mm_unpackhi_epi64(r, r)) << i;
halfs[0] = _mm_cvtsi128_si64(r) << i;
return _mm_loadu_si128((__m128i*)halfs);
#endif
} }

View File

@@ -8,7 +8,7 @@
#include <mutex> #include <mutex>
#include <boost/atomic.hpp> #include <boost/atomic.hpp>
#include <sys/sysinfo.h> #include <boost/thread.hpp>
#include "Register.h" #include "Register.h"
#include "GarbledGate.h" #include "GarbledGate.h"
@@ -31,7 +31,7 @@ class BooleanCircuit;
#ifndef N_EVAL_THREADS #ifndef N_EVAL_THREADS
// Default Intel desktop processor has 8 half cores. // Default Intel desktop processor has 8 half cores.
// This is beneficial if only one AES available per full core. // This is beneficial if only one AES available per full core.
#define N_EVAL_THREADS (get_nprocs()) #define N_EVAL_THREADS (thread::hardware_concurrency())
#endif #endif

View File

@@ -9,6 +9,8 @@
#include "GC/Instruction.hpp" #include "GC/Instruction.hpp"
#include "GC/Program.hpp" #include "GC/Program.hpp"
#include "Processor/Instruction.hpp"
namespace GC namespace GC
{ {

View File

@@ -1,4 +1,5 @@
#include "aes.h" #include "aes.h"
#include <stdexcept>
#ifdef _WIN32 #ifdef _WIN32
#include "StdAfx.h" #include "StdAfx.h"
@@ -10,6 +11,7 @@ void AES_128_Key_Expansion(const unsigned char *userkey, AES_KEY *aesKey)
//block *kp = (block *)&aesKey; //block *kp = (block *)&aesKey;
aesKey->rd_key[0] = x0 = _mm_loadu_si128((block*)userkey); aesKey->rd_key[0] = x0 = _mm_loadu_si128((block*)userkey);
x2 = _mm_setzero_si128(); x2 = _mm_setzero_si128();
#ifdef __AES__
EXPAND_ASSIST(x0, x1, x2, x0, 255, 1); aesKey->rd_key[1] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 1); aesKey->rd_key[1] = x0;
EXPAND_ASSIST(x0, x1, x2, x0, 255, 2); aesKey->rd_key[2] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 2); aesKey->rd_key[2] = x0;
EXPAND_ASSIST(x0, x1, x2, x0, 255, 4); aesKey->rd_key[3] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 4); aesKey->rd_key[3] = x0;
@@ -20,198 +22,8 @@ void AES_128_Key_Expansion(const unsigned char *userkey, AES_KEY *aesKey)
EXPAND_ASSIST(x0, x1, x2, x0, 255, 128); aesKey->rd_key[8] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 128); aesKey->rd_key[8] = x0;
EXPAND_ASSIST(x0, x1, x2, x0, 255, 27); aesKey->rd_key[9] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 27); aesKey->rd_key[9] = x0;
EXPAND_ASSIST(x0, x1, x2, x0, 255, 54); aesKey->rd_key[10] = x0; EXPAND_ASSIST(x0, x1, x2, x0, 255, 54); aesKey->rd_key[10] = x0;
} #else
(void) x1, (void) x2;
throw std::runtime_error("need to compile with AES-NI support");
#endif
void AES_192_Key_Expansion(const unsigned char *userkey, AES_KEY *aesKey)
{
__m128i x0,x1,x2,x3,tmp,*kp = (block *)&aesKey;
kp[0] = x0 = _mm_loadu_si128((block*)userkey);
tmp = x3 = _mm_loadu_si128((block*)(userkey+16));
x2 = _mm_setzero_si128();
EXPAND192_STEP(1,1);
EXPAND192_STEP(4,4);
EXPAND192_STEP(7,16);
EXPAND192_STEP(10,64);
}
void AES_256_Key_Expansion(const unsigned char *userkey, AES_KEY *aesKey)
{
__m128i x0, x1, x2, x3;/* , *kp = (block *)&aesKey;*/
aesKey->rd_key[0] = x0 = _mm_loadu_si128((block*)userkey);
aesKey->rd_key[1] = x3 = _mm_loadu_si128((block*)(userkey + 16));
x2 = _mm_setzero_si128();
EXPAND_ASSIST(x0, x1, x2, x3, 255, 1); aesKey->rd_key[2] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 1); aesKey->rd_key[3] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 2); aesKey->rd_key[4] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 2); aesKey->rd_key[5] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 4); aesKey->rd_key[6] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 4); aesKey->rd_key[7] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 8); aesKey->rd_key[8] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 8); aesKey->rd_key[9] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 16); aesKey->rd_key[10] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 16); aesKey->rd_key[11] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 32); aesKey->rd_key[12] = x0;
EXPAND_ASSIST(x3, x1, x2, x0, 170, 32); aesKey->rd_key[13] = x3;
EXPAND_ASSIST(x0, x1, x2, x3, 255, 64); aesKey->rd_key[14] = x0;
}
void AES_set_encrypt_key(const unsigned char *userKey, const int bits, AES_KEY *aesKey)
{
if (bits == 128) {
AES_128_Key_Expansion(userKey, aesKey);
} else if (bits == 192) {
AES_192_Key_Expansion(userKey, aesKey);
} else if (bits == 256) {
AES_256_Key_Expansion(userKey, aesKey);
}
aesKey->rounds = 6 + bits / 32;
}
void AES_encryptC(block *in, block *out, AES_KEY *aesKey)
{
int j, rnds = ROUNDS(aesKey);
const __m128i *sched = ((__m128i *)(aesKey->rd_key));
__m128i tmp = _mm_load_si128((__m128i*)in);
tmp = _mm_xor_si128(tmp, sched[0]);
for (j = 1; j<rnds; j++) tmp = _mm_aesenc_si128(tmp, sched[j]);
tmp = _mm_aesenclast_si128(tmp, sched[j]);
_mm_store_si128((__m128i*)out, tmp);
}
void AES_ecb_encrypt(block *blk, AES_KEY *aesKey) {
unsigned j, rnds = ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
*blk = _mm_xor_si128(*blk, sched[0]);
for (j = 1; j<rnds; ++j)
*blk = _mm_aesenc_si128(*blk, sched[j]);
*blk = _mm_aesenclast_si128(*blk, sched[j]);
}
void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *aesKey) {
unsigned i,j,rnds=ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
for (i=0; i<nblks; ++i)
blks[i] =_mm_xor_si128(blks[i], sched[0]);
for(j=1; j<rnds; ++j)
for (i=0; i<nblks; ++i)
blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
for (i=0; i<nblks; ++i)
blks[i] =_mm_aesenclast_si128(blks[i], sched[j]);
}
void AES_ecb_encrypt_blks_4(block *blks, AES_KEY *aesKey) {
unsigned j, rnds = ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
blks[0] = _mm_xor_si128(blks[0], sched[0]);
blks[1] = _mm_xor_si128(blks[1], sched[0]);
blks[2] = _mm_xor_si128(blks[2], sched[0]);
blks[3] = _mm_xor_si128(blks[3], sched[0]);
for (j = 1; j < rnds; ++j){
blks[0] = _mm_aesenc_si128(blks[0], sched[j]);
blks[1] = _mm_aesenc_si128(blks[1], sched[j]);
blks[2] = _mm_aesenc_si128(blks[2], sched[j]);
blks[3] = _mm_aesenc_si128(blks[3], sched[j]);
}
blks[0] = _mm_aesenclast_si128(blks[0], sched[j]);
blks[1] = _mm_aesenclast_si128(blks[1], sched[j]);
blks[2] = _mm_aesenclast_si128(blks[2], sched[j]);
blks[3] = _mm_aesenclast_si128(blks[3], sched[j]);
}
void AES_ecb_encrypt_blks_2_in_out(block *in, block *out, AES_KEY *aesKey) {
unsigned j, rnds = ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
out[0] = _mm_xor_si128(in[0], sched[0]);
out[1] = _mm_xor_si128(in[1], sched[0]);
for (j = 1; j < rnds; ++j){
out[0] = _mm_aesenc_si128(out[0], sched[j]);
out[1] = _mm_aesenc_si128(out[1], sched[j]);
}
out[0] = _mm_aesenclast_si128(out[0], sched[j]);
out[1] = _mm_aesenclast_si128(out[1], sched[j]);
}
void AES_ecb_encrypt_blks_4_in_out(block *in, block *out, AES_KEY *aesKey) {
unsigned j, rnds = ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
//block temp[4];
out[0] = _mm_xor_si128(in[0], sched[0]);
out[1] = _mm_xor_si128(in[1], sched[0]);
out[2] = _mm_xor_si128(in[2], sched[0]);
out[3] = _mm_xor_si128(in[3], sched[0]);
for (j = 1; j < rnds; ++j){
out[0] = _mm_aesenc_si128(out[0], sched[j]);
out[1] = _mm_aesenc_si128(out[1], sched[j]);
out[2] = _mm_aesenc_si128(out[2], sched[j]);
out[3] = _mm_aesenc_si128(out[3], sched[j]);
}
out[0] = _mm_aesenclast_si128(out[0], sched[j]);
out[1] = _mm_aesenclast_si128(out[1], sched[j]);
out[2] = _mm_aesenclast_si128(out[2], sched[j]);
out[3] = _mm_aesenclast_si128(out[3], sched[j]);
}
void AES_ecb_encrypt_chunk_in_out(block *in, block *out, unsigned nblks, AES_KEY *aesKey) {
int numberOfLoops = nblks / 8;
int blocksPipeLined = numberOfLoops * 8;
int remainingEncrypts = nblks - blocksPipeLined;
unsigned j, rnds = ROUNDS(aesKey);
const block *sched = ((block *)(aesKey->rd_key));
for (int i = 0; i < numberOfLoops; i++){
out[0 + i * 8] = _mm_xor_si128(in[0 + i * 8], sched[0]);
out[1 + i * 8] = _mm_xor_si128(in[1 + i * 8], sched[0]);
out[2 + i * 8] = _mm_xor_si128(in[2 + i * 8], sched[0]);
out[3 + i * 8] = _mm_xor_si128(in[3 + i * 8], sched[0]);
out[4 + i * 8] = _mm_xor_si128(in[4 + i * 8], sched[0]);
out[5 + i * 8] = _mm_xor_si128(in[5 + i * 8], sched[0]);
out[6 + i * 8] = _mm_xor_si128(in[6 + i * 8], sched[0]);
out[7 + i * 8] = _mm_xor_si128(in[7 + i * 8], sched[0]);
for (j = 1; j < rnds; ++j){
out[0 + i * 8] = _mm_aesenc_si128(out[0 + i * 8], sched[j]);
out[1 + i * 8] = _mm_aesenc_si128(out[1 + i * 8], sched[j]);
out[2 + i * 8] = _mm_aesenc_si128(out[2 + i * 8], sched[j]);
out[3 + i * 8] = _mm_aesenc_si128(out[3 + i * 8], sched[j]);
out[4 + i * 8] = _mm_aesenc_si128(out[4 + i * 8], sched[j]);
out[5 + i * 8] = _mm_aesenc_si128(out[5 + i * 8], sched[j]);
out[6 + i * 8] = _mm_aesenc_si128(out[6 + i * 8], sched[j]);
out[7 + i * 8] = _mm_aesenc_si128(out[7 + i * 8], sched[j]);
}
out[0 + i * 8] = _mm_aesenclast_si128(out[0 + i * 8], sched[j]);
out[1 + i * 8] = _mm_aesenclast_si128(out[1 + i * 8], sched[j]);
out[2 + i * 8] = _mm_aesenclast_si128(out[2 + i * 8], sched[j]);
out[3 + i * 8] = _mm_aesenclast_si128(out[3 + i * 8], sched[j]);
out[4 + i * 8] = _mm_aesenclast_si128(out[4 + i * 8], sched[j]);
out[5 + i * 8] = _mm_aesenclast_si128(out[5 + i * 8], sched[j]);
out[6 + i * 8] = _mm_aesenclast_si128(out[6 + i * 8], sched[j]);
out[7 + i * 8] = _mm_aesenclast_si128(out[7 + i * 8], sched[j]);
}
for (int i = blocksPipeLined; i < blocksPipeLined + remainingEncrypts; ++i){
out[i] = _mm_xor_si128(in[i], sched[0]);
for (j = 1; j < rnds; ++j)
{
out[i] = _mm_aesenc_si128(out[i], sched[j]);
}
out[i] = _mm_aesenclast_si128(out[i], sched[j]);
}
} }

View File

@@ -24,7 +24,6 @@
typedef struct { block rd_key[15]; int rounds; } AES_KEY; typedef struct { block rd_key[15]; int rounds; } AES_KEY;
#define ROUNDS(ctx) ((ctx)->rounds)
#define EXPAND_ASSIST(v1,v2,v3,v4,shuff_const,aes_const) \ #define EXPAND_ASSIST(v1,v2,v3,v4,shuff_const,aes_const) \
v2 = _mm_aeskeygenassist_si128(v4,aes_const); \ v2 = _mm_aeskeygenassist_si128(v4,aes_const); \
@@ -37,36 +36,6 @@ typedef struct { block rd_key[15]; int rounds; } AES_KEY;
v2 = _mm_shuffle_epi32(v2,shuff_const); \ v2 = _mm_shuffle_epi32(v2,shuff_const); \
v1 = _mm_xor_si128(v1,v2) v1 = _mm_xor_si128(v1,v2)
#define EXPAND192_STEP(idx,aes_const) \
EXPAND_ASSIST(x0,x1,x2,x3,85,aes_const); \
x3 = _mm_xor_si128(x3,_mm_slli_si128 (x3, 4)); \
x3 = _mm_xor_si128(x3,_mm_shuffle_epi32(x0, 255)); \
kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), \
_mm_castsi128_ps(x0), 68)); \
kp[idx+1] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), \
_mm_castsi128_ps(x3), 78)); \
EXPAND_ASSIST(x0,x1,x2,x3,85,(aes_const*2)); \
x3 = _mm_xor_si128(x3,_mm_slli_si128 (x3, 4)); \
x3 = _mm_xor_si128(x3,_mm_shuffle_epi32(x0, 255)); \
kp[idx+2] = x0; tmp = x3
void AES_128_Key_Expansion(const unsigned char *userkey, AES_KEY* aesKey); void AES_128_Key_Expansion(const unsigned char *userkey, AES_KEY* aesKey);
void AES_192_Key_Expansion(const unsigned char *userkey, AES_KEY* aesKey);
void AES_256_Key_Expansion(const unsigned char *userkey, AES_KEY* aesKey);
void AES_set_encrypt_key(const unsigned char *userKey, const int bits, AES_KEY *aesKey);
void AES_encryptC(block *in, block *out, AES_KEY *aesKey);
void AES_ecb_encrypt(block *blk, AES_KEY *aesKey);
void AES_ecb_encrypt_blks(block *blks, unsigned nblks, AES_KEY *aesKey);
void AES_ecb_encrypt_blks_4(block *blk, AES_KEY *aesKey);
void AES_ecb_encrypt_blks_4_in_out(block *in, block *out, AES_KEY *aesKey);
void AES_ecb_encrypt_blks_2_in_out(block *in, block *out, AES_KEY *aesKey);
void AES_ecb_encrypt_chunk_in_out(block *in, block *out, unsigned nblks, AES_KEY *aesKey);
#endif /* PROTOCOL_INC_AES_H_ */ #endif /* PROTOCOL_INC_AES_H_ */

View File

@@ -36,7 +36,7 @@ Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsig
_servaddr.sin_addr.s_addr = htonl(INADDR_ANY); _servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
_servaddr.sin_port = htons(_port); _servaddr.sin_port = htons(_port);
if( 0 != bind(_servfd, (struct sockaddr *) &_servaddr, sizeof(_servaddr)) ) if( 0 != ::bind(_servfd, (struct sockaddr *) &_servaddr, sizeof(_servaddr)) )
printf("Server:: Error binding to %d: \n%s\n", _port, strerror(errno)); printf("Server:: Error binding to %d: \n%s\n", _port, strerror(errno));
if(0 != listen(_servfd, _expected_clients)) if(0 != listen(_servfd, _expected_clients))

View File

@@ -1,21 +0,0 @@
/*
* prf.cpp
*
*/
#include "prf.h"
#include "aes.h"
#include "proto_utils.h"
void PRF_single(const Key& key, char* input, char* output)
{
// printf("prf_single\n");
// std::cout << *key;
// phex(input, 16);
AES_KEY aes_key;
AES_128_Key_Expansion((const unsigned char*)(&(key.r)), &aes_key);
aes_key.rounds=10;
AES_encryptC((block*)input, (block*)output, &aes_key);
// phex(output, 16);
}

View File

@@ -11,15 +11,6 @@
#include "Tools/aes.h" #include "Tools/aes.h"
void PRF_single(const Key& key, char* input, char* output);
inline Key PRF_single(const Key& key, const Key& input)
{
Key output;
PRF_single(key, (char*)&input, (char*)&output);
return output;
}
inline void PRF_chunk(const Key& key, char* input, char* output, int number) inline void PRF_chunk(const Key& key, char* input, char* output, int number)
{ {
__m128i* in = (__m128i*)input; __m128i* in = (__m128i*)input;

View File

@@ -1,5 +1,11 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here. The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here.
## 0.0.7 (Feb 14, 2019)
- Simplified installation on macOS
- Optimized matrix multiplication
- Data type for quantization
## 0.0.6 (Jan 5, 2019) ## 0.0.6 (Jan 5, 2019)
- Shamir secret sharing - Shamir secret sharing

18
CONFIG
View File

@@ -4,6 +4,7 @@ OPTIM= -O3
#PROF = -pg #PROF = -pg
#DEBUG = -DDEBUG #DEBUG = -DDEBUG
#MEMPROTECT = -DMEMPROTECT #MEMPROTECT = -DMEMPROTECT
GDEBUG = -g
# set this to your preferred local storage directory # set this to your preferred local storage directory
PREP_DIR = '-DPREP_DIR="Player-Data/"' PREP_DIR = '-DPREP_DIR="Player-Data/"'
@@ -16,8 +17,15 @@ USE_NTL = 0
USE_GF2N_LONG = 1 USE_GF2N_LONG = 1
# set to -march=<architecture> for optimization # set to -march=<architecture> for optimization
# AVX2 support (Haswell or later) changes the bit matrix transpose # AES-NI is required for BMR
ARCH = -mtune=native -mavx # PCLMUL is required for GF(2^128) computation
# AVX2 support (Haswell or later) is used to optimize OT
# AVX/AVX2 is required for replicated binary secret sharing
# BMI2 is used to optimize multiplication modulo a prime
ARCH = -mtune=native -msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2
# allow to set compiler in CONFIG.mine
CXX = g++
#use CONFIG.mine to overwrite DIR settings #use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine -include CONFIG.mine
@@ -30,7 +38,7 @@ endif
# Default is 2, which suffices for 128-bit p # Default is 2, which suffices for 128-bit p
# MOD = -DMAX_MOD_SZ=2 # MOD = -DMAX_MOD_SZ=2
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) -lm -lpthread LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS += -lboost_system -lssl -lcrypto LDLIBS += -lboost_system -lssl -lcrypto
ifeq ($(USE_NTL),1) ifeq ($(USE_NTL),1)
@@ -44,8 +52,6 @@ endif
BOOST = -lboost_system -lboost_thread $(MY_BOOST) BOOST = -lboost_system -lboost_thread $(MY_BOOST)
CXX = g++ -no-pie CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror
#CXX = clang++
CFLAGS += $(ARCH) $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 -mavx -mavx2 -mbmi2 --std=c++11 -Werror
CPPFLAGS = $(CFLAGS) CPPFLAGS = $(CFLAGS)
LD = $(CXX) LD = $(CXX)

View File

@@ -14,6 +14,7 @@
#include "Processor/Data_Files.h" #include "Processor/Data_Files.h"
#include "Auth/fake-stuff.hpp" #include "Auth/fake-stuff.hpp"
#include "Processor/Data_Files.hpp"
#include <sstream> #include <sstream>
#include <fstream> #include <fstream>

View File

@@ -6,11 +6,10 @@ from Compiler import util, oram, floatingpoint
import Compiler.GC.instructions as inst import Compiler.GC.instructions as inst
import operator import operator
class bits(Tape.Register): class bits(Tape.Register, _structure):
n = 40 n = 40
size = 1 size = 1
PreOp = staticmethod(floatingpoint.PreOpN) PreOp = staticmethod(floatingpoint.PreOpN)
MemValue = staticmethod(lambda value: MemValue(value))
decomposed = None decomposed = None
@staticmethod @staticmethod
def PreOR(l): def PreOR(l):
@@ -72,7 +71,10 @@ class bits(Tape.Register):
def n_elements(): def n_elements():
return 1 return 1
@classmethod @classmethod
def load_mem(cls, address, mem_type=None): def load_mem(cls, address, mem_type=None, size=None):
if size not in (None, 1):
v = [cls.load_mem(address + i) for i in range(size)]
return cls.vec(v)
res = cls() res = cls()
if mem_type == 'sd': if mem_type == 'sd':
return cls.load_dynamic_mem(address) return cls.load_dynamic_mem(address)
@@ -376,6 +378,9 @@ class sbits(bits):
class sbitvec(object): class sbitvec(object):
@classmethod @classmethod
def get_type(cls, n):
return cls
@classmethod
def from_vec(cls, vector): def from_vec(cls, vector):
res = cls() res = cls()
res.v = list(vector) res.v = list(vector)
@@ -419,6 +424,15 @@ class sbitvec(object):
@classmethod @classmethod
def conv(cls, other): def conv(cls, other):
return cls.from_vec(other.v) return cls.from_vec(other.v)
@property
def size(self):
return self.v[0].n
def store_in_mem(self, address):
for i, x in enumerate(self.elements()):
x.store_in_mem(address + i)
def bit_decompose(self):
return self.v
bit_compose = from_vec
class bit(object): class bit(object):
n = 1 n = 1
@@ -499,7 +513,9 @@ class sbitint(_bitint, _number, sbits):
bin_type = None bin_type = None
types = {} types = {}
@classmethod @classmethod
def get_type(cls, n): def get_type(cls, n, other=None):
if isinstance(other, sbitvec):
return sbitvec
if n in cls.types: if n in cls.types:
return cls.types[n] return cls.types[n]
sbits_type = sbits.get_type(n) sbits_type = sbits.get_type(n)
@@ -511,6 +527,12 @@ class sbitint(_bitint, _number, sbits):
cls.types[n] = _ cls.types[n] = _
return _ return _
@classmethod @classmethod
def combo_type(cls, other):
if isinstance(other, sbitintvec):
return sbitintvec
else:
return cls
@classmethod
def new(cls, value=None, n=None): def new(cls, value=None, n=None):
return cls.get_type(n)(value) return cls.get_type(n)(value)
def set_length(*args): def set_length(*args):
@@ -523,7 +545,9 @@ class sbitint(_bitint, _number, sbits):
return super(sbitint, cls).bit_compose(bits) return super(sbitint, cls).bit_compose(bits)
def force_bit_decompose(self, n_bits=None): def force_bit_decompose(self, n_bits=None):
return sbits.bit_decompose(self, n_bits) return sbits.bit_decompose(self, n_bits)
def TruncMul(self, other, k, m, kappa=None): def TruncMul(self, other, k, m, kappa=None, nearest=False):
if nearest:
raise CompilerError('round to nearest not implemented')
self_bits = self.bit_decompose() self_bits = self.bit_decompose()
other_bits = other.bit_decompose() other_bits = other.bit_decompose()
if len(self_bits) + len(other_bits) != k: if len(self_bits) + len(other_bits) != k:
@@ -532,10 +556,12 @@ class sbitint(_bitint, _number, sbits):
(len(self_bits), len(other_bits), k)) (len(self_bits), len(other_bits), k))
t = self.get_type(k) t = self.get_type(k)
a = t.bit_compose(self_bits + [self_bits[-1]] * (k - len(self_bits))) a = t.bit_compose(self_bits + [self_bits[-1]] * (k - len(self_bits)))
t = other.get_type(k)
b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits))) b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits)))
product = a * b product = a * b
res_bits = product.bit_decompose()[m:k] res_bits = product.bit_decompose()[m:k]
return self.bit_compose(res_bits) t = self.combo_type(other)
return t.bit_compose(res_bits)
def Norm(self, k, f, kappa=None, simplex_flag=False): def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self) absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding #next 2 lines actually compute the SufOR for little indian encoding
@@ -557,17 +583,32 @@ class sbitint(_bitint, _number, sbits):
bits = self.bit_decompose() bits = self.bit_decompose()
bits += [bits[-1]] * (n - len(bits)) bits += [bits[-1]] * (n - len(bits))
return self.get_type(n).bit_compose(bits) return self.get_type(n).bit_compose(bits)
def __mul__(self, other):
if isinstance(other, sbitintvec):
return other * self
else:
return super(sbitint, self).__mul__(other)
class sbitintvec(sbitvec): class sbitintvec(sbitvec):
def __add__(self, other): def __add__(self, other):
if other is 0: if other is 0:
return self return self
assert(len(self.v) == len(other.v)) assert(len(self.v) == len(other.v))
return self.from_vec(sbitint.bit_adder(self.v, other.v)) v = sbitint.bit_adder(self.v, other.v)
return self.from_vec(v)
__radd__ = __add__ __radd__ = __add__
def less_than(self, other, *args, **kwargs): def less_than(self, other, *args, **kwargs):
assert(len(self.v) == len(other.v)) assert(len(self.v) == len(other.v))
return self.from_vec(sbitint.bit_less_than(self.v, other.v)) return self.from_vec(sbitint.bit_less_than(self.v, other.v))
def __mul__(self, other):
assert isinstance(other, sbitint)
matrix = [[x * b for x in self.v] for b in other.bit_decompose()]
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v)
__rmul__ = __mul__
reduce_after_mul = lambda x: x
sbitint.vec = sbitintvec
class cbitfix(object): class cbitfix(object):
def __init__(self, value): def __init__(self, value):
@@ -585,6 +626,13 @@ class sbitfix(_fix):
def set_precision(cls, f, k=None): def set_precision(cls, f, k=None):
super(cls, sbitfix).set_precision(f, k) super(cls, sbitfix).set_precision(f, k)
cls.int_type = sbitint.get_type(cls.k) cls.int_type = sbitint.get_type(cls.k)
@classmethod
def load_mem(cls, address, size=None):
if size not in (None, 1):
v = [cls.int_type.load_mem(address + i) for i in range(size)]
return sbitfixvec._new(sbitintvec(v))
else:
return super(sbitfix, cls).load_mem(address)
def __xor__(self, other): def __xor__(self, other):
return type(self)(self.v ^ other.v) return type(self)(self.v ^ other.v)
def __mul__(self, other): def __mul__(self, other):
@@ -596,3 +644,17 @@ class sbitfix(_fix):
__rmul__ = __mul__ __rmul__ = __mul__
sbitfix.set_precision(20, 41) sbitfix.set_precision(20, 41)
class sbitfixvec(_fix):
int_type = sbitintvec
float_type = type(None)
@staticmethod
@property
def f():
return sbitfix.f
@staticmethod
@property
def k():
return sbitfix.k
sbitfix.vec = sbitfixvec

View File

@@ -43,8 +43,7 @@ class StraightlineAllocator:
if base.vector: if base.vector:
for i,r in enumerate(base.vector): for i,r in enumerate(base.vector):
r.i = self.alloc[base] + i r.i = self.alloc[base] + i
else: base.i = self.alloc[base]
base.i = self.alloc[base]
def dealloc_reg(self, reg, inst, free): def dealloc_reg(self, reg, inst, free):
self.dealloc.add(reg) self.dealloc.add(reg)
@@ -57,6 +56,7 @@ class StraightlineAllocator:
return return
free[reg.reg_type, base.size].add(self.alloc[base]) free[reg.reg_type, base.size].add(self.alloc[base])
if inst.is_vec() and base.vector: if inst.is_vec() and base.vector:
self.defined[base] = inst
for i in base.vector: for i in base.vector:
self.defined[i] = inst self.defined[i] = inst
else: else:

View File

@@ -108,67 +108,54 @@ def Trunc(d, a, k, m, kappa, signed):
def TruncRing(d, a, k, m, signed): def TruncRing(d, a, k, m, signed):
a_prime = Mod2mRing(None, a, k, m, signed) a_prime = Mod2mRing(None, a, k, m, signed)
a -= a_prime a -= a_prime
res = TruncZeroesInRing(a, k, m, signed) res = TruncLeakyInRing(a, k, m, signed)
if d is not None: if d is not None:
movs(d, res) movs(d, res)
return res return res
def TruncZeroesInRing(a, k, m, signed): def TruncZeroes(a, k, m, signed):
if program.options.ring:
return TruncLeakyInRing(a, k, m, signed)
else:
import types
tmp = types.cint()
inv2m(tmp, m)
return a * tmp
def TruncLeakyInRing(a, k, m, signed):
""" """
Returns a >> m. Returns a >> m.
Requires 2^m | a and a < 2^k. Requires a < 2^k and leaks a % 2^m (needs to be constant or random).
""" """
assert k > m
assert int(program.options.ring) >= k
from types import sint, intbitint, cint, cgf2n from types import sint, intbitint, cint, cgf2n
n_bits = k - m n_bits = k - m
n_shift = int(program.options.ring) - n_bits n_shift = int(program.options.ring) - n_bits
r_bits = [sint.get_random_bit() for i in range(n_bits)] r_bits = [sint.get_random_bit() for i in range(n_bits)]
r = sint.bit_compose(r_bits) r = sint.bit_compose(r_bits)
shifted = ((a << (n_shift - m)) - (r << n_shift)).reveal()
masked = shifted >> n_shift
res_bits = intbitint.bit_adder(r_bits, masked.bit_decompose(n_bits))
res = sint.bit_compose(res_bits)
if signed: if signed:
res = sint.conv(res_bits[-1].if_else(res - (sint(1) << n_bits), a += (1 << (k - 1))
res)) shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits, 0)
res = (u << n_bits) + masked - r
if signed:
res -= (1 << (n_bits - 1))
return res return res
def TruncRoundNearest(a, k, m, kappa): def TruncRoundNearest(a, k, m, kappa, signed=False):
""" """
Returns a / 2^m, rounded to the nearest integer. Returns a / 2^m, rounded to the nearest integer.
k: bit length of m k: bit length of a
m: compile-time integer m: compile-time integer
""" """
from types import sint, cint from types import sint
from library import reveal, load_int_to_secret res = sint()
if m == 1: Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
if program.options.ring: return res
lsb = Mod2mRing(None, a, k, 1, False)
return TruncRing(None, a + lsb, k + 1, 1, False)
else:
lsb = sint()
Mod2(lsb, a, k, kappa, False)
return (a + lsb) / 2
r_dprime = sint()
r_prime = sint()
r = [sint() for i in range(m)]
u = sint()
PRandM(r_dprime, r_prime, r, k, m, kappa)
c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
c_prime = c % (cint(1) << (m - 1))
if const_rounds:
BitLTC1(u, c_prime, r[:-1], kappa)
else:
BitLTL(u, c_prime, r[:-1], kappa)
bit = ((c - c_prime) >> (m - 1)) % 2
xor = bit + u - 2 * bit * u
prod = xor * r[-1]
# u_prime = xor * u + (1 - xor) * r[-1]
u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
d = (a - a_prime) >> m
rounding = xor + r[-1] - 2 * prod
return d + rounding
def Mod2m(a_prime, a, k, m, kappa, signed): def Mod2m(a_prime, a, k, m, kappa, signed):
""" """

View File

@@ -369,6 +369,8 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
x, pow2m = B2U(m, l, kappa) x, pow2m = B2U(m, l, kappa)
#assert(pow2m.value == 2**m.value) #assert(pow2m.value == 2**m.value)
#assert(sum(b.value for b in x) == m.value) #assert(sum(b.value for b in x) == m.value)
if program.Program.prog.options.ring and not compute_modulo:
return TruncInRing(a, l, pow2m)
for i in range(l): for i in range(l):
bit(r[i]) bit(r[i])
t1 = two_power(i) * r[i] t1 = two_power(i) * r[i]
@@ -418,7 +420,7 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa) overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
if program.Program.prog.options.ring: if program.Program.prog.options.ring:
s = (1 - overflow) * t + \ s = (1 - overflow) * t + \
comparison.TruncZeroesInRing(overflow * t, length, 1, False) comparison.TruncLeakyInRing(overflow * t, length, 1, False)
else: else:
s = (1 - overflow) * t + overflow * t / 2 s = (1 - overflow) * t + overflow * t / 2
return s, overflow return s, overflow
@@ -484,9 +486,22 @@ def TruncPr(a, k, m, kappa=None):
def TruncPrRing(a, k, m): def TruncPrRing(a, k, m):
if m == 0: if m == 0:
return a return a
res = types.sint() n_ring = int(program.Program.prog.options.ring)
comparison.TruncRing(res, a, k, m, True) if k == n_ring:
return res for i in range(m):
a += types.sint.get_random_bit() << i
return comparison.TruncLeakyInRing(a, k, m, True)
else:
from types import sint
# extra bit to mask overflow
r_bits = [sint.get_random_bit() for i in range(k + 1)]
n_shift = n_ring - len(r_bits)
tmp = a + sint.bit_compose(r_bits)
masked = (tmp << n_shift).reveal()
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m))
return res
def TruncPrField(a, k, m, kappa=None): def TruncPrField(a, k, m, kappa=None):
if kappa is None: if kappa is None:
@@ -504,27 +519,26 @@ def TruncPrField(a, k, m, kappa=None):
d = (a - a_prime) / two_to_m d = (a - a_prime) / two_to_m
return d return d
def SDiv(a, b, l, kappa): def SDiv(a, b, l, kappa, round_nearest=False):
theta = int(ceil(log(l / 3.5) / log(2))) theta = int(ceil(log(l / 3.5) / log(2)))
alpha = two_power(2*l) alpha = two_power(2*l)
beta = 1 / types.cint(two_power(l))
w = types.cint(int(2.9142 * two_power(l))) - 2 * b w = types.cint(int(2.9142 * two_power(l))) - 2 * b
x = alpha - b * w x = alpha - b * w
y = a * w y = a * w
y = TruncPr(y, 2 * l, l, kappa) y = y.round(2 * l + 1, l, kappa, round_nearest)
x2 = types.sint() x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
x1 = (x - x2) * beta x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
for i in range(theta-1): for i in range(theta-1):
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = TruncPr(y, 2 * l + 1, l + 1, kappa) y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa) x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest)
x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa) x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest)
x2 = types.sint() x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l, l, kappa, False) comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
x1 = (x - x2) * beta x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa) y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
y = TruncPr(y, 2 * l + 1, l - 1, kappa) y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
return y return y
def SDiv_mono(a, b, l, kappa): def SDiv_mono(a, b, l, kappa):

View File

@@ -933,13 +933,6 @@ class print_reg_plain(base.IOInstruction):
code = base.opcodes['PRINTREGPLAIN'] code = base.opcodes['PRINTREGPLAIN']
arg_format = ['c'] arg_format = ['c']
@base.vectorize
class print_float_plain(base.IOInstruction):
__slots__ = []
code = base.opcodes['PRINTFLOATPLAIN']
arg_format = ['c', 'c', 'c', 'c']
class print_int(base.IOInstruction): class print_int(base.IOInstruction):
r""" Print only the value of register \verb|ci| to stdout. """ r""" Print only the value of register \verb|ci| to stdout. """
__slots__ = [] __slots__ = []
@@ -1358,10 +1351,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
data_type = 'triple' data_type = 'triple'
def get_repeat(self): def get_repeat(self):
if program.options.ring: return len(self.args) / 3
return 0
else:
return len(self.args) / 3
# def expand(self): # def expand(self):
# s = [program.curr_block.new_reg('s') for i in range(9)] # s = [program.curr_block.new_reg('s') for i in range(9)]
@@ -1378,6 +1368,73 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
# adds(s[8], s[7], s[6]) # adds(s[8], s[7], s[6])
# addm(self.args[0], s[8], c[2]) # addm(self.args[0], s[8], c[2])
@base.gf2n
class mulrs(base.VarArgsInstruction, base.DataInstruction):
""" Secret multiplication $s_i = s_j \cdot s_k$. """
__slots__ = []
code = base.opcodes['MULRS']
arg_format = tools.cycle(['int','sw','s','s'])
data_type = 'triple'
is_vec = lambda self: True
def __init__(self, res, x, y):
assert y.size == 1
assert res.size == x.size
base.Instruction.__init__(self, res.size, res, x, y)
def get_repeat(self):
return sum(self.args[::4])
def get_def(self):
return sum((arg.get_all() for arg in self.args[1::4]), [])
def get_used(self):
return sum((arg.get_all()
for arg in self.args[2::4] + self.args[3::4]), [])
@base.gf2n
class dotprods(base.VarArgsInstruction, base.DataInstruction):
""" Secret dot product. """
__slots__ = []
code = base.opcodes['DOTPRODS']
data_type = 'triple'
def __init__(self, *args):
flat_args = []
for i in range(0, len(args), 3):
res, x, y = args[i:i+3]
assert len(x) == len(y)
flat_args += [2 * len(x) + 2, res]
for x, y in zip(x, y):
flat_args += [x, y]
base.Instruction.__init__(self, *flat_args)
@property
def arg_format(self):
field = 'g' if self.is_gf2n() else ''
for i in self.bases():
yield 'int'
yield 's' + field + 'w'
for j in range(self.args[i] - 2):
yield 's' + field
def bases(self):
i = 0
while i < len(self.args):
yield i
i += self.args[i]
def get_repeat(self):
return sum(self.args[i] / 2 for i in self.bases())
def get_def(self):
return [self.args[i + 1] for i in self.bases()]
def get_used(self):
for i in self.bases():
for reg in self.args[i + 2:i + self.args[i]]:
yield reg
### ###
### CISC-style instructions ### CISC-style instructions
### ###

View File

@@ -87,6 +87,8 @@ opcodes = dict(
# Open # Open
OPEN = 0xA5, OPEN = 0xA5,
MULS = 0xA6, MULS = 0xA6,
MULRS = 0xA7,
DOTPRODS = 0xA8,
# Data access # Data access
TRIPLE = 0x50, TRIPLE = 0x50,
BIT = 0x51, BIT = 0x51,
@@ -177,19 +179,14 @@ def int_to_bytes(x):
return [(x >> 8*i) % 256 for i in (3,2,1,0)] return [(x >> 8*i) % 256 for i in (3,2,1,0)]
global_vector_size = 1 global_vector_size_stack = []
global_vector_size_depth = 0
global_instruction_type_stack = ['modp'] global_instruction_type_stack = ['modp']
def set_global_vector_size(size): def set_global_vector_size(size):
global global_vector_size, global_vector_size_depth stack = global_vector_size_stack
if size == 1: if size == 1 and not stack:
return return
if global_vector_size == 1 or global_vector_size == size: stack.append(size)
global_vector_size = size
global_vector_size_depth += 1
else:
raise CompilerError('Cannot set global vector size when already set')
def set_global_instruction_type(t): def set_global_instruction_type(t):
if t == 'modp' or t == 'gf2n': if t == 'modp' or t == 'gf2n':
@@ -198,17 +195,19 @@ def set_global_instruction_type(t):
raise CompilerError('Invalid type %s for setting global instruction type') raise CompilerError('Invalid type %s for setting global instruction type')
def reset_global_vector_size(): def reset_global_vector_size():
global global_vector_size, global_vector_size_depth stack = global_vector_size_stack
if global_vector_size_depth > 0: if global_vector_size_stack:
global_vector_size_depth -= 1 stack.pop()
if global_vector_size_depth == 0:
global_vector_size = 1
def reset_global_instruction_type(): def reset_global_instruction_type():
global_instruction_type_stack.pop() global_instruction_type_stack.pop()
def get_global_vector_size(): def get_global_vector_size():
return global_vector_size stack = global_vector_size_stack
if stack:
return stack[-1]
else:
return 1
def get_global_instruction_type(): def get_global_instruction_type():
return global_instruction_type_stack[-1] return global_instruction_type_stack[-1]
@@ -243,10 +242,17 @@ def vectorize(instruction, global_dict=None):
@functools.wraps(instruction) @functools.wraps(instruction)
def maybe_vectorized_instruction(*args, **kwargs): def maybe_vectorized_instruction(*args, **kwargs):
if global_vector_size == 1: size = get_global_vector_size()
for arg in args:
try:
size = arg.size
break
except:
pass
if size == 1:
return instruction(*args, **kwargs) return instruction(*args, **kwargs)
else: else:
return Vectorized_Instruction(global_vector_size, *args, **kwargs) return Vectorized_Instruction(size, *args, **kwargs)
maybe_vectorized_instruction.vec_ins = Vectorized_Instruction maybe_vectorized_instruction.vec_ins = Vectorized_Instruction
maybe_vectorized_instruction.std_ins = instruction maybe_vectorized_instruction.std_ins = instruction
@@ -287,6 +293,8 @@ def gf2n(instruction):
else: else:
__format.append(__f[0] + 'g' + __f[1:]) __format.append(__f[0] + 'g' + __f[1:])
arg_format[:] = __format arg_format[:] = __format
elif isinstance(arg_format, property):
pass
else: else:
for __f in arg_format.args: for __f in arg_format.args:
reformat(__f) reformat(__f)

View File

@@ -1,4 +1,4 @@
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single
from Compiler.instructions import * from Compiler.instructions import *
from Compiler.util import tuplify,untuplify from Compiler.util import tuplify,untuplify
from Compiler import instructions,instructions_base,comparison,program,util from Compiler import instructions,instructions_base,comparison,program,util
@@ -98,6 +98,7 @@ def print_ln_if(cond, s):
if cond: if cond:
print_ln(s) print_ln(s)
else: else:
s += ' ' * ((len(s) + 3) % 4)
s += '\n' s += '\n'
while s: while s:
cond.print_if(s[:4]) cond.print_if(s[:4])
@@ -816,7 +817,7 @@ def map_reduce_single(n_parallel, n_loops, initializer, reducer, mem_state=None)
n_parallel = n_parallel or 1 n_parallel = n_parallel or 1
if mem_state is None: if mem_state is None:
# default to list of MemValues to allow varying types # default to list of MemValues to allow varying types
mem_state = [type(x).MemValue(x) for x in initializer()] mem_state = [MemValue(x) for x in initializer()]
use_array = False use_array = False
else: else:
# use Arrays for multithread version # use Arrays for multithread version

View File

@@ -68,6 +68,10 @@ class Program(object):
Compiler.instructions.gasm_open_class, \ Compiler.instructions.gasm_open_class, \
Compiler.instructions.muls_class, \ Compiler.instructions.muls_class, \
Compiler.instructions.gmuls_class, \ Compiler.instructions.gmuls_class, \
Compiler.instructions.mulrs_class, \
Compiler.instructions.gmulrs, \
Compiler.instructions.dotprods_class, \
Compiler.instructions.gdotprods, \
Compiler.instructions.asm_input_class, \ Compiler.instructions.asm_input_class, \
Compiler.instructions.gasm_input_class] Compiler.instructions.gasm_input_class]
import Compiler.GC.instructions as gc import Compiler.GC.instructions as gc
@@ -433,6 +437,7 @@ class Tape:
self.alloc_pool = scope.alloc_pool self.alloc_pool = scope.alloc_pool
else: else:
self.alloc_pool = defaultdict(set) self.alloc_pool = defaultdict(set)
self.purged = False
def new_reg(self, reg_type, size=None): def new_reg(self, reg_type, size=None):
return self.parent.new_reg(reg_type, size=size) return self.parent.new_reg(reg_type, size=size)
@@ -468,6 +473,23 @@ class Tape:
offset = self.get_offset(self.exit_block) offset = self.get_offset(self.exit_block)
self.exit_condition.set_relative_jump(offset) self.exit_condition.set_relative_jump(offset)
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
def purge(self):
relevant = lambda inst: inst.add_usage.__func__ is not \
Compiler.instructions_base.Instruction.add_usage.__func__
self.usage_instructions = filter(relevant, self.instructions)
del self.instructions
del self.defined_registers
self.purged = True
def add_usage(self, req_node):
if self.purged:
instructions = self.usage_instructions
else:
instructions = self.instructions
for inst in instructions:
inst.add_usage(req_node)
def __str__(self): def __str__(self):
return self.name return self.name
@@ -507,6 +529,8 @@ class Tape:
self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc'
def purge(self): def purge(self):
for block in self.basicblocks:
block.purge()
self._is_empty = (len(self.basicblocks) == 0) self._is_empty = (len(self.basicblocks) == 0)
del self.reg_values del self.reg_values
del self.basicblocks del self.basicblocks
@@ -772,8 +796,7 @@ class Tape:
def aggregate(self, *args): def aggregate(self, *args):
self.num = Tape.ReqNum() self.num = Tape.ReqNum()
for block in self.blocks: for block in self.blocks:
for inst in block.instructions: block.add_usage(self)
inst.add_usage(self)
res = reduce(lambda x,y: x + y.aggregate(self.name), res = reduce(lambda x,y: x + y.aggregate(self.name),
self.children, self.num) self.children, self.num)
return res return res
@@ -842,7 +865,7 @@ class Tape:
if size is None: if size is None:
size = Compiler.instructions_base.get_global_vector_size() size = Compiler.instructions_base.get_global_vector_size()
self.size = size self.size = size
if i: if i is not None:
self.i = i self.i = i
else: else:
self.i = program.reg_counter[reg_type] self.i = program.reg_counter[reg_type]
@@ -878,23 +901,29 @@ class Tape:
raise CompilerError('Cannot reset size of vector register') raise CompilerError('Cannot reset size of vector register')
def set_vectorbase(self, vectorbase): def set_vectorbase(self, vectorbase):
if self.vectorbase != self: if self.vectorbase is not self:
raise CompilerError('Cannot assign one register' \ raise CompilerError('Cannot assign one register' \
'to several vectors') 'to several vectors')
self.vectorbase = vectorbase self.vectorbase = vectorbase
def _new_by_number(self, i):
return Tape.Register(self.reg_type, self.program, size=1, i=i)
def create_vector_elements(self): def create_vector_elements(self):
if self.vector: if self.vector:
return return
elif self.size == 1: elif self.size == 1:
self.vector = [self] self.vector = [self]
return return
self.vector = [self] self.vector = []
for i in range(1,self.size): for i in range(self.size):
reg = Tape.Register(self.reg_type, self.program, size=1, i=self.i+i) reg = self._new_by_number(self.i + i)
reg.set_vectorbase(self) reg.set_vectorbase(self)
self.vector.append(reg) self.vector.append(reg)
def get_all(self):
return self.vector or [self]
def __getitem__(self, index): def __getitem__(self, index):
if not self.vector: if not self.vector:
self.create_vector_elements() self.create_vector_elements()

File diff suppressed because it is too large Load Diff

View File

@@ -27,11 +27,11 @@ def greater_than(a, b, bits):
else: else:
return a.greater_than(b, bits) return a.greater_than(b, bits)
def pow2(a, bits): def pow2_value(a, bit_length=None, security=None):
if isinstance(a, int): if is_constant_float(a):
return 2**a return 2**a
else: else:
return a.pow2(bits) return a.pow2(bit_length, security)
def mod2m(a, b, bits, signed): def mod2m(a, b, bits, signed):
if isinstance(a, int): if isinstance(a, int):
@@ -95,7 +95,16 @@ def cond_swap(cond, a, b):
def log2(x): def log2(x):
#print 'Compute log2 of', x #print 'Compute log2 of', x
return int(math.ceil(math.log(x, 2))) if is_constant_float(x):
return int(math.ceil(math.log(x, 2)))
else:
return x.log2()
def round_to_int(x):
if is_constant_float(x):
return int(round(x))
else:
return x.round_to_int()
def tree_reduce(function, sequence): def tree_reduce(function, sequence):
sequence = list(sequence) sequence = list(sequence)
@@ -165,3 +174,9 @@ def long_one(x):
except: except:
pass pass
return 1 return 1
def expand(x, size):
try:
return x.expand_to_vector(size)
except AttributeError:
return x

View File

@@ -340,12 +340,14 @@ void MultiEncCommit<FD>::add_ciphertexts(vector<Ciphertext>& ciphertexts,
template class SimpleEncCommitBase<gfp, FFT_Data, bigint>; template class SimpleEncCommitBase<gfp, FFT_Data, bigint>;
template class SimpleEncCommit<gfp, FFT_Data, bigint>; template class SimpleEncCommit<gfp, FFT_Data, bigint>;
template class SimpleEncCommitFactory<FFT_Data>;
template class SummingEncCommit<FFT_Data>; template class SummingEncCommit<FFT_Data>;
template class NonInteractiveProofSimpleEncCommit<FFT_Data>; template class NonInteractiveProofSimpleEncCommit<FFT_Data>;
template class MultiEncCommit<FFT_Data>; template class MultiEncCommit<FFT_Data>;
template class SimpleEncCommitBase<gf2n_short, P2Data, int>; template class SimpleEncCommitBase<gf2n_short, P2Data, int>;
template class SimpleEncCommit<gf2n_short, P2Data, int>; template class SimpleEncCommit<gf2n_short, P2Data, int>;
template class SimpleEncCommitFactory<P2Data>;
template class SummingEncCommit<P2Data>; template class SummingEncCommit<P2Data>;
template class NonInteractiveProofSimpleEncCommit<P2Data>; template class NonInteractiveProofSimpleEncCommit<P2Data>;
template class MultiEncCommit<P2Data>; template class MultiEncCommit<P2Data>;

View File

@@ -14,6 +14,7 @@
#include "Tools/benchmarking.h" #include "Tools/benchmarking.h"
#include "Auth/fake-stuff.hpp" #include "Auth/fake-stuff.hpp"
#include "Processor/Data_Files.hpp"
#include <sstream> #include <sstream>
#include <fstream> #include <fstream>

View File

@@ -13,8 +13,9 @@
#include "Program.hpp" #include "Program.hpp"
#include "Thread.hpp" #include "Thread.hpp"
#include "ThreadMaster.hpp" #include "ThreadMaster.hpp"
#include "Auth/MaliciousRepMC.hpp"
#include "Processor/Replicated.hpp" #include "Processor/Machine.hpp"
#include "Processor/Instruction.hpp"
namespace GC namespace GC
{ {

View File

@@ -8,6 +8,7 @@
#include "Math/Setup.h" #include "Math/Setup.h"
#include "Auth/MaliciousRepMC.hpp" #include "Auth/MaliciousRepMC.hpp"
#include "Processor/Data_Files.hpp"
namespace GC namespace GC
{ {
@@ -61,7 +62,7 @@ void MaliciousRepThread::and_(Processor<MaliciousRepSecret>& processor,
int n_bits = args[i]; int n_bits = args[i];
int left = args[i + 2]; int left = args[i + 2];
int right = args[i + 3]; int right = args[i + 3];
triples.push_back({0}); triples.push_back({{0}});
DataF.get(DATA_TRIPLE, triples.back().data()); DataF.get(DATA_TRIPLE, triples.back().data());
shares.push_back((processor.S[left] - triples.back()[0]).mask(n_bits)); shares.push_back((processor.S[left] - triples.back()[0]).mask(n_bits));
MaliciousRepSecret y_ext; MaliciousRepSecret y_ext;

View File

@@ -13,7 +13,7 @@ union matrix32x8
__m256i whole; __m256i whole;
octet rows[32]; octet rows[32];
matrix32x8(__m256i x = _mm256_setzero_si256()) : whole(x) {} matrix32x8(const __m256i& x = _mm256_setzero_si256()) : whole(x) {}
matrix32x8(square64& input, int x, int y) matrix32x8(square64& input, int x, int y)
{ {
@@ -23,6 +23,7 @@ union matrix32x8
void transpose(square64& output, int x, int y) void transpose(square64& output, int x, int y)
{ {
#ifdef __AVX2__
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++)
{ {
int row = _mm256_movemask_epi8(whole); int row = _mm256_movemask_epi8(whole);
@@ -31,6 +32,10 @@ union matrix32x8
// _mm_movemask_epi8 uses most significant bit, hence +7-j // _mm_movemask_epi8 uses most significant bit, hence +7-j
output.halfrows[8*x+7-j][y] = row; output.halfrows[8*x+7-j][y] = row;
} }
#else
(void) output, (void) x, (void) y;
throw runtime_error("need to compile with AVX2 support");
#endif
} }
}; };
@@ -51,8 +56,10 @@ case I: \
HIGHS = _mm256_unpackhi_epi##I(A, B); \ HIGHS = _mm256_unpackhi_epi##I(A, B); \
break; break;
void zip(int chunk_size, __m256i& lows, __m256i& highs, __m256i a, __m256i b) void zip(int chunk_size, __m256i& lows, __m256i& highs,
const __m256i& a, const __m256i& b)
{ {
#ifdef __AVX2__
switch (chunk_size) switch (chunk_size)
{ {
ZIP_CASE(8, lows, highs, a, b); ZIP_CASE(8, lows, highs, a, b);
@@ -67,6 +74,10 @@ void zip(int chunk_size, __m256i& lows, __m256i& highs, __m256i a, __m256i b)
default: default:
throw invalid_argument("not supported"); throw invalid_argument("not supported");
} }
#else
(void) chunk_size, (void) lows, (void) highs, (void) a, (void) b;
throw runtime_error("need to compile with AVX2 support");
#endif
} }
void square64::transpose(int n_rows, int n_cols) void square64::transpose(int n_rows, int n_cols)

49
Machines/Rep.cpp Normal file
View File

@@ -0,0 +1,49 @@
/*
* Rep.cpp
*
*/
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
template<>
Preprocessing<Rep3Share<gfp>>* Preprocessing<Rep3Share<gfp>>::get_live_prep(
SubProcessor<Rep3Share<gfp>>* proc)
{
return new ReplicatedPrep<Rep3Share<gfp>>(proc);
}
template<>
Preprocessing<Rep3Share<gf2n>>* Preprocessing<Rep3Share<gf2n>>::get_live_prep(
SubProcessor<Rep3Share<gf2n>>* proc)
{
return new ReplicatedPrep<Rep3Share<gf2n>>(proc);
}
template<>
Preprocessing<Rep3Share<Integer>>* Preprocessing<Rep3Share<Integer>>::get_live_prep(
SubProcessor<Rep3Share<Integer>>* proc)
{
return new ReplicatedRingPrep<Rep3Share<Integer>>(proc);
}
template<>
Preprocessing<MaliciousRep3Share<gfp>>* Preprocessing<MaliciousRep3Share<gfp>>::get_live_prep(
SubProcessor<MaliciousRep3Share<gfp>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousRep3Share<gfp>>(proc);
}
template<>
Preprocessing<MaliciousRep3Share<gf2n>>* Preprocessing<MaliciousRep3Share<gf2n>>::get_live_prep(
SubProcessor<MaliciousRep3Share<gf2n>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousRep3Share<gf2n>>(proc);
}
template class Machine<Rep3Share<Integer>, Rep3Share<gf2n>>;
template class Machine<Rep3Share<gfp>, Rep3Share<gf2n>>;
template class Machine<MaliciousRep3Share<gfp>, MaliciousRep3Share<gf2n>>;

6
Machines/SPDZ.cpp Normal file
View File

@@ -0,0 +1,6 @@
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
template class Machine<sgfp, Share<gf2n>>;

View File

@@ -3,13 +3,17 @@
* *
*/ */
#include "ShamirMachine.h" #include <Machines/ShamirMachine.h>
#include "Math/ShamirShare.h" #include "Math/ShamirShare.h"
#include "Math/MaliciousShamirShare.h" #include "Math/MaliciousShamirShare.h"
#include "Math/gfp.h" #include "Math/gfp.h"
#include "Math/gf2n.h" #include "Math/gf2n.h"
#include "ReplicatedMachine.hpp" #include "Processor/ReplicatedMachine.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
ShamirMachine* ShamirMachine::singleton = 0; ShamirMachine* ShamirMachine::singleton = 0;
@@ -64,5 +68,38 @@ ShamirMachineSpec<T>::ShamirMachineSpec(int argc, const char** argv) :
ReplicatedMachine<T<gfp>, T<gf2n>>(argc, argv, "shamir", opt, nparties); ReplicatedMachine<T<gfp>, T<gf2n>>(argc, argv, "shamir", opt, nparties);
} }
template<>
Preprocessing<ShamirShare<gfp>>* Preprocessing<ShamirShare<gfp>>::get_live_prep(
SubProcessor<ShamirShare<gfp>>* proc)
{
return new ReplicatedPrep<ShamirShare<gfp>>(proc);
}
template<>
Preprocessing<ShamirShare<gf2n>>* Preprocessing<ShamirShare<gf2n>>::get_live_prep(
SubProcessor<ShamirShare<gf2n>>* proc)
{
return new ReplicatedPrep<ShamirShare<gf2n>>(proc);
}
template<>
Preprocessing<MaliciousShamirShare<gfp>>* Preprocessing<MaliciousShamirShare<gfp>>::get_live_prep(
SubProcessor<MaliciousShamirShare<gfp>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousShamirShare<gfp>>(proc);
}
template<>
Preprocessing<MaliciousShamirShare<gf2n>>* Preprocessing<MaliciousShamirShare<gf2n>>::get_live_prep(
SubProcessor<MaliciousShamirShare<gf2n>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousShamirShare<gf2n>>(proc);
}
template class ShamirMachineSpec<ShamirShare>; template class ShamirMachineSpec<ShamirShare>;
template class ShamirMachineSpec<MaliciousShamirShare>; template class ShamirMachineSpec<MaliciousShamirShare>;
template class Machine<ShamirShare<gfp>, ShamirShare<gf2n>>;
template class Machine<MaliciousShamirShare<gfp>, MaliciousShamirShare<gf2n>>;

View File

@@ -3,8 +3,8 @@
* *
*/ */
#ifndef PROCESSOR_SHAMIRMACHINE_H_ #ifndef MACHINES_SHAMIRMACHINE_H_
#define PROCESSOR_SHAMIRMACHINE_H_ #define MACHINES_SHAMIRMACHINE_H_
#include "Tools/ezOptionParser.h" #include "Tools/ezOptionParser.h"
@@ -31,4 +31,4 @@ public:
ShamirMachineSpec(int argc, const char** argv); ShamirMachineSpec(int argc, const char** argv);
}; };
#endif /* PROCESSOR_SHAMIRMACHINE_H_ */ #endif /* MACHINES_SHAMIRMACHINE_H_ */

View File

@@ -28,22 +28,25 @@ endif
COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH)
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC)
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) Processor/BaseMachine.o Processor/ProcessorBase.o
LIB = libSPDZ.a LIB = libSPDZ.a
LIBHM = libhm.a
LIBSIMPLEOT = SimpleOT/libsimpleot.a LIBSIMPLEOT = SimpleOT/libsimpleot.a
# used for dependency generation # used for dependency generation
OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp))
DEPS := $(OBJS:.o=.d) DEPS := $(OBJS:.o=.d)
all: gen_input online offline externalIO yao replicated shamir all: gen_input online offline externalIO yao replicated shamir
ifeq ($(USE_GF2N_LONG),1) ifeq ($(USE_GF2N_LONG),1)
ifneq ($(OS), Darwin)
all: bmr all: bmr
endif endif
endif
ifeq ($(USE_NTL),1) ifeq ($(USE_NTL),1)
all: overdrive she-offline all: overdrive she-offline
@@ -78,21 +81,40 @@ rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x
replicated: rep-field rep-ring rep-bin replicated: rep-field rep-ring rep-bin
tldr: malicious-rep-field-party.x Setup.x tldr:
-echo ARCH = -march=native >> CONFIG.mine
$(MAKE) malicious-rep-field-party.x Setup.x
ifeq ($(OS), Darwin)
tldr: mac-setup
else
tldr: mpir
endif
shamir: shamir-party.x malicious-shamir-party.x galois-degree.x shamir: shamir-party.x malicious-shamir-party.x galois-degree.x
Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) $(LIBHM): Machines/Rep.o Machines/ShamirMachine.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ Fake-Offline.cpp $(COMMON) $(PROCESSOR) $(LDLIBS) $(AR) -csr $@ $^
static/%.x: %.cpp $(LIBHM) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static-dir:
@ mkdir static 2> /dev/null; true
static-hm: static-dir $(patsubst %.cpp, static/%.x, $(wildcard *ring*.cpp *field*.cpp *shamir*.cpp ))
Fake-Offline.x: Fake-Offline.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) Auth/fake-stuff.hpp Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) Auth/fake-stuff.hpp
$(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS) $(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(LDLIBS)
Server.x: Server.cpp $(COMMON) Server.x: Server.cpp $(COMMON)
$(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS) $(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS)
Player-Online.x: Player-Online.cpp $(COMMON) $(PROCESSOR) Player-Online.x: Player-Online.cpp Machines/SPDZ.o $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) Player-Online.cpp -o Player-Online.x $(COMMON) $(PROCESSOR) $(LDLIBS) $(CXX) $(CFLAGS) -o Player-Online.x $^ $(LDLIBS)
Setup.x: Setup.cpp $(COMMON) Setup.x: Setup.cpp $(COMMON)
$(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS) $(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS)
@@ -134,13 +156,13 @@ endif
bmr-clean: bmr-clean:
-rm BMR/*.o BMR/*/*.o GC/*.o -rm BMR/*.o BMR/*/*.o GC/*.o
client-setup.x: client-setup.cpp $(COMMON) $(PROCESSOR) client-setup.x: client-setup.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(PROCESSOR) bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) $(PROCESSOR) bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
ifeq ($(USE_NTL),1) ifeq ($(USE_NTL),1)
@@ -172,19 +194,19 @@ replicated-bin-party.x: $(COMMON) $(GC) replicated-bin-party.cpp
malicious-rep-bin-party.x: $(COMMON) $(GC) malicious-rep-bin-party.cpp malicious-rep-bin-party.x: $(COMMON) $(GC) malicious-rep-bin-party.cpp
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
replicated-ring-party.x: replicated-ring-party.cpp $(PROCESSOR) $(COMMON) replicated-ring-party.x: replicated-ring-party.cpp Machines/Rep.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
replicated-field-party.x: replicated-field-party.cpp $(PROCESSOR) $(COMMON) replicated-field-party.x: replicated-field-party.cpp Machines/Rep.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
malicious-rep-field-party.x: malicious-rep-field-party.cpp $(PROCESSOR) $(COMMON) malicious-rep-field-party.x: malicious-rep-field-party.cpp Machines/Rep.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
shamir-party.x: shamir-party.cpp $(PROCESSOR) $(COMMON) shamir-party.x: shamir-party.cpp Machines/ShamirMachine.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
malicious-shamir-party.x: malicious-shamir-party.cpp $(PROCESSOR) $(COMMON) malicious-shamir-party.x: malicious-shamir-party.cpp Machines/ShamirMachine.o $(PROCESSOR) $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
$(LIBSIMPLEOT): SimpleOT/Makefile $(LIBSIMPLEOT): SimpleOT/Makefile
@@ -196,15 +218,31 @@ OT/BaseOT.o: SimpleOT/Makefile
SimpleOT/Makefile: SimpleOT/Makefile:
git submodule update --init SimpleOT git submodule update --init SimpleOT
.PHONY: mpir .PHONY: mpir-setup mpir-global mpir
mpir: mpir-setup:
git submodule update --init mpir git submodule update --init mpir
cd mpir; \ cd mpir; \
autoreconf -i; \ autoreconf -i; \
autoreconf -i; \ autoreconf -i
- $(MAKE) -C mpir clean
mpir-global: mpir-setup
cd mpir; \
./configure --enable-cxx; ./configure --enable-cxx;
$(MAKE) -C mpir $(MAKE) -C mpir
sudo $(MAKE) -C mpir install sudo $(MAKE) -C mpir install
mpir: mpir-setup
cd mpir; \
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C mpir install
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,./local/lib -L./local/lib >> CONFIG.mine
mac-setup:
brew install openssl boost libsodium mpir yasm
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib >> CONFIG.mine
clean: clean:
-rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o -rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x

View File

@@ -17,6 +17,8 @@ public:
static char type_char() { return 'B'; } static char type_char() { return 'B'; }
static DataFieldType field_type() { return DATA_GF2; } static DataFieldType field_type() { return DATA_GF2; }
static bool allows(Dtype dtype) { return dtype == DATA_TRIPLE or dtype == DATA_BIT; }
BitVec() {} BitVec() {}
BitVec(long a) : IntBase(a) {} BitVec(long a) : IntBase(a) {}
BitVec(const IntBase& a) : IntBase(a) {} BitVec(const IntBase& a) : IntBase(a) {}

View File

@@ -27,6 +27,8 @@ public:
static void init_default(int lgp) { (void)lgp; } static void init_default(int lgp) { (void)lgp; }
static bool allows(Dtype type) { return type <= DATA_BIT; }
IntBase() { a = 0; } IntBase() { a = 0; }
IntBase(long a) : a(a) {} IntBase(long a) : a(a) {}
@@ -79,7 +81,7 @@ class Integer : public IntBase
typedef Integer clear; typedef Integer clear;
static char type_char() { return 'R'; } static char type_char() { return 'R'; }
static DataFieldType field_type() { return DATA_INT64; } static DataFieldType field_type() { return DATA_INT; }
static void reqbl(int n); static void reqbl(int n);

View File

@@ -83,6 +83,11 @@ public:
*this = Rep3Share(aa, my_num) - S; *this = Rep3Share(aa, my_num) - S;
} }
clear local_mul(const Rep3Share& other) const
{
return (*this)[0] * other.sum() + (*this)[1] * other[0];
}
void mul_by_bit(const Rep3Share& x, const T& y) void mul_by_bit(const Rep3Share& x, const T& y)
{ {
(void) x, (void) y; (void) x, (void) y;

View File

@@ -196,6 +196,7 @@ inline void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) c
} }
} }
#ifdef __BMI2__
template <int T> template <int T>
inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
{ {
@@ -219,17 +220,20 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
else else
{ inline_mpn_copyi(z,ans+T,T); } { inline_mpn_copyi(z,ans+T,T); }
} }
#endif
inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
{ {
switch (t) switch (t)
{ {
#ifdef __BMI2__
case 2: case 2:
Mont_Mult_<2>(z, x, y); Mont_Mult_<2>(z, x, y);
break; break;
case 1: case 1:
Mont_Mult_<1>(z, x, y); Mont_Mult_<1>(z, x, y);
break; break;
#endif
default: default:
Mont_Mult_variable(z, x, y); Mont_Mult_variable(z, x, y);
break; break;

View File

@@ -7,7 +7,8 @@
#define MATH_FIELD_TYPES_H_ #define MATH_FIELD_TYPES_H_
enum DataFieldType { DATA_MODP, DATA_GF2N, DATA_GF2, DATA_INT64, N_DATA_FIELD_TYPE }; enum DataFieldType { DATA_INT, DATA_GF2N, DATA_GF2, N_DATA_FIELD_TYPE };
enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE };
#endif /* MATH_FIELD_TYPES_H_ */ #endif /* MATH_FIELD_TYPES_H_ */

View File

@@ -91,7 +91,11 @@ void gf2n_short::init_field(int nn)
mask=(1ULL<<n)-1; mask=(1ULL<<n)-1;
#ifdef __PCLMUL__
useC=(Check_CPU_support_AES()==0); useC=(Check_CPU_support_AES()==0);
#else
useC = true;
#endif
} }
@@ -222,6 +226,7 @@ void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y)
} }
else else
{ /* Use Intel Instructions */ { /* Use Intel Instructions */
#ifdef __PCLMUL__
__m128i xx,yy,zz; __m128i xx,yy,zz;
uint64_t c[] __attribute__((aligned (16))) = { 0,0 }; uint64_t c[] __attribute__((aligned (16))) = { 0,0 };
xx=_mm_set1_epi64x(x.a); xx=_mm_set1_epi64x(x.a);
@@ -230,6 +235,9 @@ void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y)
_mm_store_si128((__m128i*)c,zz); _mm_store_si128((__m128i*)c,zz);
lo=c[0]; lo=c[0];
hi=c[1]; hi=c[1];
#else
throw runtime_error("need to compile with PCLMUL support");
#endif
} }
reduce(hi,lo); reduce(hi,lo);

View File

@@ -80,6 +80,8 @@ class gf2n_short
static int default_length() { return 40; } static int default_length() { return 40; }
static bool allows(Dtype type) { (void) type; return true; }
word get() const { return a; } word get() const { return a; }
word get_word() const { return a; } word get_word() const { return a; }

View File

@@ -31,9 +31,13 @@ public:
int128(const word& upper, const word& lower) : a(_mm_set_epi64x(upper, lower)) { } int128(const word& upper, const word& lower) : a(_mm_set_epi64x(upper, lower)) { }
word get_lower() const { return (word)_mm_cvtsi128_si64(a); } word get_lower() const { return (word)_mm_cvtsi128_si64(a); }
word get_upper() const { return _mm_extract_epi64(a, 1); } word get_upper() const { return _mm_cvtsi128_si64(_mm_unpackhi_epi64(a, a)); }
#ifdef __SSE41__
bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); } bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); }
#else
bool operator==(const int128& other) const { return get_lower() == other.get_lower() and get_upper() == other.get_upper(); }
#endif
bool operator!=(const int128& other) const { return !(*this == other); } bool operator!=(const int128& other) const { return !(*this == other); }
int128 operator<<(const int& other) const; int128 operator<<(const int& other) const;
@@ -125,6 +129,8 @@ class gf2n_long
static int default_length() { return 128; } static int default_length() { return 128; }
static bool allows(Dtype type) { (void) type; return true; }
int128 get() const { return a; } int128 get() const { return a; }
__m128i to_m128i() const { return a.a; } __m128i to_m128i() const { return a.a; }
word get_word() const { return _mm_cvtsi128_si64(a.a); } word get_word() const { return _mm_cvtsi128_si64(a.a); }
@@ -278,10 +284,15 @@ inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2)
{ {
__m128i tmp3, tmp4, tmp5, tmp6; __m128i tmp3, tmp4, tmp5, tmp6;
#ifdef __PCLMUL__
tmp3 = _mm_clmulepi64_si128(a, b, 0x00); tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
tmp4 = _mm_clmulepi64_si128(a, b, 0x10); tmp4 = _mm_clmulepi64_si128(a, b, 0x10);
tmp5 = _mm_clmulepi64_si128(a, b, 0x01); tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
tmp6 = _mm_clmulepi64_si128(a, b, 0x11); tmp6 = _mm_clmulepi64_si128(a, b, 0x11);
#else
(void) a, (void) b;
throw runtime_error("need to compile with PCLMUL support");
#endif
tmp4 = _mm_xor_si128(tmp4, tmp5); tmp4 = _mm_xor_si128(tmp4, tmp5);
tmp5 = _mm_slli_si128(tmp4, 8); tmp5 = _mm_slli_si128(tmp4, 8);

View File

@@ -151,6 +151,18 @@ void gfp::reqbl(int n)
} }
} }
bool gfp::allows(Dtype type)
{
switch(type)
{
case DATA_BITGF2NTRIPLE:
case DATA_BITTRIPLE:
return false;
default:
return true;
}
}
void to_signed_bigint(bigint& ans, const gfp& x) void to_signed_bigint(bigint& ans, const gfp& x)
{ {
to_bigint(ans, x); to_bigint(ans, x);

View File

@@ -47,7 +47,7 @@ class gfp
static Zp_Data& get_ZpD() static Zp_Data& get_ZpD()
{ return ZpD; } { return ZpD; }
static DataFieldType field_type() { return DATA_MODP; } static DataFieldType field_type() { return DATA_INT; }
static char type_char() { return 'p'; } static char type_char() { return 'p'; }
static string type_string() { return "gfp"; } static string type_string() { return "gfp"; }
@@ -55,6 +55,8 @@ class gfp
static void reqbl(int n); static void reqbl(int n);
static bool allows(Dtype type);
void assign(const gfp& g) { a=g.a; } void assign(const gfp& g) { a=g.a; }
void assign_zero() { assignZero(a,ZpD); } void assign_zero() { assignZero(a,ZpD); }
void assign_one() { assignOne(a,ZpD); } void assign_one() { assignOne(a,ZpD); }
@@ -108,7 +110,7 @@ class gfp
bool is_zero() const { return isZero(a,ZpD); } bool is_zero() const { return isZero(a,ZpD); }
bool is_one() const { return isOne(a,ZpD); } bool is_one() const { return isOne(a,ZpD); }
bool is_bit() const { return is_zero() or is_one(); } bool is_bit() const { return is_zero() or is_one(); }
bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); }
bool operator==(const gfp& y) const { return equal(y); } bool operator==(const gfp& y) const { return equal(y); }

View File

@@ -192,6 +192,7 @@ inline void mpn_add_n_use_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_lim
} }
} }
#ifdef __BMI2__
template <int L, int M, bool ADD> template <int L, int M, bool ADD>
inline void mpn_addmul_1_fixed__(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x) inline void mpn_addmul_1_fixed__(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x)
{ {
@@ -267,5 +268,6 @@ inline void mpn_mul_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y
{ {
mpn_mul_fixed_<N + M, N, M>(res, x, y); mpn_mul_fixed_<N + M, N, M>(res, x, y);
} }
#endif
#endif /* MATH_MPN_FIXED_H_ */ #endif /* MATH_MPN_FIXED_H_ */

View File

@@ -8,8 +8,8 @@
template <class T> template <class T>
T operator*(const bool& x, const T& y) { return x ? y : T(); } T operator*(const bool& x, const T& y) { return x ? y : T(); }
template <class T> //template <class T>
T operator*(const T& y, const bool& x) { return x ? y : T(); } //T operator*(const T& y, const bool& x) { return x ? y : T(); }
template <class T> template <class T>
T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; } T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; }

View File

@@ -76,6 +76,7 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
CryptoPlayer::~CryptoPlayer() CryptoPlayer::~CryptoPlayer()
{ {
close_client_socket(plaintext_player.socket(my_num()));
plaintext_player.sockets.clear(); plaintext_player.sockets.clear();
for (int i = 0; i < num_players(); i++) for (int i = 0; i < num_players(); i++)
delete sockets[i]; delete sockets[i];

View File

@@ -188,10 +188,15 @@ template<>
MultiPlayer<int>::~MultiPlayer() MultiPlayer<int>::~MultiPlayer()
{ {
/* Close down the sockets */ /* Close down the sockets */
for (int i=0; i<nplayers; i++) for (auto socket : sockets)
close_client_socket(sockets[i]); close_client_socket(socket);
close_client_socket(send_to_self_socket);
} }
template<class T>
MultiPlayer<T>::~MultiPlayer()
{
}
Player::~Player() Player::~Player()
{ {
@@ -336,6 +341,11 @@ void MultiPlayer<T>::exchange(int other, octetStream& o) const
sent += o.get_length(); sent += o.get_length();
} }
void Player::exchange_relative(int offset, octetStream& o) const
{
exchange(get_player(offset), o);
}
template<class T> template<class T>
void MultiPlayer<T>::pass_around(octetStream& o, int offset) const void MultiPlayer<T>::pass_around(octetStream& o, int offset) const

View File

@@ -169,6 +169,10 @@ public:
void receive_relative(vector<octetStream>& o) const; void receive_relative(vector<octetStream>& o) const;
void receive_relative(int offset, octetStream& o) const; void receive_relative(int offset, octetStream& o) const;
// exchange data with minimal memory usage
void exchange(int other, octetStream& o) const = 0;
void exchange_relative(int offset, octetStream& o) const;
/* Broadcast and Receive data to/from all players /* Broadcast and Receive data to/from all players
* - Assumes o[player_no] contains the thing broadcast by me * - Assumes o[player_no] contains the thing broadcast by me
*/ */
@@ -204,7 +208,7 @@ public:
// portnum bases in each thread // portnum bases in each thread
MultiPlayer(const Names& Nms,int id_base=0); MultiPlayer(const Names& Nms,int id_base=0);
virtual ~MultiPlayer() {} virtual ~MultiPlayer();
T socket(int i) const { return sockets[i]; } T socket(int i) const { return sockets[i]; }

View File

@@ -151,7 +151,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
void close_client_socket(int socket) void close_client_socket(int socket)
{ {
if (close(socket) < 0 and errno != EBADF) if (close(socket))
{ {
char tmp[1000]; char tmp[1000];
sprintf(tmp, "close(%d)", socket); sprintf(tmp, "close(%d)", socket);

View File

@@ -56,16 +56,27 @@ int get_ack(int socket);
extern unsigned long long sent_amount, sent_counter; extern unsigned long long sent_amount, sent_counter;
inline size_t send_non_blocking(int socket, octet* msg, size_t len)
{
int j = send(socket,msg,len,0);
if (j < 0)
{
if (errno != EINTR)
{ error("Send error - 1 "); }
else
return 0;
}
return j;
}
template<> template<>
inline void send(int socket,octet *msg,size_t len) inline void send(int socket,octet *msg,size_t len)
{ {
size_t i = 0; size_t i = 0;
while (i < len) while (i < len)
{ {
int j = send(socket,msg+i,len-i,0); i += send_non_blocking(socket, msg + i, len - i);
i += j;
if (j < 0 and errno != EINTR)
{ error("Send error - 1 "); }
} }
sent_amount += len; sent_amount += len;

View File

@@ -14,12 +14,17 @@
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket; typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket;
inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length)
{
return socket->write_some(boost::asio::buffer(data, length));
}
template<> template<>
inline void send(ssl_socket* socket, octet* data, size_t length) inline void send(ssl_socket* socket, octet* data, size_t length)
{ {
size_t sent = 0; size_t sent = 0;
while (sent < length) while (sent < length)
sent += socket->write_some(boost::asio::buffer(data + sent, length - sent)); sent += send_non_blocking(socket, data + sent, length - sent);
} }
template<> template<>

View File

@@ -384,8 +384,7 @@ bool square128::operator==(square128& other)
{ {
for (int i = 0; i < 128; i++) for (int i = 0; i < 128; i++)
{ {
__m128i tmp = rows[i] ^ other.rows[i]; if (int128(rows[i]) != other.rows[i])
if (not _mm_test_all_zeros(tmp, tmp))
return false; return false;
} }
return true; return true;

View File

@@ -137,7 +137,7 @@ class BitVector
memcpy(bytes + offset, (octet*)&w, sizeof(word)); memcpy(bytes + offset, (octet*)&w, sizeof(word));
} }
int128 get_int128(int i) const { return _mm_lddqu_si128((__m128i*)bytes + i); } int128 get_int128(int i) const { return _mm_loadu_si128((__m128i*)bytes + i); }
void set_int128(int i, int128 a) { *((__m128i*)bytes + i) = a.a; } void set_int128(int i, int128 a) { *((__m128i*)bytes + i) = a.a; }
int get_bit(int i) const int get_bit(int i) const

View File

@@ -143,6 +143,15 @@ int main(int argc, const char** argv)
"-N", // Flag token. "-N", // Flag token.
"--nparties" // Flag token. "--nparties" // Flag token.
); );
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use encrypted channels.", // Help description.
"-e", // Flag token.
"--encrypted" // Flag token.
);
opt.resetArgs(); opt.resetArgs();
opt.parse(argc, argv); opt.parse(argc, argv);
@@ -250,7 +259,8 @@ int main(int argc, const char** argv)
{ {
Machine<sgfp, Share<gf2n>>(playerno, playerNames, progname, memtype, lg2, Machine<sgfp, Share<gf2n>>(playerno, playerNames, progname, memtype, lg2,
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
opt.get("--threads")->isSet, max_broadcast, false, false, opt.get("--threads")->isSet, max_broadcast,
opt.get("--encrypted")->isSet, false,
online_opts).run(); online_opts).run();
if (server) if (server)

94
Processor/BaseMachine.cpp Normal file
View File

@@ -0,0 +1,94 @@
/*
* BaseMachine.cpp
*
*/
#include "BaseMachine.h"
#include <iostream>
using namespace std;
BaseMachine* BaseMachine::singleton = 0;
BaseMachine& BaseMachine::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
BaseMachine::BaseMachine() : nthreads(0)
{
if (not singleton)
singleton = this;
}
void BaseMachine::load_schedule(string progname)
{
this->progname = progname;
string fname = "Programs/Schedules/" + progname + ".sch";
cerr << "Opening file " << fname << endl;
inpf.open(fname);
if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); }
int nprogs;
inpf >> nthreads;
inpf >> nprogs;
cerr << "Number of threads I will run in parallel = " << nthreads << endl;
cerr << "Number of program sequences I need to load = " << nprogs << endl;
// Load in the programs
string threadname;
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
string filename = "Programs/Bytecode/" + threadname + ".bc";
cerr << "Loading program " << i << " from " << filename << endl;
load_program(threadname, filename);
}
}
void BaseMachine::print_compiler()
{
char compiler[1000];
inpf.get();
inpf.getline(compiler, 1000);
if (compiler[0] != 0)
cerr << "Compiler: " << compiler << endl;
inpf.close();
}
void BaseMachine::load_program(string threadname, string filename)
{
(void)threadname;
(void)filename;
throw not_implemented();
}
void BaseMachine::time()
{
cout << "Elapsed time: " << timer[0].elapsed() << endl;
}
void BaseMachine::start(int n)
{
cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " after " << timer[n].idle() << endl;
timer[n].start();
}
void BaseMachine::stop(int n)
{
timer[n].stop();
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl;
}
void BaseMachine::print_timers()
{
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
timer.erase(0);
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
}

View File

@@ -7,22 +7,34 @@
#define PROCESSOR_BEAVER_H_ #define PROCESSOR_BEAVER_H_
#include <vector> #include <vector>
#include <array>
using namespace std; using namespace std;
#include "Replicated.h"
template<class T> class SubProcessor; template<class T> class SubProcessor;
template<class T> class MAC_Check_Base; template<class T> class MAC_Check_Base;
class Player; class Player;
template<class T> template<class T>
class Beaver class Beaver : public ProtocolBase<T>
{ {
vector<T> shares;
vector<typename T::clear> opened;
vector<array<T, 3>> triples;
typename vector<typename T::clear>::iterator it;
typename vector<array<T, 3>>::iterator triple;
SubProcessor<T>* proc;
public: public:
Player& P; Player& P;
static void muls(const vector<int>& reg, SubProcessor<T>& proc, Beaver(Player& P) : proc(0), P(P) {}
MAC_Check_Base<T>& MC, int size);
Beaver(Player& P) : P(P) {} void init_mul(SubProcessor<T>* proc);
typename T::clear prepare_mul(const T& x, const T& y);
void exchange();
T finalize_mul();
}; };
#endif /* PROCESSOR_BEAVER_H_ */ #endif /* PROCESSOR_BEAVER_H_ */

View File

@@ -7,44 +7,46 @@
#include <array> #include <array>
template<class T> template<class T>
void Beaver<T>::muls(const vector<int>& reg, SubProcessor<T>& proc, MAC_Check_Base<T>& MC, void Beaver<T>::init_mul(SubProcessor<T>* proc)
int size)
{ {
assert(reg.size() % 3 == 0); this->proc = proc;
int n = reg.size() / 3;
vector<T>& shares = proc.Sh_PO;
vector<typename T::clear>& opened = proc.PO;
shares.clear(); shares.clear();
vector<array<T, 3>> triples(n * size); opened.clear();
auto triple = triples.begin(); triples.clear();
}
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++) template<class T>
{ typename T::clear Beaver<T>::prepare_mul(const T& x, const T& y)
proc.DataF.get(DATA_TRIPLE, triple->data()); {
for (int k = 0; k < 2; k++) triples.push_back({{}});
shares.push_back(proc.S[reg[i * 3 + k + 1] + j] - (*triple)[k]); auto& triple = triples.back();
triple++; proc->DataF.get(DATA_TRIPLE, triple.data());
} shares.push_back(x - triple[0]);
shares.push_back(y - triple[1]);
MC.POpen_Begin(opened, shares, proc.P); return 0;
MC.POpen_End(opened, shares, proc.P); }
auto it = opened.begin();
triple = triples.begin(); template<class T>
void Beaver<T>::exchange()
for (int i = 0; i < n; i++) {
for (int j = 0; j < size; j++) proc->MC.POpen(opened, shares, P);
{ it = opened.begin();
typename T::clear masked[2]; triple = triples.begin();
T& tmp = (*triple)[2]; }
for (int k = 0; k < 2; k++)
{ template<class T>
masked[k] = *it++; T Beaver<T>::finalize_mul()
tmp += (masked[k] * (*triple)[1 - k]); {
} typename T::clear masked[2];
tmp.add(tmp, masked[0] * masked[1], proc.P.my_num(), MC.get_alphai()); T& tmp = (*triple)[2];
proc.S[reg[i * 3] + j] = tmp; for (int k = 0; k < 2; k++)
triple++; {
} masked[k] = *it++;
tmp += (masked[k] * (*triple)[1 - k]);
}
tmp.add(tmp, masked[0] * masked[1], P.my_num(), proc->MC.get_alphai());
triple++;
return tmp;
} }

View File

@@ -9,7 +9,7 @@
#include "Math/gf2n.h" #include "Math/gf2n.h"
#include "Math/Share.h" #include "Math/Share.h"
#include "Math/field_types.h" #include "Math/field_types.h"
#include "Processor/Buffer.h" #include "Tools/Buffer.h"
#include "Processor/InputTuple.h" #include "Processor/InputTuple.h"
#include "Tools/Lock.h" #include "Tools/Lock.h"
#include "Networking/Player.h" #include "Networking/Player.h"
@@ -18,8 +18,6 @@
#include <map> #include <map>
using namespace std; using namespace std;
enum Dtype { DATA_TRIPLE, DATA_SQUARE, DATA_BIT, DATA_INVERSE, DATA_BITTRIPLE, DATA_BITGF2NTRIPLE, N_DTYPE };
class DataTag class DataTag
{ {
int t[4]; int t[4];
@@ -95,8 +93,6 @@ class Sub_Data_Files : public Preprocessing<T>
{ {
template<class U> friend class Sub_Data_Files; template<class U> friend class Sub_Data_Files;
static const bool implemented[N_DTYPE];
static map<DataTag, int> tuple_lengths; static map<DataTag, int> tuple_lengths;
static Lock tuple_lengths_lock; static Lock tuple_lengths_lock;

View File

@@ -9,77 +9,17 @@
#include "Math/MaliciousShamirShare.h" #include "Math/MaliciousShamirShare.h"
#include "Processor/MaliciousRepPrep.hpp" #include "Processor/MaliciousRepPrep.hpp"
#include "Processor/Replicated.hpp" //#include "Processor/Replicated.hpp"
#include "Processor/ReplicatedPrep.hpp" #include "Processor/ReplicatedPrep.hpp"
#include "Processor/Input.hpp" //#include "Processor/Input.hpp"
#include "Processor/ReplicatedInput.hpp" //#include "Processor/ReplicatedInput.hpp"
#include "Processor/Shamir.hpp" //#include "Processor/Shamir.hpp"
#include "Auth/MaliciousShamirMC.hpp" //#include "Auth/MaliciousShamirMC.hpp"
#include <iomanip> #include <iomanip>
#include <numeric> #include <numeric>
const char* DataPositions::field_names[] = { "gfp", "gf2n", "bit", "int64" }; const char* DataPositions::field_names[] = { "int", "gf2n", "bit" };
template<>
const bool Sub_Data_Files<sgfp>::implemented[N_DTYPE] =
{ true, true, true, true, false, false }
;
template<>
const bool Sub_Data_Files<Share<gf2n>>::implemented[N_DTYPE] =
{ true, true, true, true, true, true }
;
template<>
const bool Sub_Data_Files<Rep3Share<Integer>>::implemented[N_DTYPE] =
{ false, false, true, false, false, false }
;
template<>
const bool Sub_Data_Files<Rep3Share<gfp>>::implemented[N_DTYPE] =
{ true, true, true, true, false, false }
;
template<>
const bool Sub_Data_Files<Rep3Share<gf2n>>::implemented[N_DTYPE] =
{ true, true, true, true, false, false }
;
template<>
const bool Sub_Data_Files<MaliciousRep3Share<gfp>>::implemented[N_DTYPE] =
{ true, true, true, true, false, false }
;
template<>
const bool Sub_Data_Files<MaliciousRep3Share<gf2n>>::implemented[N_DTYPE] =
{ true, true, true, true, false, false }
;
template<>
const bool Sub_Data_Files<GC::MaliciousRepSecret>::implemented[N_DTYPE] =
{ true, false, true, false, false, false }
;
template<>
const bool Sub_Data_Files<ShamirShare<gfp>>::implemented[N_DTYPE] =
{ false, false, false, false, false, false }
;
template<>
const bool Sub_Data_Files<ShamirShare<gf2n>>::implemented[N_DTYPE] =
{ false, false, false, false, false, false }
;
template<>
const bool Sub_Data_Files<MaliciousShamirShare<gfp>>::implemented[N_DTYPE] =
{ false, false, false, false, false, false }
;
template<>
const bool Sub_Data_Files<MaliciousShamirShare<gf2n>>::implemented[N_DTYPE] =
{ false, false, false, false, false, false }
;
const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 }; const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 };
@@ -88,73 +28,6 @@ Lock Sub_Data_Files<T>::tuple_lengths_lock;
template<class T> template<class T>
map<DataTag, int> Sub_Data_Files<T>::tuple_lengths; map<DataTag, int> Sub_Data_Files<T>::tuple_lengths;
template<>
Preprocessing<Rep3Share<gfp>>* Preprocessing<Rep3Share<gfp>>::get_live_prep(
SubProcessor<Rep3Share<gfp>>* proc)
{
return new ReplicatedPrep<Rep3Share<gfp>>(proc);
}
template<>
Preprocessing<Rep3Share<gf2n>>* Preprocessing<Rep3Share<gf2n>>::get_live_prep(
SubProcessor<Rep3Share<gf2n>>* proc)
{
return new ReplicatedPrep<Rep3Share<gf2n>>(proc);
}
template<>
Preprocessing<Rep3Share<Integer>>* Preprocessing<Rep3Share<Integer>>::get_live_prep(
SubProcessor<Rep3Share<Integer>>* proc)
{
return new ReplicatedRingPrep<Rep3Share<Integer>>(proc);
}
template<>
Preprocessing<MaliciousRep3Share<gfp>>* Preprocessing<MaliciousRep3Share<gfp>>::get_live_prep(
SubProcessor<MaliciousRep3Share<gfp>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousRep3Share<gfp>>(proc);
}
template<>
Preprocessing<MaliciousRep3Share<gf2n>>* Preprocessing<MaliciousRep3Share<gf2n>>::get_live_prep(
SubProcessor<MaliciousRep3Share<gf2n>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousRep3Share<gf2n>>(proc);
}
template<>
Preprocessing<ShamirShare<gfp>>* Preprocessing<ShamirShare<gfp>>::get_live_prep(
SubProcessor<ShamirShare<gfp>>* proc)
{
return new ReplicatedPrep<ShamirShare<gfp>>(proc);
}
template<>
Preprocessing<ShamirShare<gf2n>>* Preprocessing<ShamirShare<gf2n>>::get_live_prep(
SubProcessor<ShamirShare<gf2n>>* proc)
{
return new ReplicatedPrep<ShamirShare<gf2n>>(proc);
}
template<>
Preprocessing<MaliciousShamirShare<gfp>>* Preprocessing<MaliciousShamirShare<gfp>>::get_live_prep(
SubProcessor<MaliciousShamirShare<gfp>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousShamirShare<gfp>>(proc);
}
template<>
Preprocessing<MaliciousShamirShare<gf2n>>* Preprocessing<MaliciousShamirShare<gf2n>>::get_live_prep(
SubProcessor<MaliciousShamirShare<gf2n>>* proc)
{
(void) proc;
return new MaliciousRepPrep<MaliciousShamirShare<gf2n>>(proc);
}
template<class T> template<class T>
Preprocessing<T>* Preprocessing<T>::get_live_prep(SubProcessor<T>* proc) Preprocessing<T>* Preprocessing<T>::get_live_prep(SubProcessor<T>* proc)
{ {
@@ -270,7 +143,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
string suffix = get_suffix(thread_num); string suffix = get_suffix(thread_num);
for (int dtype = 0; dtype < N_DTYPE; dtype++) for (int dtype = 0; dtype < N_DTYPE; dtype++)
{ {
if (implemented[dtype]) if (T::clear::allows(Dtype(dtype)))
{ {
sprintf(filename,(prep_data_dir + "%s-%s-P%d%s").c_str(),DataPositions::dtype_names[dtype], sprintf(filename,(prep_data_dir + "%s-%s-P%d%s").c_str(),DataPositions::dtype_names[dtype],
(T::type_short()).c_str(),my_num,suffix.c_str()); (T::type_short()).c_str(),my_num,suffix.c_str());
@@ -329,7 +202,7 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
{ {
DataFieldType field_type = T::field_type(); DataFieldType field_type = T::field_type();
for (int dtype = 0; dtype < N_DTYPE; dtype++) for (int dtype = 0; dtype < N_DTYPE; dtype++)
if (implemented[dtype]) if (T::clear::allows(Dtype(dtype)))
buffers[dtype].seekg(pos.files[field_type][dtype]); buffers[dtype].seekg(pos.files[field_type][dtype]);
for (int j = 0; j < num_players; j++) for (int j = 0; j < num_players; j++)
if (j == my_num) if (j == my_num)
@@ -437,19 +310,3 @@ void Sub_Data_Files<T>::get(vector<T>& S, DataTag tag, const vector<int>& regs,
for (unsigned int i = 0; i < regs.size(); i++) for (unsigned int i = 0; i < regs.size(); i++)
extended[tag].input(S[regs[i] + j]); extended[tag].input(S[regs[i] + j]);
} }
template class Sub_Data_Files<Share<gf2n>>;
template class Sub_Data_Files<sgfp>;
template class Sub_Data_Files<Rep3Share<Integer>>;
template class Sub_Data_Files<Rep3Share<gfp>>;
template class Sub_Data_Files<Rep3Share<gf2n>>;
template class Sub_Data_Files<GC::MaliciousRepSecret>;
template class Sub_Data_Files<MaliciousRep3Share<gfp>>;
template class Sub_Data_Files<MaliciousRep3Share<gf2n>>;
template class Data_Files<sgfp, Share<gf2n>>;
template class Data_Files<Rep3Share<Integer>, Rep3Share<gf2n>>;
template class Data_Files<Rep3Share<gfp>, Rep3Share<gf2n>>;
template class Data_Files<MaliciousRep3Share<gfp>, MaliciousRep3Share<gf2n>>;
template class Data_Files<ShamirShare<gfp>, ShamirShare<gf2n>>;
template class Data_Files<MaliciousShamirShare<gfp>, MaliciousShamirShare<gf2n>>;

View File

@@ -10,7 +10,7 @@
using namespace std; using namespace std;
#include "Math/Share.h" #include "Math/Share.h"
#include "Processor/Buffer.h" #include "Tools/Buffer.h"
#include "Tools/time-func.h" #include "Tools/time-func.h"
class ArithmeticProcessor; class ArithmeticProcessor;

View File

@@ -92,6 +92,8 @@ enum
// Open // Open
OPEN = 0xA5, OPEN = 0xA5,
MULS = 0xA6, MULS = 0xA6,
MULRS = 0xA7,
DOTPRODS = 0xA8,
// Data access // Data access
TRIPLE = 0x50, TRIPLE = 0x50,
BIT = 0x51, BIT = 0x51,
@@ -214,6 +216,8 @@ enum
// Open // Open
GOPEN = 0x1A5, GOPEN = 0x1A5,
GMULS = 0x1A6, GMULS = 0x1A6,
GMULRS = 0x1A7,
GDOTPRODS = 0x1A8,
// Data access // Data access
GTRIPLE = 0x150, GTRIPLE = 0x150,
GBIT = 0x151, GBIT = 0x151,

View File

@@ -11,15 +11,15 @@
#include "Auth/ShamirMC.h" #include "Auth/ShamirMC.h"
#include "Math/MaliciousShamirShare.h" #include "Math/MaliciousShamirShare.h"
#include "Processor/Processor.hpp" //#include "Processor/Processor.hpp"
#include "Processor/Binary_File_IO.hpp" #include "Processor/Binary_File_IO.hpp"
#include "Processor/Input.hpp" //#include "Processor/Input.hpp"
#include "Processor/Beaver.hpp" //#include "Processor/Beaver.hpp"
#include "Processor/Shamir.hpp" //#include "Processor/Shamir.hpp"
#include "Processor/ShamirInput.hpp" //#include "Processor/ShamirInput.hpp"
#include "Processor/Replicated.hpp" //#include "Processor/Replicated.hpp"
#include "Auth/MaliciousRepMC.hpp" //#include "Auth/MaliciousRepMC.hpp"
#include "Auth/ShamirMC.hpp" //#include "Auth/ShamirMC.hpp"
#include <stdlib.h> #include <stdlib.h>
#include <algorithm> #include <algorithm>
@@ -33,6 +33,7 @@
#undef DEBUG #undef DEBUG
// Convert modp to signed bigint of a given bit length // Convert modp to signed bigint of a given bit length
inline
void to_signed_bigint(bigint& bi, const gfp& x, int len) void to_signed_bigint(bigint& bi, const gfp& x, int len)
{ {
to_bigint(bi, x); to_bigint(bi, x);
@@ -53,7 +54,7 @@ void to_signed_bigint(bigint& bi, const gfp& x, int len)
bi = -bi; bi = -bi;
} }
inline
void Instruction::parse(istream& s) void Instruction::parse(istream& s)
{ {
n=0; start.resize(0); n=0; start.resize(0);
@@ -70,7 +71,7 @@ void Instruction::parse(istream& s)
parse_operands(s, pos); parse_operands(s, pos);
} }
inline
void BaseInstruction::parse_operands(istream& s, int pos) void BaseInstruction::parse_operands(istream& s, int pos)
{ {
int num_var_args = 0; int num_var_args = 0;
@@ -277,7 +278,7 @@ void BaseInstruction::parse_operands(istream& s, int pos)
case CRASH: case CRASH:
case STARTGRIND: case STARTGRIND:
case STOPGRIND: case STOPGRIND:
break; break;
// instructions with 4 register operands // instructions with 4 register operands
case PRINTFLOATPLAIN: case PRINTFLOATPLAIN:
get_vector(4, start, s); get_vector(4, start, s);
@@ -288,6 +289,10 @@ void BaseInstruction::parse_operands(istream& s, int pos)
case GOPEN: case GOPEN:
case MULS: case MULS:
case GMULS: case GMULS:
case MULRS:
case GMULRS:
case DOTPRODS:
case GDOTPRODS:
case INPUT: case INPUT:
case GINPUT: case GINPUT:
num_var_args = get_int(s); num_var_args = get_int(s);
@@ -385,7 +390,7 @@ void BaseInstruction::parse_operands(istream& s, int pos)
} }
} }
inline
bool Instruction::get_offline_data_usage(DataPositions& usage) bool Instruction::get_offline_data_usage(DataPositions& usage)
{ {
switch (opcode) switch (opcode)
@@ -415,6 +420,7 @@ bool Instruction::get_offline_data_usage(DataPositions& usage)
} }
} }
inline
int BaseInstruction::get_reg_type() const int BaseInstruction::get_reg_type() const
{ {
switch (opcode) { switch (opcode) {
@@ -451,10 +457,27 @@ int BaseInstruction::get_reg_type() const
} }
} }
inline
unsigned BaseInstruction::get_max_reg(int reg_type) const unsigned BaseInstruction::get_max_reg(int reg_type) const
{ {
if (get_reg_type() != reg_type) { return 0; } if (get_reg_type() != reg_type) { return 0; }
switch (opcode)
{
case DOTPRODS:
{
int res = 0;
auto it = start.begin();
while (it != start.end())
{
int n = *it;
res = max(res, *it++);
it += n - 1;
}
return res;
}
}
const int *begin, *end; const int *begin, *end;
if (start.size()) if (start.size())
{ {
@@ -473,6 +496,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
return res + size; return res + size;
} }
inline
unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const
{ {
if (get_reg_type() == reg_type and is_direct_memory_access(sec_type)) if (get_reg_type() == reg_type and is_direct_memory_access(sec_type))
@@ -481,6 +505,7 @@ unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const
return 0; return 0;
} }
inline
bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
{ {
if (sec_type == SECRET) if (sec_type == SECRET)
@@ -514,7 +539,7 @@ bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const
} }
inline
ostream& operator<<(ostream& s,const Instruction& instr) ostream& operator<<(ostream& s,const Instruction& instr)
{ {
s << instr.opcode << " : "; s << instr.opcode << " : ";
@@ -1386,6 +1411,18 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case GMULS: case GMULS:
Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size); Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size);
return; return;
case MULRS:
Proc.Procp.protocol.mulrs(start, Proc.Procp);
return;
case GMULRS:
Proc.Proc2.protocol.mulrs(start, Proc.Proc2);
return;
case DOTPRODS:
Proc.Procp.protocol.dotprods(start, Proc.Procp);
return;
case GDOTPRODS:
Proc.Proc2.protocol.dotprods(start, Proc.Proc2);
return;
case JMP: case JMP:
Proc.PC += (signed int) n; Proc.PC += (signed int) n;
break; break;
@@ -1701,10 +1738,3 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
while (Proc.PC<size) while (Proc.PC<size)
{ p[Proc.PC].execute(Proc); } { p[Proc.PC].execute(Proc); }
} }
template void Program::execute(Processor<sgfp, Share<gf2n>>& Proc) const;
template void Program::execute(Processor<Rep3Share<Integer>, Rep3Share<gf2n>>& Proc) const;
template void Program::execute(Processor<Rep3Share<gfp>, Rep3Share<gf2n>>& Proc) const;
template void Program::execute(Processor<MaliciousRep3Share<gfp>, MaliciousRep3Share<gf2n>>& Proc) const;
template void Program::execute(Processor<ShamirShare<gfp>, ShamirShare<gf2n>>& Proc) const;
template void Program::execute(Processor<MaliciousShamirShare<gfp>, MaliciousShamirShare<gf2n>>& Proc) const;

View File

@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <atomic>
using namespace std; using namespace std;
template<class sint, class sgf2n> template<class sint, class sgf2n>
@@ -71,6 +72,8 @@ class Machine : public BaseMachine
OnlineOptions opts; OnlineOptions opts;
atomic<size_t> data_sent;
Machine(int my_number, Names& playerNames, string progname, Machine(int my_number, Names& playerNames, string progname,
string memtype, int lg2, bool direct, int opening_sum, bool parallel, string memtype, int lg2, bool direct, int opening_sum, bool parallel,
bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep, bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep,

View File

@@ -5,6 +5,7 @@
#include "ShamirInput.hpp" #include "ShamirInput.hpp"
#include "Shamir.hpp" #include "Shamir.hpp"
#include "Replicated.hpp" #include "Replicated.hpp"
#include "Beaver.hpp"
#include "Auth/ShamirMC.hpp" #include "Auth/ShamirMC.hpp"
#include "Auth/MaliciousShamirMC.hpp" #include "Auth/MaliciousShamirMC.hpp"
@@ -24,24 +25,6 @@
#include <pthread.h> #include <pthread.h>
using namespace std; using namespace std;
BaseMachine* BaseMachine::singleton = 0;
BaseMachine& BaseMachine::s()
{
if (singleton)
return *singleton;
else
throw runtime_error("no singleton");
}
BaseMachine::BaseMachine() : nthreads(0)
{
if (singleton)
throw runtime_error("there can only be one");
else
singleton = this;
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames, Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
string progname_str, string memtype, int lg2, bool direct, string progname_str, string memtype, int lg2, bool direct,
@@ -50,7 +33,8 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
: my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false), : my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false),
direct(direct), opening_sum(opening_sum), parallel(parallel), direct(direct), opening_sum(opening_sum), parallel(parallel),
receive_threads(receive_threads), max_broadcast(max_broadcast), receive_threads(receive_threads), max_broadcast(max_broadcast),
use_encryption(use_encryption), live_prep(live_prep), opts(opts) use_encryption(use_encryption), live_prep(live_prep), opts(opts),
data_sent(0)
{ {
if (opening_sum < 2) if (opening_sum < 2)
this->opening_sum = N.num_players(); this->opening_sum = N.num_players();
@@ -166,42 +150,6 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
} }
} }
void BaseMachine::load_schedule(string progname)
{
this->progname = progname;
string fname = "Programs/Schedules/" + progname + ".sch";
cerr << "Opening file " << fname << endl;
inpf.open(fname);
if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); }
int nprogs;
inpf >> nthreads;
inpf >> nprogs;
cerr << "Number of threads I will run in parallel = " << nthreads << endl;
cerr << "Number of program sequences I need to load = " << nprogs << endl;
// Load in the programs
string threadname;
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
string filename = "Programs/Bytecode/" + threadname + ".bc";
cerr << "Loading program " << i << " from " << filename << endl;
load_program(threadname, filename);
}
}
void BaseMachine::print_compiler()
{
char compiler[1000];
inpf.get();
inpf.getline(compiler, 1000);
if (compiler[0] != 0)
cerr << "Compiler: " << compiler << endl;
inpf.close();
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
void Machine<sint, sgf2n>::load_program(string threadname, string filename) void Machine<sint, sgf2n>::load_program(string threadname, string filename)
{ {
@@ -345,6 +293,7 @@ void Machine<sint, sgf2n>::run()
cerr << "Finish timer: " << finish_timer.elapsed() << endl; cerr << "Finish timer: " << finish_timer.elapsed() << endl;
cerr << "Process timer: " << proc_timer.elapsed() << endl; cerr << "Process timer: " << proc_timer.elapsed() << endl;
print_timers(); print_timers();
cerr << "Data sent = " << data_sent / 1e6 << " MB" << endl;
if (opening_sum < N.num_players() && !direct) if (opening_sum < N.num_players() && !direct)
cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl; cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl;
@@ -368,12 +317,6 @@ void Machine<sint, sgf2n>::run()
outf << M2 << Mp << Mi; outf << M2 << Mp << Mi;
outf.close(); outf.close();
extern unsigned long long sent_amount, sent_counter;
cerr << "Data sent = " << sent_amount << " bytes in "
<< sent_counter << " calls,";
cerr << sent_amount / sent_counter / N.num_players()
<< " bytes per call" << endl;
for (int dtype = 0; dtype < N_DTYPE; dtype++) for (int dtype = 0; dtype < N_DTYPE; dtype++)
{ {
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t="; cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
@@ -407,48 +350,8 @@ string Machine<sint, sgf2n>::memory_filename()
return PREP_DIR "Memory-" + sint::type_short() + "-P" + to_string(my_number); return PREP_DIR "Memory-" + sint::type_short() + "-P" + to_string(my_number);
} }
void BaseMachine::load_program(string threadname, string filename)
{
(void)threadname;
(void)filename;
throw not_implemented();
}
void BaseMachine::time()
{
cout << "Elapsed time: " << timer[0].elapsed() << endl;
}
void BaseMachine::start(int n)
{
cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " after " << timer[n].idle() << endl;
timer[n].start();
}
void BaseMachine::stop(int n)
{
timer[n].stop();
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl;
}
void BaseMachine::print_timers()
{
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
timer.erase(0);
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
void Machine<sint, sgf2n>::reqbl(int n) void Machine<sint, sgf2n>::reqbl(int n)
{ {
sint::clear::reqbl(n); sint::clear::reqbl(n);
} }
template class Machine<sgfp, Share<gf2n>>;
template class Machine<Rep3Share<Integer>, Rep3Share<gf2n>>;
template class Machine<Rep3Share<gfp>, Rep3Share<gf2n>>;
template class Machine<MaliciousRep3Share<gfp>, MaliciousRep3Share<gf2n>>;
template class Machine<ShamirShare<gfp>, ShamirShare<gf2n>>;
template class Machine<MaliciousShamirShare<gfp>, MaliciousShamirShare<gf2n>>;

View File

@@ -5,7 +5,7 @@
#include "MaliciousRepPrep.h" #include "MaliciousRepPrep.h"
#include "Auth/Subroutines.h" #include "Auth/Subroutines.h"
#include "Auth/MaliciousRepMC.hpp" //#include "Auth/MaliciousRepMC.hpp"
template<class T> template<class T>
MaliciousRepPrep<T>::MaliciousRepPrep(SubProcessor<T>* proc) : MaliciousRepPrep<T>::MaliciousRepPrep(SubProcessor<T>* proc) :
@@ -41,7 +41,7 @@ template<class T>
void MaliciousRepPrep<T>::buffer_triples() void MaliciousRepPrep<T>::buffer_triples()
{ {
auto& triples = this->triples; auto& triples = this->triples;
auto& buffer_size = this->buffer_size; auto buffer_size = this->buffer_size;
clear_tmp(); clear_tmp();
Player& P = honest_prep.protocol->P; Player& P = honest_prep.protocol->P;
triples.clear(); triples.clear();
@@ -51,8 +51,8 @@ void MaliciousRepPrep<T>::buffer_triples()
T f, g, h; T f, g, h;
honest_prep.get_three(DATA_TRIPLE, a, b, c); honest_prep.get_three(DATA_TRIPLE, a, b, c);
honest_prep.get_three(DATA_TRIPLE, f, g, h); honest_prep.get_three(DATA_TRIPLE, f, g, h);
triples.push_back({a, b, c}); triples.push_back({{a, b, c}});
check_triples.push_back({f, g, h}); check_triples.push_back({{f, g, h}});
} }
auto t = Create_Random<typename T::clear>(P); auto t = Create_Random<typename T::clear>(P);
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
@@ -86,7 +86,7 @@ template<class T>
void MaliciousRepPrep<T>::buffer_squares() void MaliciousRepPrep<T>::buffer_squares()
{ {
auto& squares = this->squares; auto& squares = this->squares;
auto& buffer_size = this->buffer_size; auto buffer_size = this->buffer_size;
clear_tmp(); clear_tmp();
Player& P = honest_prep.protocol->P; Player& P = honest_prep.protocol->P;
squares.clear(); squares.clear();
@@ -96,8 +96,8 @@ void MaliciousRepPrep<T>::buffer_squares()
T f, h; T f, h;
honest_prep.get_two(DATA_SQUARE, a, b); honest_prep.get_two(DATA_SQUARE, a, b);
honest_prep.get_two(DATA_SQUARE, f, h); honest_prep.get_two(DATA_SQUARE, f, h);
squares.push_back({a, b}); squares.push_back({{a, b}});
check_squares.push_back({f, h}); check_squares.push_back({{f, h}});
} }
auto t = Create_Random<typename T::clear>(P); auto t = Create_Random<typename T::clear>(P);
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
@@ -132,7 +132,7 @@ template<class T>
void MaliciousRepPrep<T>::buffer_bits() void MaliciousRepPrep<T>::buffer_bits()
{ {
auto& bits = this->bits; auto& bits = this->bits;
auto& buffer_size = this->buffer_size; auto buffer_size = this->buffer_size;
clear_tmp(); clear_tmp();
Player& P = honest_prep.protocol->P; Player& P = honest_prep.protocol->P;
bits.clear(); bits.clear();
@@ -142,7 +142,7 @@ void MaliciousRepPrep<T>::buffer_bits()
honest_prep.get_one(DATA_BIT, a); honest_prep.get_one(DATA_BIT, a);
honest_prep.get_two(DATA_SQUARE, f, h); honest_prep.get_two(DATA_SQUARE, f, h);
bits.push_back(a); bits.push_back(a);
check_squares.push_back({f, h}); check_squares.push_back({{f, h}});
} }
auto t = Create_Random<typename T::clear>(P); auto t = Create_Random<typename T::clear>(P);
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)

View File

@@ -170,6 +170,8 @@ void* Sub_Main_Func(void* ptr)
cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl; cerr << "Thread " << num << " timer: " << thread_timer.elapsed() << endl;
cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl; cerr << "Thread " << num << " wait timer: " << wait_timer.elapsed() << endl;
machine.data_sent += P.sent;
delete MC2; delete MC2;
delete MCp; delete MCp;
delete player; delete player;

View File

@@ -41,7 +41,7 @@ class SubProcessor
template<class sint, class sgf2n> friend class Processor; template<class sint, class sgf2n> friend class Processor;
template<class U> friend class SPDZ; template<class U> friend class SPDZ;
template<class U> friend class PrepLessProtocol; template<class U> friend class ProtocolBase;
template<class U> friend class Beaver; template<class U> friend class Beaver;
public: public:
@@ -61,7 +61,9 @@ public:
void POpen_Stop(const vector<int>& reg,const Player& P,int size); void POpen_Stop(const vector<int>& reg,const Player& P,int size);
void POpen(const vector<int>& reg,const Player& P,int size); void POpen(const vector<int>& reg,const Player& P,int size);
void muls(const vector<int>& reg,const Player& P,int size); void muls(const vector<int>& reg, int size);
void mulrs(const vector<int>& reg);
void dotprods(const vector<int>& reg);
vector<T>& get_S() vector<T>& get_S()
{ {

View File

@@ -50,7 +50,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
template<class sint, class sgf2n> template<class sint, class sgf2n>
Processor<sint, sgf2n>::~Processor() Processor<sint, sgf2n>::~Processor()
{ {
cerr << "Sent " << sent << " elements in " << rounds << " rounds" << endl; cerr << "Opened " << sent << " elements in " << rounds << " rounds" << endl;
} }
template<class sint, class sgf2n> template<class sint, class sgf2n>
@@ -485,6 +485,85 @@ void SubProcessor<T>::POpen(const vector<int>& reg, const Player& P,
POpen_Stop(dest, P, size); POpen_Stop(dest, P, size);
} }
template<class T>
void SubProcessor<T>::muls(const vector<int>& reg, int size)
{
assert(reg.size() % 3 == 0);
int n = reg.size() / 3;
SubProcessor<T>& proc = *this;
protocol.init_mul(&proc);
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++)
{
auto& x = proc.S[reg[3 * i + 1] + j];
auto& y = proc.S[reg[3 * i + 2] + j];
protocol.prepare_mul(x, y);
}
protocol.exchange();
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++)
{
proc.S[reg[3 * i] + j] = protocol.finalize_mul();
}
protocol.counter += n * size;
}
template<class T>
void SubProcessor<T>::mulrs(const vector<int>& reg)
{
assert(reg.size() % 4 == 0);
int n = reg.size() / 4;
SubProcessor<T>& proc = *this;
protocol.init_mul(&proc);
for (int i = 0; i < n; i++)
for (int j = 0; j < reg[4 * i]; j++)
{
auto& x = proc.S[reg[4 * i + 2] + j];
auto& y = proc.S[reg[4 * i + 3]];
protocol.prepare_mul(x, y);
}
protocol.exchange();
for (int i = 0; i < n; i++)
{
for (int j = 0; j < reg[4 * i]; j++)
{
proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul();
}
protocol.counter += reg[4 * i];
}
}
template<class T>
void SubProcessor<T>::dotprods(const vector<int>& reg)
{
protocol.init_dotprod(this);
auto it = reg.begin();
while (it != reg.end())
{
auto next = it + *it;
it += 2;
while (it != next)
{
protocol.prepare_dotprod(S[*it], S[*(it + 1)]);
it += 2;
}
protocol.next_dotprod();
}
protocol.exchange();
it = reg.begin();
while (it != reg.end())
{
auto next = it + *it;
it++;
T& dest = S[*it];
dest = protocol.finalize_dotprod((next - it) / 2);
it = next;
}
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
ostream& operator<<(ostream& s,const Processor<sint, sgf2n>& P) ostream& operator<<(ostream& s,const Processor<sint, sgf2n>& P)
{ {

View File

@@ -3,6 +3,8 @@
#include "Processor/Data_Files.h" #include "Processor/Data_Files.h"
#include "Processor/Processor.h" #include "Processor/Processor.h"
#include "Processor/Instruction.hpp"
void Program::compute_constants() void Program::compute_constants()
{ {
for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++)

View File

@@ -30,33 +30,42 @@ public:
Player& P; Player& P;
ReplicatedBase(Player& P); ReplicatedBase(Player& P);
int get_n_relevant_players() { return P.num_players() - 1; }
}; };
template <class T> template <class T>
class PrepLessProtocol class ProtocolBase
{ {
public:
int counter; int counter;
public: ProtocolBase();
PrepLessProtocol(); virtual ~ProtocolBase();
virtual ~PrepLessProtocol();
void muls(const vector<int>& reg, SubProcessor<T>& proc, MAC_Check_Base<T>& MC, void muls(const vector<int>& reg, SubProcessor<T>& proc, MAC_Check_Base<T>& MC,
int size); int size);
void mulrs(const vector<int>& reg, SubProcessor<T>& proc);
void dotprods(const vector<int>& reg, SubProcessor<T>& proc);
virtual void init_mul(SubProcessor<T>* proc) = 0; virtual void init_mul(SubProcessor<T>* proc) = 0;
virtual typename T::clear prepare_mul(const T& x, const T& y) = 0; virtual typename T::clear prepare_mul(const T& x, const T& y) = 0;
virtual void exchange() = 0; virtual void exchange() = 0;
virtual T finalize_mul() = 0; virtual T finalize_mul() = 0;
virtual T get_random() = 0; void init_dotprod(SubProcessor<T>* proc) { init_mul(proc); }
void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); }
void next_dotprod() {}
T finalize_dotprod(int length);
}; };
template <class T> template <class T>
class Replicated : public ReplicatedBase, public PrepLessProtocol<T> class Replicated : public ReplicatedBase, public ProtocolBase<T>
{ {
vector<octetStream> os; vector<octetStream> os;
deque<typename T::clear> add_shares; deque<typename T::clear> add_shares;
typename T::clear dotprod_share;
public: public:
typedef ReplicatedMC<T> MAC_Check; typedef ReplicatedMC<T> MAC_Check;
typedef ReplicatedInput<T> Input; typedef ReplicatedInput<T> Input;
@@ -78,6 +87,13 @@ public:
void exchange(); void exchange();
T finalize_mul(); T finalize_mul();
void prepare_reshare(const typename T::clear& share);
void init_dotprod(SubProcessor<T>* proc);
void prepare_dotprod(const T& x, const T& y);
void next_dotprod();
T finalize_dotprod(int length);
T get_random(); T get_random();
}; };

View File

@@ -13,7 +13,7 @@
#include "GC/ReplicatedSecret.h" #include "GC/ReplicatedSecret.h"
template<class T> template<class T>
PrepLessProtocol<T>::PrepLessProtocol() : counter(0) ProtocolBase<T>::ProtocolBase() : counter(0)
{ {
} }
@@ -38,36 +38,42 @@ inline ReplicatedBase::ReplicatedBase(Player& P) : P(P)
} }
template<class T> template<class T>
PrepLessProtocol<T>::~PrepLessProtocol() ProtocolBase<T>::~ProtocolBase()
{ {
if (counter) if (counter)
cerr << "Number of multiplications: " << counter << endl; cerr << "Number of multiplications: " << counter << endl;
} }
template<class T> template<class T>
void PrepLessProtocol<T>::muls(const vector<int>& reg, void ProtocolBase<T>::muls(const vector<int>& reg,
SubProcessor<T>& proc, MAC_Check_Base<T>& MC, int size) SubProcessor<T>& proc, MAC_Check_Base<T>& MC, int size)
{ {
(void)MC; (void)MC;
assert(reg.size() % 3 == 0); proc.muls(reg, size);
int n = reg.size() / 3; }
init_mul(&proc); template<class T>
for (int i = 0; i < n; i++) void ProtocolBase<T>::mulrs(const vector<int>& reg,
for (int j = 0; j < size; j++) SubProcessor<T>& proc)
{ {
auto& x = proc.S[reg[3 * i + 1] + j]; proc.mulrs(reg);
auto& y = proc.S[reg[3 * i + 2] + j]; }
prepare_mul(x, y);
}
exchange();
for (int i = 0; i < n; i++)
for (int j = 0; j < size; j++)
{
proc.S[reg[3 * i] + j] = finalize_mul();
}
counter += n * size; template<class T>
void ProtocolBase<T>::dotprods(const vector<int>& reg,
SubProcessor<T>& proc)
{
proc.dotprods(reg);
}
template<class T>
T ProtocolBase<T>::finalize_dotprod(int length)
{
counter += length;
T res;
for (int i = 0; i < length; i++)
res += finalize_mul();
return res;
} }
template<class T> template<class T>
@@ -87,28 +93,34 @@ void Replicated<T>::init_mul()
} }
template<class T> template<class T>
typename T::clear Replicated<T>::prepare_mul(const T& x, inline typename T::clear Replicated<T>::prepare_mul(const T& x,
const T& y) const T& y)
{ {
typename T::value_type add_share = x[0] * y.sum() + x[1] * y[0]; typename T::value_type add_share = x.local_mul(y);
prepare_reshare(add_share);
return add_share;
}
template<class T>
inline void Replicated<T>::prepare_reshare(const typename T::clear& share)
{
auto add_share = share;
typename T::value_type tmp[2]; typename T::value_type tmp[2];
for (int i = 0; i < 2; i++) for (int i = 0; i < 2; i++)
tmp[i].randomize(shared_prngs[i]); tmp[i].randomize(shared_prngs[i]);
add_share += tmp[0] - tmp[1]; add_share += tmp[0] - tmp[1];
add_share.pack(os[0]); add_share.pack(os[0]);
add_shares.push_back(add_share); add_shares.push_back(add_share);
return add_share;
} }
template<class T> template<class T>
void Replicated<T>::exchange() void Replicated<T>::exchange()
{ {
P.send_relative(1, os[0]); P.pass_around(os[0], 1);
P.receive_relative(- 1, os[0]);
} }
template<class T> template<class T>
T Replicated<T>::finalize_mul() inline T Replicated<T>::finalize_mul()
{ {
T result; T result;
result[0] = add_shares.front(); result[0] = add_shares.front();
@@ -117,6 +129,34 @@ T Replicated<T>::finalize_mul()
return result; return result;
} }
template<class T>
inline void Replicated<T>::init_dotprod(SubProcessor<T>* proc)
{
init_mul(proc);
dotprod_share.assign_zero();
}
template<class T>
inline void Replicated<T>::prepare_dotprod(const T& x, const T& y)
{
dotprod_share += x.local_mul(y);
}
template<class T>
inline void Replicated<T>::next_dotprod()
{
prepare_reshare(dotprod_share);
dotprod_share.assign_zero();
}
template<class T>
inline T Replicated<T>::finalize_dotprod(int length)
{
(void) length;
this->counter++;
return finalize_mul();
}
template<class T> template<class T>
T Replicated<T>::get_random() T Replicated<T>::get_random()
{ {

View File

@@ -21,7 +21,7 @@ void ReplicatedInput<T>::reset(int player)
} }
template<class T> template<class T>
void ReplicatedInput<T>::add_mine(const typename T::clear& input) inline void ReplicatedInput<T>::add_mine(const typename T::clear& input)
{ {
auto& shares = this->shares; auto& shares = this->shares;
shares.push_back({}); shares.push_back({});
@@ -89,7 +89,7 @@ void PrepLessInput<T>::stop(int player, vector<int> targets)
} }
template<class T> template<class T>
void ReplicatedInput<T>::finalize_other(int player, T& target, inline void ReplicatedInput<T>::finalize_other(int player, T& target,
octetStream& o) octetStream& o)
{ {
typename T::value_type t; typename T::value_type t;

View File

@@ -28,7 +28,7 @@ protected:
virtual void buffer_inverses(MAC_Check_Base<T>& MC, Player& P); virtual void buffer_inverses(MAC_Check_Base<T>& MC, Player& P);
public: public:
static const int buffer_size = 1000; static const int buffer_size = 10000;
virtual ~BufferPrep() {} virtual ~BufferPrep() {}
@@ -49,6 +49,8 @@ protected:
typename T::Protocol* protocol; typename T::Protocol* protocol;
SubProcessor<T>* proc; SubProcessor<T>* proc;
int base_player;
void buffer_triples(); void buffer_triples();
void buffer_squares(); void buffer_squares();
void buffer_inverses() { throw runtime_error("not inverses in rings"); } void buffer_inverses() { throw runtime_error("not inverses in rings"); }
@@ -57,7 +59,7 @@ public:
ReplicatedRingPrep(SubProcessor<T>* proc); ReplicatedRingPrep(SubProcessor<T>* proc);
virtual ~ReplicatedRingPrep() {} virtual ~ReplicatedRingPrep() {}
void set_protocol(typename T::Protocol& protocol) { this->protocol = &protocol; } void set_protocol(typename T::Protocol& protocol);
virtual void buffer_bits(); virtual void buffer_bits();
}; };

View File

@@ -15,6 +15,16 @@ ReplicatedRingPrep<T>::ReplicatedRingPrep(SubProcessor<T>* proc) :
{ {
} }
template<class T>
void ReplicatedRingPrep<T>::set_protocol(typename T::Protocol& protocol)
{
this->protocol = &protocol;
if (proc)
base_player = proc->Proc.thread_num;
else
base_player = 0;
}
template<class T> template<class T>
void ReplicatedRingPrep<T>::buffer_triples() void ReplicatedRingPrep<T>::buffer_triples()
{ {
@@ -90,7 +100,7 @@ void BufferPrep<T>::buffer_inverses(MAC_Check_Base<T>& MC, Player& P)
MC.POpen(c_open, c, P); MC.POpen(c_open, c, P);
for (size_t i = 0; i < c.size(); i++) for (size_t i = 0; i < c.size(); i++)
if (c_open[i] != 0) if (c_open[i] != 0)
inverses.push_back({triples[i][0], triples[i][1] / c_open[i]}); inverses.push_back({{triples[i][0], triples[i][1] / c_open[i]}});
triples.clear(); triples.clear();
if (inverses.empty()) if (inverses.empty())
throw runtime_error("products were all zero"); throw runtime_error("products were all zero");
@@ -141,25 +151,12 @@ void XOR(vector<T>& res, vector<T>& x, vector<T>& y, int buffer_size,
res[i] = x[i] + y[i] - prot.finalize_mul() * two; res[i] = x[i] + y[i] - prot.finalize_mul() * two;
} }
int get_n_relevant_players(Player& P)
{
int n_relevant_players = P.num_players();
try
{
n_relevant_players = ShamirMachine::s().threshold + 1;
}
catch (...)
{
}
return n_relevant_players;
}
template<template<class U> class T> template<template<class U> class T>
void buffer_bits_spec(ReplicatedPrep<T<gfp>>& prep, vector<T<gfp>>& bits, void buffer_bits_spec(ReplicatedPrep<T<gfp>>& prep, vector<T<gfp>>& bits,
typename T<gfp>::Protocol& prot) typename T<gfp>::Protocol& prot)
{ {
(void) bits, (void) prot; (void) bits, (void) prot;
if (get_n_relevant_players(prot.P) > 10) if (prot.get_n_relevant_players() > 10)
{ {
vector<array<T<gfp>, 2>> squares(prep.buffer_size); vector<array<T<gfp>, 2>> squares(prep.buffer_size);
vector<T<gfp>> s; vector<T<gfp>> s;
@@ -189,12 +186,12 @@ void ReplicatedRingPrep<T>::buffer_bits()
auto buffer_size = this->buffer_size; auto buffer_size = this->buffer_size;
auto& bits = this->bits; auto& bits = this->bits;
auto& P = protocol->P; auto& P = protocol->P;
int n_relevant_players = get_n_relevant_players(P); int n_relevant_players = protocol->get_n_relevant_players();
vector<vector<T>> player_bits(n_relevant_players, vector<T>(buffer_size)); vector<vector<T>> player_bits(n_relevant_players, vector<T>(buffer_size));
typename T::Input input(proc, P); typename T::Input input(proc, P);
for (int i = 0; i < n_relevant_players; i++) for (int i = 0; i < P.num_players(); i++)
input.reset(i); input.reset(i);
if (P.my_num() < n_relevant_players) if (positive_modulo(P.my_num() - base_player, P.num_players()) < n_relevant_players)
{ {
SeededPRNG G; SeededPRNG G;
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
@@ -202,25 +199,28 @@ void ReplicatedRingPrep<T>::buffer_bits()
input.send_mine(); input.send_mine();
} }
for (int i = 0; i < n_relevant_players; i++) for (int i = 0; i < n_relevant_players; i++)
if (i == P.my_num()) {
int input_player = (base_player + i) % P.num_players();
if (input_player == P.my_num())
for (auto& x : player_bits[i]) for (auto& x : player_bits[i])
x = input.finalize_mine(); x = input.finalize_mine();
else else
{ {
octetStream os; octetStream os;
P.receive_player(i, os, true); P.receive_player(input_player, os, true);
for (auto& x : player_bits[i]) for (auto& x : player_bits[i])
input.finalize_other(i, x, os); input.finalize_other(input_player, x, os);
} }
}
auto& prot = *protocol; auto& prot = *protocol;
vector<T> tmp; XOR(bits, player_bits[0], player_bits[1], buffer_size, prot, proc);
XOR(tmp, player_bits[0], player_bits[1], buffer_size, prot, proc); for (int i = 2; i < n_relevant_players; i++)
for (int i = 2; i < n_relevant_players - 1; i++) XOR(bits, bits, player_bits[i], buffer_size, prot, proc);
XOR(tmp, tmp, player_bits[i], buffer_size, prot, proc); base_player++;
XOR(bits, tmp, player_bits[n_relevant_players - 1], buffer_size, prot, proc);
} }
template<> template<>
inline
void ReplicatedRingPrep<Rep3Share<gf2n>>::buffer_bits() void ReplicatedRingPrep<Rep3Share<gf2n>>::buffer_bits()
{ {
assert(protocol != 0); assert(protocol != 0);

View File

@@ -1,14 +0,0 @@
/*
* Multiplication.cpp
*
*/
#include "SPDZ.h"
#include "Processor.h"
#include "Math/Share.h"
#include "Auth/MAC_Check.h"
#include "Input.hpp"
template class SPDZ<gfp>;
template class SPDZ<gf2n>;

View File

@@ -19,7 +19,7 @@ template<class T> class ShamirInput;
class Player; class Player;
template<class U> template<class U>
class Shamir : public PrepLessProtocol<ShamirShare<U>> class Shamir : public ProtocolBase<ShamirShare<U>>
{ {
typedef ShamirShare<U> T; typedef ShamirShare<U> T;
@@ -45,6 +45,8 @@ public:
Shamir(Player& P); Shamir(Player& P);
~Shamir(); ~Shamir();
int get_n_relevant_players();
void reset(); void reset();
void init_mul(SubProcessor<T>* proc); void init_mul(SubProcessor<T>* proc);

View File

@@ -5,7 +5,7 @@
#include "Shamir.h" #include "Shamir.h"
#include "ShamirInput.h" #include "ShamirInput.h"
#include "ShamirMachine.h" #include "Machines/ShamirMachine.h"
template<class U> template<class U>
U Shamir<U>::get_rec_factor(int i, int n) U Shamir<U>::get_rec_factor(int i, int n)
@@ -31,6 +31,12 @@ Shamir<U>::~Shamir()
delete resharing; delete resharing;
} }
template<class U>
int Shamir<U>::get_n_relevant_players()
{
return ShamirMachine::s().threshold + 1;
}
template<class U> template<class U>
void Shamir<U>::reset() void Shamir<U>::reset()
{ {

View File

@@ -4,7 +4,7 @@
*/ */
#include "ShamirInput.h" #include "ShamirInput.h"
#include "ShamirMachine.h" #include "Machines/ShamirMachine.h"
template<class U> template<class U>
void ShamirInput<U>::reset(int player) void ShamirInput<U>::reset(int player)

View File

@@ -117,3 +117,17 @@ print_ln('weighted average: %s', result.reveal())
print_ln_if((sum(point[0] for point in data) == 0).reveal(), \ print_ln_if((sum(point[0] for point in data) == 0).reveal(), \
'but the inputs were invalid (weights add up to zero)') 'but the inputs were invalid (weights add up to zero)')
# permutation matrix
M = Matrix(2, 2, sfix)
M[0][0] = 0
M[1][0] = 1
M[0][1] = 1
M[1][1] = 0
# matrix multiplication
M = data * M
test(M[0][0], data[0][1].reveal())
test(M[1][1], data[1][0].reveal())

View File

@@ -5,16 +5,40 @@ protocols such as SPDZ, MASCOT, Overdrive, BMR garbled circuits
(evaluation only), Yao's garbled circuits, and computation based on (evaluation only), Yao's garbled circuits, and computation based on
semi-honest three-party replicated secret sharing (with an honest majority). semi-honest three-party replicated secret sharing (with an honest majority).
#### TL;DR #### TL;DR (Binary Distribution on Linux or Souce Distribution on macOS)
This requires `sudo` rights as well as a working toolchain installed This requires either a Linux distribution originally released 2011 or
for the first step, refer to [the requirements](#requirements) later (glibc 2.12) or macOS High Sierra or later as well as Python 2
otherwise. It will execute [the and basic command-line utilities.
Download and unpack the distribution, then execute the following from
the top folder:
```
Scripts/tldr.sh
./compile.py tutorial
echo 1 2 3 > Player-Data/Input-P0-0
echo 1 2 3 > Player-Data/Input-P1-0
Scripts/mal-rep-field.sh tutorial
```
This runs [the tutorial](Programs/Source/tutorial.mpc) with three
parties, an honest majority, and malicious security.
#### TL;DR (Source Distribution)
On Linux, this requires a working toolchain and [all
requirements](#requirements). On Ubuntu, the following might suffice:
```
apt-get install automake build-essential git libboost-dev libboost-thread-dev libsodium-dev libssl-dev libtool m4 python texinfo yasm
```
On MacOS, this requires [brew](https://brew.sh) to be installed,
which will be used for all dependencies.
It will execute [the
tutorial](Programs/Source/tutorial.mpc) with three tutorial](Programs/Source/tutorial.mpc) with three
parties, an honest majority, and malicious security. parties, an honest majority, and malicious security.
``` ```
make -j 8 mpir
make -j 8 tldr make -j 8 tldr
./compile.py tutorial ./compile.py tutorial
Scripts/setup-replicated.sh Scripts/setup-replicated.sh
@@ -95,10 +119,10 @@ phase outputs the amount of offline material required, which allows to
compute the preprocessing time for a particulor computation. compute the preprocessing time for a particulor computation.
#### Requirements #### Requirements
- GCC 4.8 or later (tested with 7.3; remove `-no-pie` from `CONFIG` for GCC 4.8) or LLVM (tested with 6.0; remove `-no-pie` from `CONFIG`) - GCC 4.8 or later (tested with 7.3) or LLVM (tested with 6.0)
- MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure)
- libsodium library, tested against 1.0.16 - libsodium library, tested against 1.0.16
- OpenSSL, tested against 1.1.0 - OpenSSL, tested against and 1.0.2 and 1.1.0
- Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.65 - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.65
- Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.65 - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.65
- CPU supporting AES-NI, PCLMUL, AVX2 - CPU supporting AES-NI, PCLMUL, AVX2

20
Scripts/build.sh Executable file
View File

@@ -0,0 +1,20 @@
#!/bin/bash
function build
{
echo ARCH = $1 >> CONFIG.mine
echo GDEBUG = >> CONFIG.mine
make clean
rm -R static
mkdir static
make -j 12 static-hm
mkdir bin
dest=bin/`uname`-$2
rm -R $dest
mv static $dest
strip $dest/*
}
build '' amd64
build '-msse4.1 -maes -mpclmul' aes
build '-msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2' avx2

View File

@@ -1,7 +1,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include "Math/gf2n.h" #include "Math/gf2n.h"
#include "Processor/Buffer.h" #include "Tools/Buffer.h"
using namespace std; using namespace std;

View File

@@ -1,7 +1,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include "Math/gfp.h" #include "Math/gfp.h"
#include "Processor/Buffer.h" #include "Tools/Buffer.h"
#include "Tools/ezOptionParser.h" #include "Tools/ezOptionParser.h"
#include "Math/Setup.h" #include "Math/Setup.h"

View File

@@ -54,4 +54,4 @@ SPDZROOT=${SPDZROOT:-.}
#. Scripts/setup.sh #. Scripts/setup.sh
mkdir logs mkdir logs 2> /dev/null

View File

@@ -32,6 +32,3 @@ $HERE/setup-ssl.sh ${players}
$SPDZROOT/Fake-Offline.x ${players} -lgp ${bits} -lg2 ${g} --default ${default} $SPDZROOT/Fake-Offline.x ${players} -lgp ${bits} -lg2 ${g} --default ${default}
for i in $(seq 0 $[players-1]) ; do
dd if=/dev/zero of=Player-Data/Private-Input-$i bs=10000 count=1
done

View File

@@ -10,4 +10,6 @@ for i in `seq 0 $[n-1]`; do
openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i"
done done
# brew-installed OpenSSL on MacOS
PATH=$PATH:/usr/local/opt/openssl/bin/
c_rehash Player-Data c_rehash Player-Data

35
Scripts/tldr.sh Executable file
View File

@@ -0,0 +1,35 @@
#!/bin/sh
if test `uname` = "Linux"; then
flags='cat /proc/cpuinfo'
elif test `uname` = Darwin; then
if ! type brew; then
echo Do you want me to install Homebrew?
echo Press RETURN to continue or any other key to abort
read ans
if test "$ans"; then
echo Aborting
exit 1
else
/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
fi
fi
make tldr
else
echo OS unknown
exit 1
fi
if test "$flags"; then
if $flags | grep -q avx2; then
cpu=avx2
elif $flags | grep -q aes; then
cpu=aes
else
cpu=amd64
fi
cp -av bin/`uname`-$cpu/* .
fi
Scripts/setup-ssl.sh

View File

@@ -3,7 +3,7 @@
* *
*/ */
#include "Buffer.h" #include "Tools/Buffer.h"
bool BufferBase::rewind = false; bool BufferBase::rewind = false;

View File

@@ -3,8 +3,8 @@
* *
*/ */
#ifndef PROCESSOR_BUFFER_H_ #ifndef TOOLS_BUFFER_H_
#define PROCESSOR_BUFFER_H_ #define TOOLS_BUFFER_H_
#include <fstream> #include <fstream>
using namespace std; using namespace std;
@@ -12,7 +12,7 @@ using namespace std;
#include "Math/Share.h" #include "Math/Share.h"
#include "Math/field_types.h" #include "Math/field_types.h"
#include "Tools/time-func.h" #include "Tools/time-func.h"
#include "config.h" #include "Processor/config.h"
#ifndef BUFFER_SIZE #ifndef BUFFER_SIZE
#define BUFFER_SIZE 101 #define BUFFER_SIZE 101
@@ -153,4 +153,4 @@ inline void Buffer<T,U>::input(U& a)
next++; next++;
} }
#endif /* PROCESSOR_BUFFER_H_ */ #endif /* TOOLS_BUFFER_H_ */

View File

@@ -35,6 +35,7 @@ void aes_128_schedule( octet* key, const octet* userkey )
__m128i *Key_Schedule = (__m128i*)key; __m128i *Key_Schedule = (__m128i*)key;
temp1 = _mm_loadu_si128((__m128i*)userkey); temp1 = _mm_loadu_si128((__m128i*)userkey);
Key_Schedule[0] = temp1; Key_Schedule[0] = temp1;
#ifdef __AES__
temp2 = _mm_aeskeygenassist_si128 (temp1 ,0x1); temp2 = _mm_aeskeygenassist_si128 (temp1 ,0x1);
temp1 = AES_128_ASSIST(temp1, temp2); temp1 = AES_128_ASSIST(temp1, temp2);
Key_Schedule[1] = temp1; Key_Schedule[1] = temp1;
@@ -65,8 +66,13 @@ void aes_128_schedule( octet* key, const octet* userkey )
temp2 = _mm_aeskeygenassist_si128 (temp1,0x36); temp2 = _mm_aeskeygenassist_si128 (temp1,0x36);
temp1 = AES_128_ASSIST(temp1, temp2); temp1 = AES_128_ASSIST(temp1, temp2);
Key_Schedule[10] = temp1; Key_Schedule[10] = temp1;
#else
(void) temp2;
throw runtime_error("need to compile with AES-NI support");
#endif
} }
#ifdef __AES__
inline void KEY_192_ASSIST(__m128i* temp1, __m128i * temp2, __m128i * temp3) inline void KEY_192_ASSIST(__m128i* temp1, __m128i * temp2, __m128i * temp3)
{ __m128i temp4; { __m128i temp4;
*temp2 = _mm_shuffle_epi32 (*temp2, 0x55); *temp2 = _mm_shuffle_epi32 (*temp2, 0x55);
@@ -222,7 +228,7 @@ void aes_256_encrypt(octet* out, const octet* in, const octet* key)
tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]);
_mm_storeu_si128 (&((__m128i*)out)[0],tmp); _mm_storeu_si128 (&((__m128i*)out)[0],tmp);
} }
#endif

View File

@@ -54,10 +54,14 @@ __attribute__((optimize("unroll-loops")))
inline __m128i aes_128_encrypt(__m128i in, const octet* key) inline __m128i aes_128_encrypt(__m128i in, const octet* key)
{ __m128i& tmp = in; { __m128i& tmp = in;
tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]);
#ifdef __AES__
int j; int j;
for(j=1; j <10; j++) for(j=1; j <10; j++)
{ tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); } { tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); }
tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]);
#else
throw runtime_error("need to compile with AES-NI support");
#endif
return tmp; return tmp;
} }
@@ -70,12 +74,17 @@ inline void ecb_aes_128_encrypt(__m128i* out, __m128i* in, const octet* key)
__m128i tmp[N]; __m128i tmp[N];
for (int i = 0; i < N; i++) for (int i = 0; i < N; i++)
tmp[i] = _mm_xor_si128 (in[i],((__m128i*)key)[0]); tmp[i] = _mm_xor_si128 (in[i],((__m128i*)key)[0]);
#ifdef __AES__
int j; int j;
for(j=1; j <10; j++) for(j=1; j <10; j++)
for (int i = 0; i < N; i++) for (int i = 0; i < N; i++)
tmp[i] = _mm_aesenc_si128 (tmp[i],((__m128i*)key)[j]); tmp[i] = _mm_aesenc_si128 (tmp[i],((__m128i*)key)[j]);
for (int i = 0; i < N; i++) for (int i = 0; i < N; i++)
out[i] = _mm_aesenclast_si128 (tmp[i],((__m128i*)key)[j]); out[i] = _mm_aesenclast_si128 (tmp[i],((__m128i*)key)[j]);
#else
(void) tmp, (void) out;
throw runtime_error("need to compile with AES-NI support");
#endif
} }
template <int N> template <int N>

View File

@@ -201,8 +201,7 @@ void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive
if (sent < len) if (sent < len)
{ {
size_t to_send = min(buffer_size, len - sent); size_t to_send = min(buffer_size, len - sent);
send(send_socket, data + sent, to_send); sent += send_non_blocking(send_socket, data + sent, to_send);
sent += to_send;
} }
// avoid extra branching, false before length received // avoid extra branching, false before length received

View File

@@ -10,10 +10,11 @@ using namespace std;
PRNG::PRNG() : cnt(0) PRNG::PRNG() : cnt(0)
{ {
#ifdef __AES__
#ifdef USE_AES #ifdef USE_AES
useC=(Check_CPU_support_AES()==0); useC=(Check_CPU_support_AES()==0);
#endif #endif
#endif
} }
void PRNG::ReSeed() void PRNG::ReSeed()

View File

@@ -16,7 +16,11 @@
#define SEED_SIZE HASH_SIZE #define SEED_SIZE HASH_SIZE
#define RAND_SIZE HASH_SIZE #define RAND_SIZE HASH_SIZE
#else #else
#ifdef __AES__
#define PIPELINES 8 #define PIPELINES 8
#else
#define PIPELINES 1
#endif
#define SEED_SIZE AES_BLK_SIZE #define SEED_SIZE AES_BLK_SIZE
#define RAND_SIZE (PIPELINES * AES_BLK_SIZE) #define RAND_SIZE (PIPELINES * AES_BLK_SIZE)
#endif #endif
@@ -38,7 +42,11 @@ class PRNG
octet random[RAND_SIZE] __attribute__((aligned (16))); octet random[RAND_SIZE] __attribute__((aligned (16)));
#ifdef USE_AES #ifdef USE_AES
#ifdef __AES__
bool useC; bool useC;
#else
const static bool useC = true;
#endif
// Two types of key schedule for the different implementations // Two types of key schedule for the different implementations
// of AES // of AES

View File

@@ -9,6 +9,8 @@
#include "GC/Instruction.hpp" #include "GC/Instruction.hpp"
#include "GC/Program.hpp" #include "GC/Program.hpp"
#include "Processor/Instruction.hpp"
namespace GC namespace GC
{ {

View File

@@ -15,7 +15,8 @@
#include "GC/Secret.h" #include "GC/Secret.h"
#include "Networking/Player.h" #include "Networking/Player.h"
#include "OT/OTExtensionWithMatrix.h" #include "OT/OTExtensionWithMatrix.h"
#include "sys/sysinfo.h"
#include <thread>
using namespace GC; using namespace GC;
@@ -66,7 +67,8 @@ public:
const Key& get_delta() { return master.delta; } const Key& get_delta() { return master.delta; }
void store_gate(const YaoGate& gate); void store_gate(const YaoGate& gate);
int get_n_worker_threads() { return max(1, get_nprocs() / master.machine.nthreads); } int get_n_worker_threads()
{ return max(1u, thread::hardware_concurrency() / master.machine.nthreads); }
int get_threshold() { return master.threshold; } int get_threshold() { return master.threshold; }
long get_gate_id() { return gate_id(thread_num); } long get_gate_id() { return gate_id(thread_num); }

Some files were not shown because too many files have changed in this diff Show More