mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 22:17:57 -05:00
76 lines
2.4 KiB
Plaintext
76 lines
2.4 KiB
Plaintext
import util
|
|
from Compiler import types
|
|
|
|
import math
|
|
import re
|
|
r = re.search('(\D*)(\d*)', program.name)
|
|
|
|
if r.group(2):
|
|
n_inputs = int(r.group(2))
|
|
else:
|
|
n_inputs = 100
|
|
|
|
n_parties = 2
|
|
n_threads = int(math.ceil(2 ** (int(math.log(n_inputs, 2) - 7))))
|
|
n_loops = 1
|
|
n_bits = 64
|
|
#value_type = types.get_sgf2nuint(n_bits)
|
|
value_type = sint
|
|
|
|
program.set_bit_length(n_bits)
|
|
program.set_security(40)
|
|
|
|
print_ln('n_inputs = %s, n_parties = %s, n_threads = %s, n_loops = %s, '
|
|
'value_type = %s',
|
|
n_inputs, n_parties, n_threads, n_loops, value_type.__name__)
|
|
|
|
@for_range(n_loops)
|
|
def f(_):
|
|
Bid = types.getNamedTupleType('party', 'price')
|
|
bids = Bid.get_array(n_inputs, value_type)
|
|
|
|
for i in range(n_inputs):
|
|
# i * 10 because inputs are all zero by default
|
|
bids[i] = Bid(i, value_type.get_input_from(i % n_parties) + i * 10)
|
|
#bids = [Bid(i, value_type(i * 10)) for i in range(n_parties)]
|
|
|
|
def bid_sort(a, b):
|
|
comp = a.price < b.price
|
|
res = util.cond_swap(comp, a, b)
|
|
for i in res:
|
|
i.price = value_type.hard_conv(i.price)
|
|
return res
|
|
|
|
def first_and_second(left, right):
|
|
top = left[0].price < right[0].price
|
|
cross = [left[i].price < right[1-i].price for i in range(2)]
|
|
first = top.if_else(right[0], left[0])
|
|
tmp = [cross[i].if_else(right[1-i], left[i]) for i in (0,1)]
|
|
second = top.if_else(*tmp)
|
|
for i in (first, second):
|
|
i.price = value_type.hard_conv(i.price)
|
|
return first, second
|
|
|
|
results = Bid.get_array(2 * n_threads, value_type)
|
|
|
|
def thread():
|
|
i = get_arg()
|
|
n_per_thread = n_inputs // n_threads
|
|
if n_per_thread % 2 != 0:
|
|
raise Exception('Number of inputs must be divisible by 2')
|
|
start = i * n_per_thread
|
|
tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \
|
|
for j in range(n_per_thread // 2)]
|
|
first, second = util.tree_reduce(first_and_second, tuples)
|
|
results[2*i], results[2*i+1] = first, second
|
|
|
|
tape = program.new_tape(thread)
|
|
threads = [program.run_tape(tape, i) for i in range(n_threads)]
|
|
for i in threads:
|
|
program.join_tape(i)
|
|
|
|
tuples = [(results[2*i], results[2*i+1]) for i in range(n_threads)]
|
|
first, second = util.tree_reduce(first_and_second, tuples)
|
|
|
|
print_ln('Winner: %s, price: %s', first.party.reveal(), second.price.reveal())
|