Files
MP-SPDZ/Programs/Source/vickrey.mpc
2021-07-02 15:50:34 +10:00

76 lines
2.4 KiB
Plaintext

import util
from Compiler import types
import math
import re
r = re.search('(\D*)(\d*)', program.name)
if r.group(2):
n_inputs = int(r.group(2))
else:
n_inputs = 100
n_parties = 2
n_threads = int(math.ceil(2 ** (int(math.log(n_inputs, 2) - 7))))
n_loops = 1
n_bits = 64
#value_type = types.get_sgf2nuint(n_bits)
value_type = sint
program.set_bit_length(n_bits)
program.set_security(40)
print_ln('n_inputs = %s, n_parties = %s, n_threads = %s, n_loops = %s, '
'value_type = %s',
n_inputs, n_parties, n_threads, n_loops, value_type.__name__)
@for_range(n_loops)
def f(_):
Bid = types.getNamedTupleType('party', 'price')
bids = Bid.get_array(n_inputs, value_type)
for i in range(n_inputs):
# i * 10 because inputs are all zero by default
bids[i] = Bid(i, value_type.get_input_from(i % n_parties) + i * 10)
#bids = [Bid(i, value_type(i * 10)) for i in range(n_parties)]
def bid_sort(a, b):
comp = a.price < b.price
res = util.cond_swap(comp, a, b)
for i in res:
i.price = value_type.hard_conv(i.price)
return res
def first_and_second(left, right):
top = left[0].price < right[0].price
cross = [left[i].price < right[1-i].price for i in range(2)]
first = top.if_else(right[0], left[0])
tmp = [cross[i].if_else(right[1-i], left[i]) for i in (0,1)]
second = top.if_else(*tmp)
for i in (first, second):
i.price = value_type.hard_conv(i.price)
return first, second
results = Bid.get_array(2 * n_threads, value_type)
def thread():
i = get_arg()
n_per_thread = n_inputs // n_threads
if n_per_thread % 2 != 0:
raise Exception('Number of inputs must be divisible by 2')
start = i * n_per_thread
tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \
for j in range(n_per_thread // 2)]
first, second = util.tree_reduce(first_and_second, tuples)
results[2*i], results[2*i+1] = first, second
tape = program.new_tape(thread)
threads = [program.run_tape(tape, i) for i in range(n_threads)]
for i in threads:
program.join_tape(i)
tuples = [(results[2*i], results[2*i+1]) for i in range(n_threads)]
first, second = util.tree_reduce(first_and_second, tuples)
print_ln('Winner: %s, price: %s', first.party.reveal(), second.price.reveal())