# sint: secret integers # you can assign public numbers to sint a = sint(1) b = sint(2) def test(actual, expected): # you can reveal a number in order to print it actual = actual.reveal() print_ln('expected %s, got %s', expected, actual) # some arithmetic works as expected test(a + b, 3) test(a * b, 2) test(a - b, -1) # but division doesn't, don't do the following # test(b / a, 2) # comparisons produce 1 for true and 0 for false test(a < b, 1) test(a <= b, 1) test(a >= b, 0) test(a > b, 0) test(a == b, 0) test(a != b, 1) # if_else() can be used instead of branching # let's find out the larger number test((a < b).if_else(b, a), 2) # arrays and loops work as follows a = Array(100, sint) @for_range(100) def f(i): a[i] = sint(i) * sint(i - 1) test(a[99], 99 * 98) # if you use loops, use Array to store results # don't do this # @for_range(100) # def f(i): # a = sint(i) # test(a, 99) # sfix: fixed-point numbers # set the precision after the dot and in total sfix.set_precision(16, 32) # you can do all basic arithmetic with sfix, including division a = sfix(2) b = sfix(-0.1) test(a + b, 1.9) test(a - b, 2.1) test(a * b, -0.2) test(a / b, -20) test(a < b, 0) test(a <= b, 0) test(a >= b, 1) test(a > b, 1) test(a == b, 0) test(a != b, 1) test((a < b).if_else(a, b), -0.1) # now let's do a computation with private inputs # party 0 supplies three number and party 1 supplies three percentages # we want to compute the weighted mean print_ln('Party 0: please input three numbers not adding up to zero') print_ln('Party 1: please input any three numbers') data = Matrix(3, 2, sfix) # use Python loops for compile-time optimization for i in range(3): for j in range(2): data[i][j] = sfix.from_sint(sint.get_input_from(j)) # compute weighted average weight_total = sum(point[0] for point in data) result = sum(point[0] * point[1] for point in data) / weight_total # the following only works with arithmetic circuits # @if_e((sum(point[0] for point in data) != 0).reveal()) # def _(): # print_ln('weighted average: %s', result.reveal()) # @else_ # def _(): # print_ln('your inputs made no sense') # so we output even an invalid result (the weights adding up to zero) print_ln('weighted average: %s', result.reveal()) # but we warn the user # note that the we don't reveal the weight sum, only the comparison print_ln_if((sum(point[0] for point in data) == 0).reveal(), \ 'but the inputs were invalid (weights add up to zero)')