Avoid race condition depending of the order of client connections.

This commit is contained in:
Marcel Keller
2019-11-25 14:52:27 +11:00
parent 470b075803
commit 9a83cfe8f2
10 changed files with 124 additions and 130 deletions

View File

@@ -7,9 +7,9 @@
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an increasing id to identify the client, starting with 0
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
* - sends an integer input (bonus value to compare)
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
@@ -22,13 +22,13 @@
* To run with 2 parties / SPDZ engines:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./compile.py bankers_bonus
* ./Scripts/run-online bankers_bonus to run the engines.
* ./Scripts/run-online.sh bankers_bonus to run the engines.
*
* ./bankers-bonus-client.x 123 2 100 0
* ./bankers-bonus-client.x 456 2 200 0
* ./bankers-bonus-client.x 789 2 50 1
* ./bankers-bonus-client.x 0 2 100 0
* ./bankers-bonus-client.x 1 2 200 0
* ./bankers-bonus-client.x 2 2 50 1
*
* Expect winner to be second client with id 456.
* Expect winner to be second client with id 1.
*/
#include "Math/gfp.h"
@@ -46,7 +46,7 @@
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties)
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties)
{
int num_inputs = values.size();
octetStream os;
@@ -172,17 +172,15 @@ int main(int argc, char** argv)
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
octetStream os;
os.store(finish);
os.Send(sockets[i]);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);
// Run the commputation
send_private_inputs(input_values_gfp, sockets, nparties);
send_private_inputs({salary_value}, sockets, nparties);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back (client_id of winning client)

View File

@@ -9,10 +9,10 @@
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an increasing id to identify the client, starting with 0
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends an integer input (bonus value to compare)
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
@@ -24,7 +24,7 @@
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
* ./compile.py bankers_bonus_commsec
* ./Scripts/run-online bankers_bonus_commsec to run the engines.
* ./Scripts/run-online.sh bankers_bonus_commsec to run the engines.
*
* ./bankers-bonus-commsec-client.x 0 2 100 0
* ./bankers-bonus-commsec-client.x 1 2 200 0
@@ -139,7 +139,7 @@ pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t key
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties,
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties,
commsec_t commsec, vector<octet*>& keys)
{
int num_inputs = values.size();
@@ -380,20 +380,19 @@ int main(int argc, char** argv)
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
octetStream os;
os.store(finish);
os.Send(sockets[i]);
send_public_key(sts_key.client_publickey_ints, sockets[i]);
send_public_key(client_public_key_ints, sockets[i]);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);
// Send the inputs to the SPDZ Engines
send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys);
send_private_inputs({salary_value}, sockets, nparties, commseckey, session_keys);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back

View File

@@ -107,6 +107,7 @@ void ServerSocket::accept_clients()
}
data_signal.lock();
process_client(client_id);
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
@@ -157,8 +158,6 @@ void* anonymous_accept_thread(void* server_socket)
return 0;
}
int AnonymousServerSocket::global_client_socket_count = 0;
void AnonymousServerSocket::init()
{
pthread_create(&thread, 0, anonymous_accept_thread, this);
@@ -169,22 +168,12 @@ int AnonymousServerSocket::get_connection_count()
return num_accepted_clients;
}
void AnonymousServerSocket::accept_clients()
void AnonymousServerSocket::process_client(int client_id)
{
while (true)
{
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
data_signal.lock();
client_connection_queue.push(consocket);
num_accepted_clients++;
data_signal.broadcast();
data_signal.unlock();
}
if (clients.find(client_id) != clients.end())
close_client_socket(clients[client_id]);
num_accepted_clients++;
client_connection_queue.push(client_id);
}
int AnonymousServerSocket::get_connection_socket(int& client_id)
@@ -195,10 +184,9 @@ int AnonymousServerSocket::get_connection_socket(int& client_id)
while (client_connection_queue.empty())
data_signal.wait();
client_id = global_client_socket_count;
global_client_socket_count++;
int client_socket = client_connection_queue.front();
client_id = client_connection_queue.front();
client_connection_queue.pop();
int client_socket = clients[client_id];
data_signal.unlock();
return client_socket;
}

View File

