""" Example programs used in the SPDZ tutorial at the TPMPC 2017 workshop in Bristol. """ from util import if_else program.bit_length = 32 def millionnaires(): """ Secure comparison, receiving input from each party via stdin """ print_ln("Waiting for Alice's input") alice = sint.get_input_from(0) print_ln("Waiting for Bob's input") bob = sint.get_input_from(1) b = alice < bob print_ln('The richest is: %s', b.reveal()) def naive_search(n): """ Search secret list for private input from Bob """ # hardcoded "secret" list from Alice - in a real application this should be a private input a = [sint(i) for i in range(n)] print_ln("Waiting for search input from Bob") b = sint.get_input_from(1) eq_bits = [x == b for x in a] b_in_a = sum(eq_bits) print_ln("Is b in Alice's list? %s", b_in_a.reveal()) def scalable_search(n): """ Search using SPDZ loop to avoid loop unrolling """ array = Array(n, sint) @for_range(n) def _(i): array[i] = sint(i) print_ln("Waiting for search input from Bob") b = sint.get_input_from(1) # need to use MemValue and Array inside @for_range loops, # instead of basic sint/cint registers result = MemValue(sint(0)) @for_range(100, n) def _(i): result.write(result + (array[i] == b)) print_ln("Is b in Alice's list? %s", result.reveal()) def compute_intersection(a, b): """ Naive quadratic private set intersection. Returns: secret Array with intersection (padded to len(a)), and secret Array of bits indicating whether Alice's input matches or not """ n = len(a) if n != len(b): raise CompilerError('Inconsistent lengths to compute_intersection') intersection = Array(n, sint) is_match_at = Array(n, sint) @for_range(n) def _(i): @for_range(n) def _(j): match = a[i] == b[j] is_match_at[i] += match intersection[i] = if_else(match, a[i], intersection[i]) # match * a[i] + (1 - match) * intersection[i] return intersection, is_match_at def set_intersection_example(n): """Naive private set intersection on two Arrays, followed by computing the size and average of the intersection""" a = Array(n, sint) b = Array(n, sint) print_ln('Running PSI example') @for_range(n) def _(i): a[i] = i b[i] = i + 60 intersection, is_match_at = compute_intersection(a,b) print_ln('Printing set intersection (0: not in intersection)') size = MemValue(sint(0)) total = MemValue(sint(0)) @for_range(n) def _(i): size.write(size + is_match_at[i]) total.write(total + intersection[i]) print_str('%s ', intersection[i].reveal()) print_ln('\nIntersection size: %s', size.reveal()) total_fixed = sfix() total_fixed.load_int(total.read()) print_ln('Average in intersection: %s', (total_fixed / size.read()).reveal()) millionnaires() naive_search(100) scalable_search(10000) set_intersection_example(100)