Files
MP-SPDZ/Programs/Source/blink.mpc
2018-10-26 15:52:49 +11:00

103 lines
2.9 KiB
Plaintext

import math
import util
n_threads = 64
xor_op = lambda x, y: x ^ y
n_bits = 64
full_t = sbits.get_type(64)
sbits.n = n_bits
if len(program.args) > 1:
n_batches = int(program.args[1])
else:
n_batches = 78
batch_size = 64
n = n_batches * batch_size
l = 16
a = Matrix(n, l, full_t)
b = Matrix(n, l, full_t)
t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1)
matches = Matrix(n, n, t.bit_type)
mismatches = Matrix(n, n, t)
threshold = MemValue(t(10))
for i in range(n):
for j in range(l):
a[i][j] = full_t.get_input_from(0)
b[i][j] = full_t.get_input_from(1)
# test, create match between a[0] and b[1] but no match for a[1]
a.assign_all(0)
b.assign_all(0)
a[0][0] = -1
b[1][0] = -1
a[1][1] = -1
@for_range_multithread(n_batches, 1, n)
def _(i):
print_ln('%s', i)
@for_range_parallel(100, n_batches)
def _(j):
j = j * batch_size
av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \
for kk in range(l))
bv = sbitintvec.from_matrix((b[j + k][kk] for k in range(batch_size)) \
for kk in range(l))
res = xor_op(av, bv).popcnt()
mismatches[i].set_range(j, (t(x) for x in res.elements()))
@for_range_multithread(n_batches, 8, n)
def _(i):
print_ln('%s', i)
@for_range_parallel(100, n_batches)
def _(j):
j = j * batch_size
v = sbitintvec(mismatches[i].get_range(j, batch_size))
vv = sbitintvec([threshold.read()] * batch_size)
matches[i].set_range(j, v.less_than(vv, 10).elements())
mg = MultiArray([n_batches, n, t.n], full_t)
ag = Matrix(n_batches, n, full_t)
@for_range_multithread(n_batches, 1, n_batches)
def _(i):
m = mg[i]
a = ag[i]
i = i * batch_size
print_ln('best %s', i)
@for_range(n)
def _(j):
m[j].assign(sbitintvec(mismatches[i + k][j]
for k in range(batch_size)).v)
m = [sbitintvec.from_vec(m[j]) for j in range(n)]
def reducer(a, b):
c = a[0].less_than(b[0])
return util.if_else(c, (a[0], a[1] + [0] * len(b[1])),
(b[0], [0] * len(a[1]) + b[1]))
mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in m))
a.assign(mm[1])
@for_range_parallel(100, len(a))
def _(j):
x = a[j]
pm = sbitintvec(matches[i + k][j] for k in range(batch_size))
x = sbitintvec.from_vec([x])
for k, y in enumerate((pm & x).elements()):
matches[i + k][j] = y
def test(result, expected):
print_ln('%s ?= %s', result.reveal(), expected)
test(matches[0][1], 1)
test(matches[0][0], 0)
test(matches[1][0], 0)
test(matches[1][1], 0)
test(sum(matches[2]), 1)
test(mismatches[0][1], 0)
test(mismatches[0][0], 64)
test(mismatches[1][0], 64)
test(mismatches[1][1], 128)
print_ln('%sx%s linkage of %s bits', n, n, l * batch_size)