Files
MP-SPDZ/Programs/Source/tutorial.mpc
2023-12-14 12:17:54 +11:00

140 lines
3.1 KiB
Plaintext

# sint: secret integers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint
# you can assign public numbers to sint
a = sint(1)
b = sint(2)
def test(actual, expected):
# you can reveal a number in order to print it
actual = actual.reveal()
print_ln('expected %s, got %s', expected, actual)
# private inputs are read from Player-Data/Input-P<i>-0
# or from standard input if using command-line option -I
# see https://mp-spdz.readthedocs.io/en/latest/io.html for more options
for i in 0, 1:
print_ln('got %s from player %s', sint.get_input_from(i).reveal(), i)
# some arithmetic works as expected
test(a + b, 3)
test(a * b, 2)
test(a - b, -1)
# Division can mean different things in different domains
# and there has be a specified bit length in some,
# so we use int_div() for integer division.
# k-bit division requires (4k+1)-bit computation.
test(b.int_div(a, 15), 2)
# comparisons produce 1 for true and 0 for false
test(a < b, 1)
test(a <= b, 1)
test(a >= b, 0)
test(a > b, 0)
test(a == b, 0)
test(a != b, 1)
# if_else() can be used instead of branching
# let's find out the larger number
test((a < b).if_else(b, a), 2)
# arrays and loops work as follows
a = Array(100, sint)
@for_range(100)
def f(i):
a[i] = sint(i) * sint(i - 1)
test(a[99], 99 * 98)
# if you use loops, use Array to store results
# don't do this
# @for_range(100)
# def f(i):
# a = sint(i)
# test(a, 99)
# sfix: fixed-point numbers
# see also https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix
# set the precision after the dot and in total
sfix.set_precision(16, 31)
# and the output precision in decimal digits
print_float_precision(4)
# you can do all basic arithmetic with sfix, including division
a = sfix(2)
b = sfix(-0.1)
test(a + b, 1.9)
test(a - b, 2.1)
test(a * b, -0.2)
test(a / b, -20)
test(a < b, 0)
test(a <= b, 0)
test(a >= b, 1)
test(a > b, 1)
test(a == b, 0)
test(a != b, 1)
test((a < b).if_else(a, b), -0.1)
# now let's do a computation with private inputs
# party 0 supplies three number and party 1 supplies three percentages
# we want to compute the weighted mean
print_ln('Party 0: please input three numbers not adding up to zero')
print_ln('Party 1: please input any three numbers')
data = Matrix(3, 2, sfix)
# use @for_range_opt for balanced optimization
# but use Python loops if compile-time numbers are need (e.g., for players)
@for_range_opt(3)
def _(i):
for j in range(2):
data[i][j] = sfix.get_input_from(j)
# compute weighted average
weight_total = sum(point[0] for point in data)
result = sum(point[0] * point[1] for point in data) / weight_total
# branching is supported also depending on revealed secret data
# with garbled circuits this triggers a interruption of the garbling
@if_e((sum(point[0] for point in data) != 0).reveal())
def _():
print_ln('weighted average: %s', result.reveal())
@else_
def _():
print_ln('your inputs made no sense')
# permutation matrix
M = Matrix(2, 2, sfix)
M[0][0] = 0
M[1][0] = 1
M[0][1] = 1
M[1][1] = 0
# matrix multiplication
M = data * M
test(M[0][0], data[0][1].reveal())
test(M[1][1], data[1][0].reveal())