Files
MP-SPDZ/Programs/Source/bankers_bonus.mpc
Marcel Keller 6cc3fccef0 Maintenance.
2023-05-09 14:50:53 +10:00

144 lines
4.8 KiB
Plaintext

# coding: latin-1
"""
Solve Bankers bonus, aka Millionaires problem.
to deduce the maximum value from a range of integer input.
Demonstrate clients external to computing parties supplying input and receiving an authenticated result. See bankers-bonus-client.cpp for client (and setup instructions).
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
before calculating the maximum.
Note each client connects in a single thread and so is potentially blocked.
Each round / game will reset and so this runs indefinitiely.
"""
from Compiler.types import sint, regint, Array, MemValue
from Compiler.library import print_ln, do_while, for_range
from Compiler.util import if_else
PORTNUM = 14000
MAX_NUM_CLIENTS = 8
n_rounds = 0
n_threads = 2
if len(program.args) > 1:
n_rounds = int(program.args[1])
if len(program.args) > 2:
program.active = bool(int(program.args[2]))
def accept_client():
client_socket_id = accept_client_connection(PORTNUM)
last = regint.read_from_socket(client_socket_id)
return client_socket_id, last
def close_connections(number_clients):
@for_range(number_clients)
def _(i):
closeclientconnection(i)
def client_input(t, client_socket_id):
"""
Send share of random value, receive input and deduce share.
"""
return t.receive_from_client(1, client_socket_id)[0]
def determine_winner(number_clients, client_values, client_ids):
"""Work out and return client_id which corresponds to max client_value"""
max_value = Array(1, client_values.value_type)
max_value[0] = client_values[0]
win_client_id = Array(1, sint)
win_client_id[0] = client_ids[0]
@for_range(number_clients-1)
def loop_body(i):
# Is this client input a new maximum, will be sint(1) if true, else sint(0)
is_new_max = max_value[0] < client_values[i+1]
# Keep latest max_value
max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0])
# Keep current winning client id
win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0])
print_ln('maximum: %s', max_value[0].reveal())
return win_client_id[0]
def write_winner_to_clients(sockets, number_clients, winning_client_id):
"""Send share of winning client id to all clients who joined game."""
# Setup authenticate result using share of random.
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
sint.reveal_to_clients(sockets.get_sub(number_clients), [winning_client_id])
def main():
"""Listen in while loop for players to join a game.
Once maxiumum reached or have notified that round finished, run comparison and return result."""
# Start listening for client socket connections
listen_for_clients(PORTNUM)
print_ln('Listening for client connections on base port %s', PORTNUM)
def game_loop(_=None):
print_ln('Starting a new round of the game.')
# Clients socket id (integer).
client_sockets = Array(MAX_NUM_CLIENTS, regint)
# Number of clients
number_clients = MemValue(regint(0))
# Client ids to identity client
client_ids = Array(MAX_NUM_CLIENTS, sint)
# Keep track of received inputs
seen = Array(MAX_NUM_CLIENTS, regint)
seen.assign_all(0)
# Loop round waiting for each client to connect
@do_while
def client_connections():
client_id, last = accept_client()
@if_(client_id >= MAX_NUM_CLIENTS)
def _():
print_ln('client id too high')
crash()
client_sockets[client_id] = client_id
client_ids[client_id] = client_id
seen[client_id] = 1
@if_(last == 1)
def _():
number_clients.write(client_id + 1)
return (sum(seen) < number_clients) + (number_clients == 0)
def type_run(t):
# Clients secret input.
client_values = t.Array(MAX_NUM_CLIENTS)
@for_range_multithread(n_threads, 1, number_clients)
def _(client_id):
client_values[client_id] = client_input(t, client_id)
winning_client_id = determine_winner(number_clients, client_values,
client_ids)
# print_ln('Found winner, index: %s.', winning_client_id.reveal())
write_winner_to_clients(client_sockets, number_clients,
winning_client_id)
type_run(sint)
type_run(sfix)
close_connections(number_clients)
return True
if n_rounds > 0:
print('run %d rounds' % n_rounds)
for_range(n_rounds)(game_loop)
else:
print('run forever')
do_while(game_loop)
main()