@@ -28,8 +28,7 @@ protected:
// disable copying
ServerSocket(const ServerSocket& other);
// receive id from client
int assign_client_id(int consocket);
virtual void process_client(int) {}
public:
ServerSocket(int Portnum);
@@ -55,23 +54,20 @@ public:
class AnonymousServerSocket : public ServerSocket
{
private:
// Global no. of client sockets that have been returned - used to create identifiers
static int global_client_socket_count;
// No. of accepted connections in this instance
int num_accepted_clients;
queue<int> client_connection_queue;
void process_client(int client_id);
public:
AnonymousServerSocket(int Portnum) :
ServerSocket(Portnum), num_accepted_clients(0) { };
// override so clients do not send id
void accept_clients();
void init();
virtual int get_connection_count();
// Get socket for the last client who connected
// Writes a unique client identifier (i.e. a counter) to client_id
// Get socket and id for the last client who connected
int get_connection_socket(int& client_id);
};

View File

@@ -62,6 +62,10 @@ int ExternalClients::get_client_connection(int portnum_base)
int client_id, socket;
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
external_client_sockets[client_id] = socket;
if (symmetric_client_keys[client_id] != 0)
delete symmetric_client_keys[client_id];
symmetric_client_commsec_send_keys.erase(client_id);
symmetric_client_commsec_recv_keys.erase(client_id);
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
return client_id;
}
@@ -175,3 +179,9 @@ int ExternalClients::get_party_num()
return party_num;
}
int ExternalClients::get_socket(int id)
{
if (external_client_sockets.find(id) == external_client_sockets.end())
throw runtime_error("external connection not found for id " + to_string(id));
return external_client_sockets[id];
}

View File

@@ -28,13 +28,14 @@ class ExternalClients
bool server_keys_loaded = false;
bool ed25519_keys_loaded = false;
// Maps holding per client values (indexed by unique 32-bit id)
std::map<int,int> external_client_sockets;
public:
unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES];
unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES];
// Maps holding per client values (indexed by unique 32-bit id)
std::map<int,int> external_client_sockets;
std::map<int,octet*> symmetric_client_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_send_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_recv_keys;

View File

@@ -100,11 +100,6 @@ template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
int socket_id, int message_type, const vector<int>& registers)
{
if (socket_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << socket_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();
@@ -144,7 +139,7 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
// Apply STS commsec encryption if session keys have been created.
try {
maybe_encrypt_sequence(socket_id);
socket_stream.Send(external_clients.external_client_sockets[socket_id]);
socket_stream.Send(external_clients.get_socket(socket_id));
}
catch (bad_value& e) {
cerr << "Send error thrown when writing " << m << " values of type " << reg_type << " to socket id "
@@ -157,15 +152,9 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>& registers)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
@@ -179,15 +168,9 @@ void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>&
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>& registers)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
@@ -199,14 +182,9 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int>& registers, bool read_macs)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(client_id);
@@ -251,11 +229,6 @@ void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const ve
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
throw "No socket connection exists for client";
}
// Extract client long term public key into bytes
vector<int> client_public_key (registers.size(), 0);
@@ -269,15 +242,15 @@ void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const ve
m1 = ke.send_msg1();
socket_stream.reset_write_head();
socket_stream.append(m1.bytes, sizeof m1.bytes);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.Send(external_clients.get_socket(client_id));
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
96);
socket_stream.consume(m2.pubkey, sizeof m2.pubkey);
socket_stream.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
socket_stream.reset_write_head();
socket_stream.append(m3.bytes, sizeof m3.bytes);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.Send(external_clients.get_socket(client_id));
// Use results of STS to generate send and receive keys.
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
@@ -323,11 +296,6 @@ void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const ve
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
throw "No socket connection exists for client";
}
vector<int> client_public_key (registers.size(), 0);
for(unsigned int i = 0; i < registers.size(); i++) {
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
@@ -337,16 +305,16 @@ void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const ve
// Start Station to Station Protocol for the responder
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
socket_stream.reset_read_head();
socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
32);
socket_stream.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
socket_stream.reset_write_head();
socket_stream.append(m2.pubkey, sizeof m2.pubkey);
socket_stream.append(m2.sig, sizeof m2.sig);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.Send(external_clients.get_socket(client_id));
socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
64);
socket_stream.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);

View File

