mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
134 lines
2.7 KiB
Plaintext
134 lines
2.7 KiB
Plaintext
# sint: secret integers
|
|
|
|
# you can assign public numbers to sint
|
|
|
|
a = sint(1)
|
|
b = sint(2)
|
|
|
|
def test(actual, expected):
|
|
|
|
# you can reveal a number in order to print it
|
|
|
|
actual = actual.reveal()
|
|
print_ln('expected %s, got %s', expected, actual)
|
|
|
|
# some arithmetic works as expected
|
|
|
|
test(a + b, 3)
|
|
test(a * b, 2)
|
|
test(a - b, -1)
|
|
|
|
# but division doesn't, don't do the following
|
|
# test(b / a, 2)
|
|
|
|
# comparisons produce 1 for true and 0 for false
|
|
|
|
test(a < b, 1)
|
|
test(a <= b, 1)
|
|
test(a >= b, 0)
|
|
test(a > b, 0)
|
|
test(a == b, 0)
|
|
test(a != b, 1)
|
|
|
|
# if_else() can be used instead of branching
|
|
# let's find out the larger number
|
|
test((a < b).if_else(b, a), 2)
|
|
|
|
# arrays and loops work as follows
|
|
|
|
a = Array(100, sint)
|
|
|
|
@for_range(100)
|
|
def f(i):
|
|
a[i] = sint(i) * sint(i - 1)
|
|
|
|
test(a[99], 99 * 98)
|
|
|
|
# if you use loops, use Array to store results
|
|
# don't do this
|
|
# @for_range(100)
|
|
# def f(i):
|
|
# a = sint(i)
|
|
# test(a, 99)
|
|
|
|
# sfix: fixed-point numbers
|
|
|
|
# set the precision after the dot and in total
|
|
|
|
sfix.set_precision(16, 32)
|
|
|
|
# and the output precision in decimal digits
|
|
|
|
print_float_precision(4)
|
|
|
|
# you can do all basic arithmetic with sfix, including division
|
|
|
|
a = sfix(2)
|
|
b = sfix(-0.1)
|
|
|
|
test(a + b, 1.9)
|
|
test(a - b, 2.1)
|
|
test(a * b, -0.2)
|
|
test(a / b, -20)
|
|
test(a < b, 0)
|
|
test(a <= b, 0)
|
|
test(a >= b, 1)
|
|
test(a > b, 1)
|
|
test(a == b, 0)
|
|
test(a != b, 1)
|
|
|
|
test((a < b).if_else(a, b), -0.1)
|
|
|
|
# now let's do a computation with private inputs
|
|
# party 0 supplies three number and party 1 supplies three percentages
|
|
# we want to compute the weighted mean
|
|
|
|
print_ln('Party 0: please input three numbers not adding up to zero')
|
|
print_ln('Party 1: please input any three numbers')
|
|
|
|
data = Matrix(3, 2, sfix)
|
|
|
|
# use Python loops for compile-time optimization
|
|
|
|
for i in range(3):
|
|
for j in range(2):
|
|
data[i][j] = sfix.from_sint(sint.get_input_from(j))
|
|
|
|
# compute weighted average
|
|
|
|
weight_total = sum(point[0] for point in data)
|
|
result = sum(point[0] * point[1] for point in data) / weight_total
|
|
|
|
# the following only works with arithmetic circuits
|
|
|
|
# @if_e((sum(point[0] for point in data) != 0).reveal())
|
|
# def _():
|
|
# print_ln('weighted average: %s', result.reveal())
|
|
# @else_
|
|
# def _():
|
|
# print_ln('your inputs made no sense')
|
|
|
|
# so we output even an invalid result (the weights adding up to zero)
|
|
|
|
print_ln('weighted average: %s', result.reveal())
|
|
|
|
# but we warn the user
|
|
# note that the we don't reveal the weight sum, only the comparison
|
|
|
|
print_ln_if((sum(point[0] for point in data) == 0).reveal(), \
|
|
'but the inputs were invalid (weights add up to zero)')
|
|
|
|
# permutation matrix
|
|
|
|
M = Matrix(2, 2, sfix)
|
|
M[0][0] = 0
|
|
M[1][0] = 1
|
|
M[0][1] = 1
|
|
M[1][1] = 0
|
|
|
|
# matrix multiplication
|
|
|
|
M = data * M
|
|
test(M[0][0], data[0][1].reveal())
|
|
test(M[1][1], data[1][0].reveal())
|