mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-07 20:53:55 -05:00
144 lines
4.8 KiB
Plaintext
144 lines
4.8 KiB
Plaintext
# coding: latin-1
|
|
"""
|
|
Solve Bankers bonus, aka Millionaires problem.
|
|
to deduce the maximum value from a range of integer input.
|
|
|
|
Demonstrate clients external to computing parties supplying input and receiving an authenticated result. See bankers-bonus-client.cpp for client (and setup instructions).
|
|
|
|
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
|
|
before calculating the maximum.
|
|
|
|
Note each client connects in a single thread and so is potentially blocked.
|
|
|
|
Each round / game will reset and so this runs indefinitiely.
|
|
"""
|
|
|
|
from Compiler.types import sint, regint, Array, MemValue
|
|
from Compiler.library import print_ln, do_while, for_range
|
|
from Compiler.util import if_else
|
|
|
|
PORTNUM = 14000
|
|
MAX_NUM_CLIENTS = 8
|
|
n_rounds = 0
|
|
n_threads = 2
|
|
|
|
if len(program.args) > 1:
|
|
n_rounds = int(program.args[1])
|
|
|
|
if len(program.args) > 2:
|
|
program.active = bool(int(program.args[2]))
|
|
|
|
def accept_client():
|
|
client_socket_id = accept_client_connection(PORTNUM)
|
|
last = regint.read_from_socket(client_socket_id)
|
|
return client_socket_id, last
|
|
|
|
def close_connections(number_clients):
|
|
@for_range(number_clients)
|
|
def _(i):
|
|
closeclientconnection(i)
|
|
|
|
def client_input(t, client_socket_id):
|
|
"""
|
|
Send share of random value, receive input and deduce share.
|
|
"""
|
|
|
|
return t.receive_from_client(1, client_socket_id)[0]
|
|
|
|
|
|
def determine_winner(number_clients, client_values, client_ids):
|
|
"""Work out and return client_id which corresponds to max client_value"""
|
|
max_value = Array(1, client_values.value_type)
|
|
max_value[0] = client_values[0]
|
|
win_client_id = Array(1, sint)
|
|
win_client_id[0] = client_ids[0]
|
|
|
|
@for_range(number_clients-1)
|
|
def loop_body(i):
|
|
# Is this client input a new maximum, will be sint(1) if true, else sint(0)
|
|
is_new_max = max_value[0] < client_values[i+1]
|
|
# Keep latest max_value
|
|
max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0])
|
|
# Keep current winning client id
|
|
win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0])
|
|
|
|
print_ln('maximum: %s', max_value[0].reveal())
|
|
return win_client_id[0]
|
|
|
|
|
|
def write_winner_to_clients(sockets, number_clients, winning_client_id):
|
|
"""Send share of winning client id to all clients who joined game."""
|
|
|
|
# Setup authenticate result using share of random.
|
|
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
|
|
sint.reveal_to_clients(sockets.get_sub(number_clients), [winning_client_id])
|
|
|
|
def main():
|
|
"""Listen in while loop for players to join a game.
|
|
Once maxiumum reached or have notified that round finished, run comparison and return result."""
|
|
# Start listening for client socket connections
|
|
listen_for_clients(PORTNUM)
|
|
print_ln('Listening for client connections on base port %s', PORTNUM)
|
|
|
|
def game_loop(_=None):
|
|
print_ln('Starting a new round of the game.')
|
|
|
|
# Clients socket id (integer).
|
|
client_sockets = Array(MAX_NUM_CLIENTS, regint)
|
|
# Number of clients
|
|
number_clients = MemValue(regint(0))
|
|
# Client ids to identity client
|
|
client_ids = Array(MAX_NUM_CLIENTS, sint)
|
|
# Keep track of received inputs
|
|
seen = Array(MAX_NUM_CLIENTS, regint)
|
|
seen.assign_all(0)
|
|
|
|
# Loop round waiting for each client to connect
|
|
@do_while
|
|
def client_connections():
|
|
client_id, last = accept_client()
|
|
@if_(client_id >= MAX_NUM_CLIENTS)
|
|
def _():
|
|
print_ln('client id too high')
|
|
crash()
|
|
client_sockets[client_id] = client_id
|
|
client_ids[client_id] = client_id
|
|
seen[client_id] = 1
|
|
@if_(last == 1)
|
|
def _():
|
|
number_clients.write(client_id + 1)
|
|
|
|
return (sum(seen) < number_clients) + (number_clients == 0)
|
|
|
|
def type_run(t):
|
|
# Clients secret input.
|
|
client_values = t.Array(MAX_NUM_CLIENTS)
|
|
|
|
@for_range_multithread(n_threads, 1, number_clients)
|
|
def _(client_id):
|
|
client_values[client_id] = client_input(t, client_id)
|
|
|
|
winning_client_id = determine_winner(number_clients, client_values,
|
|
client_ids)
|
|
|
|
# print_ln('Found winner, index: %s.', winning_client_id.reveal())
|
|
|
|
write_winner_to_clients(client_sockets, number_clients,
|
|
winning_client_id)
|
|
|
|
type_run(sint)
|
|
type_run(sfix)
|
|
|
|
close_connections(number_clients)
|
|
|
|
return True
|
|
|
|
if n_rounds > 0:
|
|
print('run %d rounds' % n_rounds)
|
|
for_range(n_rounds)(game_loop)
|
|
else:
|
|
print('run forever')
|
|
do_while(game_loop)
|
|
|
|
main()
|