@@ -23,17 +23,20 @@ from Compiler.util import if_else
PORTNUM = 14000
MAX_NUM_CLIENTS = 8
def accept_client_input():
"""
Wait for socket connection, send share of random value, receive input and deduce share.
Expect 3 inputs: unique id, bonus value and flag to indicate end of this round.
"""
def accept_client():
client_socket_id = regint()
acceptclientconnection(client_socket_id, PORTNUM)
client_inputs = sint.receive_from_client(3, client_socket_id)
last = regint.read_from_socket(client_socket_id)
return client_socket_id, last
return client_socket_id, client_inputs[0], client_inputs[1], client_inputs[2]
def client_input(client_socket_id):
"""
Send share of random value, receive input and deduce share.
"""
client_inputs = sint.receive_from_client(1, client_socket_id)
return client_inputs[0]
def determine_winner(number_clients, client_values, client_ids):
@@ -87,16 +90,30 @@ def main():
client_values = Array(MAX_NUM_CLIENTS, sint)
# Client ids to identity client
client_ids = Array(MAX_NUM_CLIENTS, sint)
# Keep track of received inputs
seen = Array(MAX_NUM_CLIENTS, regint)
seen.assign_all(0)
# Loop round waiting for each client to connect
@do_while
def client_connections():
client_sockets[number_clients], client_ids[number_clients], client_values[number_clients], finish = accept_client_input()
number_clients.write(number_clients+1)
client_id, last = accept_client()
@if_(client_id >= MAX_NUM_CLIENTS)
def _():
print_ln('client id too high')
crash()
client_sockets[client_id] = client_id
client_ids[client_id] = client_id
seen[client_id] = 1
@if_(last == 1)
def _():
number_clients.write(client_id + 1)
# continue while both expressions are false
return (number_clients >= MAX_NUM_CLIENTS) + finish.reveal() == 0
return (sum(seen) < number_clients) + (number_clients == 0)
@for_range(number_clients)
def _(client_id):
client_values[client_id] = client_input(client_id)
winning_client_id = determine_winner(number_clients, client_values, client_ids)

View File

@@ -29,23 +29,26 @@ n_rounds = 0
if len(program.args) > 1:
n_rounds = int(program.args[1])
def accept_client_input():
"""
Wait for socket connection and read for client public key.
send share of random value, receive input and deduce share.
Expect 3 inputs: unique id, bonus value and flag to indicate end of this round.
"""
def accept_client():
client_socket_id = regint()
acceptclientconnection(client_socket_id, PORTNUM)
last = regint.read_from_socket(client_socket_id)
# Crypto setup
public_signing_key = regint.read_from_socket(client_socket_id, 8)
public_key = regint.read_client_public_key(client_socket_id)
regint.resp_secure_socket(client_socket_id,*public_signing_key)
client_inputs = sint.receive_from_client(3, client_socket_id)
return client_socket_id, last
return client_socket_id, client_inputs[0], client_inputs[1], client_inputs[2]
def client_input(client_socket_id):
"""
Send share of random value, receive input and deduce share.
"""
client_inputs = sint.receive_from_client(1, client_socket_id)
return client_inputs[0]
def determine_winner(number_clients, client_values, client_ids):
@@ -98,16 +101,30 @@ def main():
client_values = Array(MAX_NUM_CLIENTS, sint)
# Client ids to identity client
client_ids = Array(MAX_NUM_CLIENTS, sint)
# Keep track of received inputs
seen = Array(MAX_NUM_CLIENTS, regint)
seen.assign_all(0)
# Loop round waiting for each client to connect
@do_while
def client_connections():
client_sockets[number_clients], client_ids[number_clients], client_values[number_clients], finish = accept_client_input()
number_clients.write(number_clients+1)
client_id, last = accept_client()
@if_(client_id >= MAX_NUM_CLIENTS)
def _():
print_ln('client id too high')
crash()
client_sockets[client_id] = client_id
client_ids[client_id] = client_id
seen[client_id] = 1
@if_(last == 1)
def _():
number_clients.write(client_id + 1)
# continue while both expressions are false
return (number_clients >= MAX_NUM_CLIENTS) + finish.reveal() == 0
return (sum(seen) < number_clients) + (number_clients == 0)
@for_range(number_clients)
def _(client_id):
client_values[client_id] = client_input(client_id)
winning_client_id = determine_winner(number_clients, client_values, client_ids)

View File

@@ -141,9 +141,9 @@ class octetStream
}
template<class T>
void Send(T& socket_num) const;
void Send(T socket_num) const;
template<class T>
void Receive(T& socket_num);
void Receive(T socket_num);
void ReceiveExpected(int socket_num, size_t expected);
// In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES
@@ -255,7 +255,7 @@ inline size_t octetStream::get_int(int n_bytes)
template<class T>
inline void octetStream::Send(T& socket_num) const
inline void octetStream::Send(T socket_num) const
{
send(socket_num,len,LENGTH_SIZE);
send(socket_num,data,len);
@@ -263,7 +263,7 @@ inline void octetStream::Send(T& socket_num) const
template<class T>
inline void octetStream::Receive(T& socket_num)
inline void octetStream::Receive(T socket_num)
{
size_t nlen=0;
receive(socket_num,nlen,LENGTH_SIZE);