mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
496 lines
15 KiB
Python
496 lines
15 KiB
Python
from random import randint
|
|
import math
|
|
|
|
#import sys
|
|
#from Test.core import *
|
|
if '_Array' not in dir():
|
|
from Compiler.types import *
|
|
from Compiler.types import _secret
|
|
from Compiler.library import *
|
|
from Compiler.program import Program
|
|
_Array = Array
|
|
|
|
SORT_BITS = []
|
|
|
|
def predefined_comparator(x, y):
|
|
""" Assumes SORT_BITS is populated with the required sorting network bits """
|
|
if predefined_comparator.sort_bits_iter is None:
|
|
predefined_comparator.sort_bits_iter = iter(SORT_BITS)
|
|
return next(predefined_comparator.sort_bits_iter)
|
|
predefined_comparator.sort_bits_iter = None
|
|
|
|
def list_comparator(x, y):
|
|
""" Uses the first element in the list for comparison """
|
|
return x[0] < y[0]
|
|
|
|
def normal_comparator(x, y):
|
|
return x < y
|
|
|
|
def bitwise_list_comparator(x, y):
|
|
""" Uses the first element in the list for comparison """
|
|
return (1 - x[0]) * y[0]
|
|
|
|
def bitwise_comparator(x, y):
|
|
b = (1 - x) * y
|
|
return b
|
|
|
|
def cond_swap_bit(x,y, b):
|
|
""" swap if b == 1 """
|
|
if x is None:
|
|
return y, None
|
|
elif y is None:
|
|
return x, None
|
|
if isinstance(x, list):
|
|
t = [(xi - yi) * b for xi,yi in zip(x, y)]
|
|
return [xi - ti for xi,ti in zip(x, t)], \
|
|
[yi + ti for yi,ti in zip(y, t)]
|
|
else:
|
|
t = (x - y) * b
|
|
return x - t, y + t
|
|
|
|
def cond_swap(x,y, comp):
|
|
if x is None:
|
|
return y, None
|
|
elif y is None:
|
|
return x, None
|
|
b = comp(x, y)
|
|
return cond_swap_bit(x, y, 1 - b)
|
|
|
|
def odd_even_merge(a, comp):
|
|
if len(a) & (len(a)-1) != 0:
|
|
raise Exception('Length must be a power of 2')
|
|
if len(a) == 1:
|
|
return
|
|
if len(a) == 2:
|
|
a[0], a[1] = cond_swap(a[0], a[1], comp)
|
|
else:
|
|
even = a[::2]
|
|
odd = a[1::2]
|
|
odd_even_merge(even, comp)
|
|
odd_even_merge(odd, comp)
|
|
a[0] = even[0]
|
|
for i in range(1, len(a) // 2):
|
|
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp)
|
|
a[-1] = odd[-1]
|
|
|
|
def odd_even_merge_sort(a, comp=bitwise_comparator):
|
|
if len(a) == 1:
|
|
return
|
|
elif len(a) % 2 == 0:
|
|
lower = a[:len(a)//2]
|
|
upper = a[len(a)//2:]
|
|
odd_even_merge_sort(lower, comp)
|
|
odd_even_merge_sort(upper, comp)
|
|
a[:] = lower + upper
|
|
odd_even_merge(a, comp)
|
|
else:
|
|
raise CompilerError('Length of list must be power of two')
|
|
|
|
def merge(a, b, comp):
|
|
""" General length merge (pads to power of 2) """
|
|
while len(a) & (len(a)-1) != 0:
|
|
a.append(None)
|
|
while len(b) & (len(b)-1) != 0:
|
|
b.append(None)
|
|
if len(a) < len(b):
|
|
a += [None] * (len(b) - len(a))
|
|
elif len(b) < len(a):
|
|
b += [None] * (len(b) - len(b))
|
|
t = a + b
|
|
odd_even_merge(t, comp)
|
|
for i,v in enumerate(t[::]):
|
|
if v is None:
|
|
t.remove(None)
|
|
return t
|
|
|
|
def sort(a, comp):
|
|
""" Pads to power of 2, sorts, removes padding """
|
|
length = len(a)
|
|
while len(a) & (len(a)-1) != 0:
|
|
a.append(None)
|
|
odd_even_merge_sort(a, comp)
|
|
del a[length:]
|
|
|
|
def recursive_merge(a, comp):
|
|
""" Recursively merge a list of sorted lists (initially sorted by size) """
|
|
if len(a) == 1:
|
|
return
|
|
# merge smallest two lists, place result in correct position, recurse
|
|
t = merge(a[0], a[1], comp)
|
|
del a[0]
|
|
del a[0]
|
|
added = False
|
|
for i,c in enumerate(a):
|
|
if len(c) >= len(t):
|
|
a.insert(i, t)
|
|
added = True
|
|
break
|
|
if not added:
|
|
a.append(t)
|
|
recursive_merge(a, comp)
|
|
|
|
def random_perm(n):
|
|
""" Generate a random permutation of length n
|
|
|
|
WARNING: randomness fixed at compile-time, this is NOT secure
|
|
"""
|
|
if not Program.prog.options.insecure:
|
|
raise CompilerError('no secure implementation of Waksman permution, '
|
|
'use --insecure to activate')
|
|
a = list(range(n))
|
|
for i in range(n-1, 0, -1):
|
|
j = randint(0, i)
|
|
t = a[i]
|
|
a[i] = a[j]
|
|
a[j] = t
|
|
return a
|
|
|
|
def inverse(perm):
|
|
inv = [None] * len(perm)
|
|
for i, p in enumerate(perm):
|
|
inv[p] = i
|
|
return inv
|
|
|
|
def configure_waksman(perm):
|
|
n = len(perm)
|
|
if n == 2:
|
|
return [(perm[0], perm[0])]
|
|
I = [None] * (n//2)
|
|
O = [None] * (n//2)
|
|
p0 = [None] * (n//2)
|
|
p1 = [None] * (n//2)
|
|
inv_perm = [0] * n
|
|
|
|
for i, p in enumerate(perm):
|
|
inv_perm[p] = i
|
|
|
|
while True:
|
|
try:
|
|
j = 2 * O.index(None)
|
|
except ValueError:
|
|
break
|
|
#print 'j =', j
|
|
O[j//2] = 0
|
|
via = 0
|
|
j0 = j
|
|
while True:
|
|
#print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2)
|
|
|
|
i = inv_perm[j]
|
|
#print ' p0[%d] = %d' % (inv_perm[j]/2, j/2)
|
|
p0[i//2] = j//2
|
|
|
|
I[i//2] = i % 2
|
|
O[j//2] = j % 2
|
|
#print ' O[%d] = %d' % (j/2, j % 2)
|
|
if i % 2 == 1:
|
|
i -= 1
|
|
else:
|
|
i += 1
|
|
#i, via = set_swapper(I, j, via, inv_perm)
|
|
|
|
#print ' O[%d] = %d' % (perm[i]/2, ((perm[i] % 2) + via ) % 2)
|
|
j = perm[i]
|
|
#O[j/2] = j % 2
|
|
if j % 2 == 1:
|
|
j -= 1
|
|
else:
|
|
j += 1
|
|
#j, via = set_swapper(O, i, via, perm)
|
|
#print ' p1[%d] = %d' % (i/2, perm[i]/2)
|
|
p1[i//2] = perm[i]//2
|
|
|
|
#print ' i = %d, j = %d' %(i,j)
|
|
if j == j0:
|
|
break
|
|
if None not in p0 and None not in p1:
|
|
break
|
|
|
|
assert sorted(p0) == list(range(n//2))
|
|
assert sorted(p1) == list(range(n//2))
|
|
p0_config = configure_waksman(p0)
|
|
p1_config = configure_waksman(p1)
|
|
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
|
|
|
|
def waksman(a, config, depth=0, start=0, reverse=False):
|
|
""" config is a list of log_2(n) configuration lists for the sub-networks """
|
|
n = len(a)
|
|
if n == 2:
|
|
a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start])
|
|
return
|
|
|
|
a0 = [0] * (n//2)
|
|
a1 = [0] * (n//2)
|
|
for i in range(n//2):
|
|
if reverse:
|
|
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start])
|
|
else:
|
|
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start])
|
|
|
|
waksman(a0, config, depth+1, start, reverse)
|
|
waksman(a1, config, depth+1, start + n//2, reverse)
|
|
|
|
for i in range(n//2):
|
|
if reverse:
|
|
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start])
|
|
else:
|
|
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start])
|
|
|
|
|
|
WAKSMAN_FUNCTIONS = {}
|
|
def iter_waksman(a, config, reverse=False):
|
|
""" Iterative Waksman algorithm, compilable for large inputs. Input
|
|
must be an Array. """
|
|
n = len(a)
|
|
#if not isinstance(a, Array):
|
|
# raise CompilerError('Input must be an Array')
|
|
|
|
depth = MemValue(0)
|
|
nblocks = MemValue(1)
|
|
size = MemValue(0)
|
|
a2 = Array(n, a[0].reg_type)
|
|
#config_array = Array(n, a[0].reg_type)
|
|
#reverse = (int(reverse))
|
|
|
|
def create_round_fn(n, reg_type, inwards):
|
|
if (n, reg_type, inwards, reverse) in WAKSMAN_FUNCTIONS:
|
|
return WAKSMAN_FUNCTIONS[(n, reg_type, inwards, reverse)]
|
|
|
|
def do_round(size, config_address, a_address, a2_address):
|
|
A = Array(n, reg_type, a_address)
|
|
A2 = Array(n, reg_type, a2_address)
|
|
C = Array(n, reg_type, config_address)
|
|
outwards = 1 - inwards
|
|
|
|
sizeval = size
|
|
#for k in range(n//2):
|
|
@for_range_parallel(200, n//2)
|
|
def f(k):
|
|
j = cint(k) % sizeval
|
|
i = (cint(k) - j)//sizeval
|
|
base = 2*i*sizeval
|
|
|
|
in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards)
|
|
out1, out2 = (base+j+j*outwards), (base+j+j*outwards+1*outwards+sizeval*inwards)
|
|
|
|
if inwards:
|
|
if reverse:
|
|
c = C[base + j + sizeval]
|
|
else:
|
|
c = C[base + j]
|
|
else:
|
|
if reverse:
|
|
c = C[base + j]
|
|
else:
|
|
c = C[base + j + sizeval]
|
|
A2[out1], A2[out2] = cond_swap_bit(A[in1], A[in2], c)
|
|
|
|
fn = function_block(do_round)
|
|
WAKSMAN_FUNCTIONS[(n, reg_type, inwards, reverse)] = fn
|
|
return fn
|
|
|
|
do_round = lambda size, ca, aa, aa2, inwards: \
|
|
create_round_fn(n, a[0].reg_type, inwards)(size, ca, aa, aa2)
|
|
|
|
logn = int(math.log(n,2))
|
|
|
|
# going into middle of network
|
|
@for_range(logn)
|
|
def f(i):
|
|
size.write(n//(2*nblocks))
|
|
conf_address = MemValue(config.address + depth.read()*n)
|
|
do_round(size, conf_address, a.address, a2.address, 1)
|
|
|
|
@for_range(n)
|
|
def _(i):
|
|
a[i] = a2[i]
|
|
|
|
nblocks.write(nblocks*2)
|
|
depth.write(depth+1)
|
|
|
|
nblocks.write(nblocks//4)
|
|
depth.write(depth-2)
|
|
|
|
# and back out
|
|
@for_range(logn-1)
|
|
def f(i):
|
|
size.write(n//(2*nblocks))
|
|
conf_address = MemValue(config.address + depth.read()*n)
|
|
do_round(size, conf_address, a.address, a2.address, 0)
|
|
|
|
@for_range(n)
|
|
def _(i):
|
|
a[i] = a2[i]
|
|
|
|
nblocks.write(nblocks//2)
|
|
depth.write(depth-1)
|
|
|
|
## going into middle of network
|
|
#while nblocks < n:
|
|
# #for i in range(n):
|
|
# # config_array[i] = config[depth][i].read()
|
|
#
|
|
# size.write(n/(2*nblocks))
|
|
# conf_address = config.address + depth*n
|
|
# do_round_in(size, conf_address, a.address, a2.address)
|
|
#
|
|
# for i in range(n):
|
|
# a[i] = a2[i]
|
|
#
|
|
# nblocks *= 2
|
|
# depth += 1
|
|
#
|
|
#nblocks /= 4
|
|
#depth -= 2
|
|
## and back out
|
|
#while nblocks > 0:
|
|
# #for i in range(n):
|
|
# # config_array[i] = config[depth][i].read()
|
|
#
|
|
# size.write(n/(2*nblocks))
|
|
# conf_address = config.address + depth*n
|
|
# do_round_out(size, conf_address, a.address, a2.address)
|
|
#
|
|
# for i in range(n):
|
|
# a[i] = a2[i]
|
|
#
|
|
# nblocks /= 2
|
|
# depth -= 1
|
|
|
|
def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
|
n = len(x)
|
|
if n & (n-1) != 0:
|
|
raise CompilerError('shuffle requires n a power of 2')
|
|
if config is None:
|
|
config = configure_waksman(random_perm(n))
|
|
for i,c in enumerate(config):
|
|
config[i] = [value_type.bit_type(b) for b in c]
|
|
waksman(x, config, reverse=reverse)
|
|
waksman(x, config, reverse=reverse)
|
|
|
|
|
|
def config_shuffle(n, value_type):
|
|
""" Compute config for oblivious shuffling.
|
|
|
|
Take mod 2 for active sec. """
|
|
perm = random_perm(n)
|
|
if n & (n-1) != 0:
|
|
# pad permutation to power of 2
|
|
m = 2**int(math.ceil(math.log(n, 2)))
|
|
perm += list(range(n, m))
|
|
config_bits = configure_waksman(perm)
|
|
# 2-D array
|
|
config = Array(len(config_bits) * len(perm), value_type.reg_type)
|
|
if n > 1024:
|
|
for x in config_bits:
|
|
for y in x:
|
|
get_program().public_input(y)
|
|
@for_range(sum(len(x) for x in config_bits))
|
|
def _(i):
|
|
config[i] = public_input()
|
|
return config
|
|
for i,c in enumerate(config_bits):
|
|
for j,b in enumerate(c):
|
|
config[i * len(perm) + j] = b
|
|
return config
|
|
|
|
def shuffle(x, config=None, value_type=sgf2n, reverse=False):
|
|
""" Simulate secure shuffling with Waksman network for 2 players.
|
|
WARNING: This is not a properly secure implementation but has roughly the right complexity.
|
|
|
|
Returns the network switching config so it may be re-used later. """
|
|
n = len(x)
|
|
m = 2**int(math.ceil(math.log(n, 2)))
|
|
assert n == m, 'only working for powers of two'
|
|
if config is None:
|
|
config = config_shuffle(n, value_type)
|
|
|
|
if isinstance(x, list):
|
|
if isinstance(x[0], list):
|
|
length = len(x[0])
|
|
assert len(x) == length
|
|
for i in range(length):
|
|
xi = Array(m, value_type.reg_type)
|
|
for j in range(n):
|
|
xi[j] = x[j][i]
|
|
for j in range(n, m):
|
|
xi[j] = value_type(0)
|
|
iter_waksman(xi, config, reverse=reverse)
|
|
iter_waksman(xi, config, reverse=reverse)
|
|
for j, y in enumerate(xi):
|
|
x[j][i] = y
|
|
else:
|
|
xa = Array(m, value_type.reg_type)
|
|
for i in range(n):
|
|
xa[i] = x[i]
|
|
for i in range(n, m):
|
|
xa[i] = value_type(0)
|
|
iter_waksman(xa, config, reverse=reverse)
|
|
iter_waksman(xa, config, reverse=reverse)
|
|
x[:] = xa
|
|
elif isinstance(x, Array):
|
|
if len(x) != m and config is None:
|
|
raise CompilerError('Non-power of 2 Array input not yet supported')
|
|
iter_waksman(x, config, reverse=reverse)
|
|
iter_waksman(x, config, reverse=reverse)
|
|
else:
|
|
raise CompilerError('Invalid type for shuffle:', type(x))
|
|
|
|
return config
|
|
|
|
def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None):
|
|
""" Shuffle a list of ORAM entries.
|
|
|
|
Randomly permutes the first "perm_size" entries, leaving the rest (empty
|
|
entry padding) in the same position. """
|
|
n = len(x)
|
|
l = len(x[0])
|
|
if n & (n-1) != 0:
|
|
raise CompilerError('Entries must be padded to power of two length.')
|
|
if perm_size is None:
|
|
perm_size = n
|
|
|
|
xarrays = [Array(n, value_type.reg_type) for i in range(l)]
|
|
for i in range(n):
|
|
for j,value in enumerate(x[i]):
|
|
if isinstance(value, MemValue):
|
|
xarrays[j][i] = value.read()
|
|
else:
|
|
xarrays[j][i] = value
|
|
|
|
if config is None:
|
|
config = config_shuffle(perm_size, value_type)
|
|
for xi in xarrays:
|
|
shuffle(xi, config, value_type, reverse)
|
|
for i in range(n):
|
|
x[i] = entry_cls(xarrays[j][i] for j in range(l))
|
|
return config
|
|
|
|
|
|
def sort_zeroes(bits, x, n_ones, value_type):
|
|
""" Return Array of values in "x" where the corresponding bit in "bits" is
|
|
a 0.
|
|
|
|
The total number of zeroes in "bits" must be known.
|
|
"bits" and "x" must be Arrays. """
|
|
config = config_shuffle(len(x), value_type)
|
|
shuffle(bits, config=config, value_type=value_type)
|
|
shuffle(x, config=config, value_type=value_type)
|
|
result = Array(n_ones, value_type.reg_type)
|
|
|
|
sz = MemValue(0)
|
|
last_x = MemValue(value_type(0))
|
|
#for i,b in enumerate(bits):
|
|
#if_then(b.reveal() == 0)
|
|
#result[sz.read()] = x[i]
|
|
#sz += 1
|
|
#end_if()
|
|
@for_range(len(bits))
|
|
def f(i):
|
|
found = (bits[i].reveal() == 0)
|
|
szval = sz.read()
|
|
result[szval] = last_x + (x[i] - last_x) * found
|
|
sz.write(sz + found)
|
|
last_x.write(result[szval])
|
|
return result
|