Files
MP-SPDZ/Programs/Source/tpmpc_tutorial.mpc
Marcel Keller cc0711c224 MP-SPDZ.
2018-10-11 17:20:26 +11:00

103 lines
2.7 KiB
Plaintext

"""
Example programs used in the SPDZ tutorial at the TPMPC 2017 workshop in Bristol.
"""
from util import if_else
program.bit_length = 32
def millionnaires():
""" Secure comparison, receiving input from each party via stdin """
print_ln("Waiting for Alice's input")
alice = sint.get_input_from(0)
print_ln("Waiting for Bob's input")
bob = sint.get_input_from(1)
b = alice < bob
print_ln('The richest is: %s', b.reveal())
def naive_search(n):
""" Search secret list for private input from Bob """
# hardcoded "secret" list from Alice - in a real application this should be a private input
a = [sint(i) for i in range(n)]
print_ln("Waiting for search input from Bob")
b = sint.get_input_from(1)
eq_bits = [x == b for x in a]
b_in_a = sum(eq_bits)
print_ln("Is b in Alice's list? %s", b_in_a.reveal())
def scalable_search(n):
""" Search using SPDZ loop to avoid loop unrolling """
array = Array(n, sint)
@for_range(n)
def _(i):
array[i] = sint(i)
print_ln("Waiting for search input from Bob")
b = sint.get_input_from(1)
# need to use MemValue and Array inside @for_range loops,
# instead of basic sint/cint registers
result = MemValue(sint(0))
@for_range(100, n)
def _(i):
result.write(result + (array[i] == b))
print_ln("Is b in Alice's list? %s", result.reveal())
def compute_intersection(a, b):
""" Naive quadratic private set intersection.
Returns: secret Array with intersection (padded to len(a)), and
secret Array of bits indicating whether Alice's input matches or not """
n = len(a)
if n != len(b):
raise CompilerError('Inconsistent lengths to compute_intersection')
intersection = Array(n, sint)
is_match_at = Array(n, sint)
@for_range(n)
def _(i):
@for_range(n)
def _(j):
match = a[i] == b[j]
is_match_at[i] += match
intersection[i] = if_else(match, a[i], intersection[i]) # match * a[i] + (1 - match) * intersection[i]
return intersection, is_match_at
def set_intersection_example(n):
"""Naive private set intersection on two Arrays, followed by computing the size and average of the intersection"""
a = Array(n, sint)
b = Array(n, sint)
print_ln('Running PSI example')
@for_range(n)
def _(i):
a[i] = i
b[i] = i + 60
intersection, is_match_at = compute_intersection(a,b)
print_ln('Printing set intersection (0: not in intersection)')
size = MemValue(sint(0))
total = MemValue(sint(0))
@for_range(n)
def _(i):
size.write(size + is_match_at[i])
total.write(total + intersection[i])
print_str('%s ', intersection[i].reveal())
print_ln('\nIntersection size: %s', size.reveal())
total_fixed = sfix()
total_fixed.load_int(total.read())
print_ln('Average in intersection: %s', (total_fixed / size.read()).reveal())
millionnaires()
naive_search(100)
scalable_search(10000)
set_intersection_example(100)