mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 22:17:57 -05:00
103 lines
2.7 KiB
Plaintext
103 lines
2.7 KiB
Plaintext
"""
|
|
Example programs used in the SPDZ tutorial at the TPMPC 2017 workshop in Bristol.
|
|
"""
|
|
|
|
from util import if_else
|
|
|
|
program.bit_length = 32
|
|
|
|
def millionnaires():
|
|
""" Secure comparison, receiving input from each party via stdin """
|
|
print_ln("Waiting for Alice's input")
|
|
alice = sint.get_input_from(0)
|
|
print_ln("Waiting for Bob's input")
|
|
bob = sint.get_input_from(1)
|
|
|
|
b = alice < bob
|
|
print_ln('The richest is: %s', b.reveal())
|
|
|
|
def naive_search(n):
|
|
""" Search secret list for private input from Bob """
|
|
# hardcoded "secret" list from Alice - in a real application this should be a private input
|
|
a = [sint(i) for i in range(n)]
|
|
print_ln("Waiting for search input from Bob")
|
|
b = sint.get_input_from(1)
|
|
|
|
eq_bits = [x == b for x in a]
|
|
b_in_a = sum(eq_bits)
|
|
print_ln("Is b in Alice's list? %s", b_in_a.reveal())
|
|
|
|
def scalable_search(n):
|
|
""" Search using SPDZ loop to avoid loop unrolling """
|
|
array = Array(n, sint)
|
|
|
|
@for_range(n)
|
|
def _(i):
|
|
array[i] = sint(i)
|
|
|
|
print_ln("Waiting for search input from Bob")
|
|
b = sint.get_input_from(1)
|
|
|
|
# need to use MemValue and Array inside @for_range loops,
|
|
# instead of basic sint/cint registers
|
|
result = MemValue(sint(0))
|
|
|
|
@for_range(100, n)
|
|
def _(i):
|
|
result.write(result + (array[i] == b))
|
|
|
|
print_ln("Is b in Alice's list? %s", result.reveal())
|
|
|
|
def compute_intersection(a, b):
|
|
""" Naive quadratic private set intersection.
|
|
|
|
Returns: secret Array with intersection (padded to len(a)), and
|
|
secret Array of bits indicating whether Alice's input matches or not """
|
|
n = len(a)
|
|
if n != len(b):
|
|
raise CompilerError('Inconsistent lengths to compute_intersection')
|
|
intersection = Array(n, sint)
|
|
is_match_at = Array(n, sint)
|
|
|
|
@for_range(n)
|
|
def _(i):
|
|
@for_range(n)
|
|
def _(j):
|
|
match = a[i] == b[j]
|
|
is_match_at[i] += match
|
|
intersection[i] = if_else(match, a[i], intersection[i]) # match * a[i] + (1 - match) * intersection[i]
|
|
return intersection, is_match_at
|
|
|
|
def set_intersection_example(n):
|
|
"""Naive private set intersection on two Arrays, followed by computing the size and average of the intersection"""
|
|
a = Array(n, sint)
|
|
b = Array(n, sint)
|
|
print_ln('Running PSI example')
|
|
@for_range(n)
|
|
def _(i):
|
|
a[i] = i
|
|
b[i] = i + 60
|
|
intersection, is_match_at = compute_intersection(a,b)
|
|
|
|
print_ln('Printing set intersection (0: not in intersection)')
|
|
size = MemValue(sint(0))
|
|
total = MemValue(sint(0))
|
|
@for_range(n)
|
|
def _(i):
|
|
size.write(size + is_match_at[i])
|
|
total.write(total + intersection[i])
|
|
print_str('%s ', intersection[i].reveal())
|
|
print_ln('\nIntersection size: %s', size.reveal())
|
|
|
|
total_fixed = sfix()
|
|
total_fixed.load_int(total.read())
|
|
print_ln('Average in intersection: %s', (total_fixed / size.read()).reveal())
|
|
|
|
|
|
|
|
millionnaires()
|
|
naive_search(100)
|
|
scalable_search(10000)
|
|
set_intersection_example(100)
|
|
|