Files
MP-SPDZ/Programs/Source/bankers_bonus_commsec.mpc

145 lines
4.8 KiB
Plaintext

# coding=latin1
"""
Solve Bankers bonus, aka Millionaires problem.
to deduce the maximum value from a range of integer input.
Demonstrate clients external to computing parties supplying input and receiving
an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions).
For an implementation without communications security see bankers_bonus.mpc.
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
before calculating the maximum.
Note each client connects in a single thread and so is potentially blocked.
Each round / game will reset and so this runs indefinitiely.
"""
from Compiler.types import sint, regint, Array, Matrix, MemValue
from Compiler.instructions import listen, acceptclientconnection
from Compiler.library import print_ln, do_while, if_e, else_, for_range
from Compiler.util import if_else
PORTNUM = 14000
MAX_NUM_CLIENTS = 8
n_rounds = 0
if len(program.args) > 1:
n_rounds = int(program.args[1])
def accept_client():
client_socket_id = regint()
acceptclientconnection(client_socket_id, PORTNUM)
last = regint.read_from_socket(client_socket_id)
# Crypto setup
public_signing_key = regint.read_from_socket(client_socket_id, 8)
public_key = regint.read_client_public_key(client_socket_id)
regint.resp_secure_socket(client_socket_id,*public_signing_key)
return client_socket_id, last
def client_input(client_socket_id):
"""
Send share of random value, receive input and deduce share.
"""
client_inputs = sint.receive_from_client(1, client_socket_id)
return client_inputs[0]
def determine_winner(number_clients, client_values, client_ids):
"""Work out and return client_id which corresponds to max client_value"""
max_value = Array(1, sint)
max_value[0] = client_values[0]
win_client_id = Array(1, sint)
win_client_id[0] = client_ids[0]
@for_range(number_clients-1)
def loop_body(i):
# Is this client input a new maximum, will be sint(1) if true, else sint(0)
is_new_max = max_value[0] < client_values[i+1]
# Keep latest max_value
max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0])
# Keep current winning client id
win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0])
return win_client_id[0]
def write_winner_to_clients(sockets, number_clients, winning_client_id):
"""Send share of winning client id to all clients who joined game."""
# Setup authenticate result using share of random.
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
rnd_from_triple = sint.get_random_triple()[0]
auth_result = winning_client_id * rnd_from_triple
@for_range(number_clients)
def loop_body(i):
sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result])
def main():
"""Listen in while loop for players to join a game.
Once maxiumum reached or have notified that round finished, run comparison and return result."""
# Start listening for client socket connections
listen(PORTNUM)
print_ln('Listening for client connections on base port %s', PORTNUM)
def game_loop(_=None):
print_ln('Starting a new round of the game.')
# Clients socket id (integer).
client_sockets = Array(MAX_NUM_CLIENTS, regint)
# Number of clients
number_clients = MemValue(regint(0))
# Clients secret input.
client_values = Array(MAX_NUM_CLIENTS, sint)
# Client ids to identity client
client_ids = Array(MAX_NUM_CLIENTS, sint)
# Keep track of received inputs
seen = Array(MAX_NUM_CLIENTS, regint)
seen.assign_all(0)
# Loop round waiting for each client to connect
@do_while
def client_connections():
client_id, last = accept_client()
@if_(client_id >= MAX_NUM_CLIENTS)
def _():
print_ln('client id too high')
crash()
client_sockets[client_id] = client_id
client_ids[client_id] = client_id
seen[client_id] = 1
@if_(last == 1)
def _():
number_clients.write(client_id + 1)
return (sum(seen) < number_clients) + (number_clients == 0)
@for_range(number_clients)
def _(client_id):
client_values[client_id] = client_input(client_id)
winning_client_id = determine_winner(number_clients, client_values, client_ids)
print_ln('Found winner, index: %s.', winning_client_id.reveal())
write_winner_to_clients(client_sockets, number_clients, winning_client_id)
return True
if n_rounds > 0:
print('run %d rounds' % n_rounds)
for_range(n_rounds)(game_loop)
else:
print('run forever')
do_while(game_loop)
main()