Files
MP-SPDZ/ExternalIO/bankers-bonus-commsec-client.cpp
2019-06-07 15:26:28 +10:00

408 lines
14 KiB
C++

/*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
* Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp
* for a simpler client with no crypto.
*
* Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on
* the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
* - share of random value [r]
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* To run with 2 parties (SPDZ engines) and 3 external clients:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
* ./compile.py bankers_bonus_commsec
* ./Scripts/run-online bankers_bonus_commsec to run the engines.
*
* ./bankers-bonus-commsec-client.x 0 2 100 0
* ./bankers-bonus-commsec-client.x 1 2 200 0
* ./bankers-bonus-commsec-client.x 2 2 50 1
*
* Expect winner to be second client with id 1.
* Note here client id must match id used in generating client key material, Client-Keys-C<client id>.
*/
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Networking/sockets.h"
#include "Networking/STS.h"
#include "Tools/int.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include <sodium.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <fstream>
typedef pair< vector<octet>, vector<octet> > keypair_t; // A pair of send/recv keys for talking to SPDZ
typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server
typedef struct {
unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES];
unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES];
vector<int> client_publickey_ints;
vector< vector<unsigned char> >server_publickey;
} sign_key_container_t;
keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
os.Receive(sockets[server_id]);
os.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
os.reset_write_head();
os.append(m2.pubkey, sizeof m2.pubkey);
os.append(m2.sig, sizeof m2.sig);
os.Send(sockets[server_id]);
os.Receive(sockets[server_id]);
os.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
m1 = ke.send_msg1();
cout << "m1: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j];
cout << dec << endl;
os.reset_write_head();
os.append(m1.bytes, sizeof m1.bytes);
os.Send(sockets[server_id]);
os.reset_write_head();
os.Receive(sockets[server_id]);
os.consume(m2.pubkey, sizeof m2.pubkey);
os.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
os.reset_write_head();
os.append(m3.bytes, sizeof m3.bytes);
os.Send(sockets[server_id]);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
pair< vector<octet>, vector<octet> > sts_response_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_response_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_initiator_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties,
commsec_t commsec, vector<octet*>& keys)
{
int num_inputs = values.size();
octetStream os;
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
vector<gfp> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
os.decrypt_sequence(&commsec[j].second[0],0);
os.decrypt(keys[j]);
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations
for (int i = 0; i < num_inputs; i++)
{
if (triples[i][0] * triples[i][1] != triples[i][2])
{
cerr << "Incorrect triple at " << i << ", aborting\n";
exit(1);
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
gfp y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++)
{
octetStream temp = os;
temp.encrypt_sequence(&commsec[j].first[0], 0);
temp.Send(sockets[j]);
}
}
// Send public key in clear to each SPDZ engine.
void send_public_key(vector<int>& pubkey, int socket)
{
octetStream os;
os.reset_write_head();
for (unsigned int i = 0; i < pubkey.size(); i++)
{
os.store(pubkey[i]);
}
os.Send(socket);
}
// Assumes that Scripts/setup-online.sh has been run to compute prime
void initialise_fields(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cout << "loading params from: " << filename << endl;
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
// Assumes that client-setup has been run to create key pairs for clients and parties
void generate_symmetric_keys(vector<octet*>& keys, vector<int>& client_public_key_ints,
sign_key_container_t *sts_key, const string& dir_prefix, int client_no)
{
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char client_secretkey[crypto_box_SECRETKEYBYTES];
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char scalarmult_q[crypto_scalarmult_BYTES];
crypto_generichash_state h;
// read client public/secret keys + SPDZ server public keys
ifstream keyfile;
stringstream client_filename;
client_filename << dir_prefix << "Client-Keys-C" << client_no;
keyfile.open(client_filename.str().c_str());
if (keyfile.fail())
throw file_error(client_filename.str());
keyfile.read((char*)client_publickey, sizeof client_publickey);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "client public key" );
// Convert client public key unsigned char to int, reverse endianness
for(unsigned int j = 0; j < client_public_key_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8));
}
client_public_key_ints[j] = keybyte;
}
keyfile.read((char*)client_secretkey, sizeof client_secretkey);
if (keyfile.eof()) {
throw end_of_file(client_filename.str(), "client private key" );
}
keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES);
keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES);
// Convert client public key unsigned char to int, reverse endianness
sts_key->client_publickey_ints.resize(8);
for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8));
}
sts_key->client_publickey_ints[j] = keybyte;
}
for (unsigned int i = 0; i < keys.size(); i++)
{
keys[i] = new octet[crypto_generichash_BYTES];
keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public key for party " + to_string(i));
keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public signing key for party " + to_string(i));
// Derive a shared key from this server's secret key and the client's public key
// shared key = h(q || client_secretkey || server_publickey)
if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) {
cerr << "Scalar mult failed\n";
exit(1);
}
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q);
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES);
}
keyfile.close();
cout << "My public key is: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) client_publickey[j];
cout << dec << endl;
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
gfp receive_result(vector<int>& sockets, int nparties, commsec_t commsec, vector<octet*>& keys)
{
vector<gfp> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
os.reset_write_head();
os.Receive(sockets[i]);
os.decrypt_sequence(&commsec[i].second[0],1);
os.decrypt(keys[i]);
for (unsigned int j = 0; j < 3; j++)
{
gfp value;
value.unpack(os);
output_values[j] += value;
}
}
if (output_values[0] * output_values[1] != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
exit(1);
}
return output_values[0];
}
int main(int argc, char** argv)
{
int my_client_id;
int nparties;
int salary_value;
int finish;
int port_base = 14000;
sign_key_container_t sts_key;
string host_name = "localhost";
if (argc < 5) {
cout << "Usage is external-client <client identifier> <number of spdz parties> "
<< "<salary to compare> <finish (0 false, 1 true)> <optional host name, default localhost> "
<< "<optional spdz party port base number, default 14000>" << endl;
exit(0);
}
my_client_id = atoi(argv[1]);
nparties = atoi(argv[2]);
salary_value = atoi(argv[3]);
finish = atoi(argv[4]);
if (argc > 5)
host_name = argv[5];
if (argc > 6)
port_base = atoi(argv[6]);
sts_key.server_publickey.resize(nparties);
for(int i = 0 ; i < nparties; i++) {
sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES);
}
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
initialise_fields(prep_data_prefix);
// Generate session keys to decrypt data sent from each spdz engine (party)
vector<octet*> session_keys(nparties);
vector<int> client_public_key_ints(8);
generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id);
// Setup connections from this client to each party socket and send the client public keys
vector<int> sockets(nparties);
// vector< pair <vector<octet>, vector <octet> > > commseckey(nparties);
commsec_t commseckey(nparties);
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send_public_key(sts_key.client_publickey_ints, sockets[i]);
send_public_key(client_public_key_ints, sockets[i]);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);
// Send the inputs to the SPDZ Engines
send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back
gfp result = receive_result(sockets, nparties, commseckey, session_keys);
cout << "Winning client id is : " << result << endl;
for (unsigned int i = 0; i < sockets.size(); i++)
close_client_socket(sockets[i]);
return 0;
}