mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
103 lines
2.9 KiB
Plaintext
103 lines
2.9 KiB
Plaintext
import math
|
|
import util
|
|
|
|
n_threads = 64
|
|
xor_op = lambda x, y: x ^ y
|
|
n_bits = 64
|
|
full_t = sbits.get_type(64)
|
|
sbits.n = n_bits
|
|
|
|
if len(program.args) > 1:
|
|
n_batches = int(program.args[1])
|
|
else:
|
|
n_batches = 78
|
|
|
|
batch_size = 64
|
|
n = n_batches * batch_size
|
|
l = 16
|
|
a = Matrix(n, l, full_t)
|
|
b = Matrix(n, l, full_t)
|
|
t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1)
|
|
matches = Matrix(n, n, t.bit_type)
|
|
mismatches = Matrix(n, n, t)
|
|
threshold = MemValue(t(10))
|
|
|
|
for i in range(n):
|
|
for j in range(l):
|
|
a[i][j] = full_t.get_input_from(0)
|
|
b[i][j] = full_t.get_input_from(1)
|
|
|
|
# test, create match between a[0] and b[1] but no match for a[1]
|
|
a.assign_all(0)
|
|
b.assign_all(0)
|
|
a[0][0] = -1
|
|
b[1][0] = -1
|
|
a[1][1] = -1
|
|
|
|
@for_range_multithread(n_batches, 1, n)
|
|
def _(i):
|
|
print_ln('%s', i)
|
|
@for_range_parallel(100, n_batches)
|
|
def _(j):
|
|
j = j * batch_size
|
|
av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \
|
|
for kk in range(l))
|
|
bv = sbitintvec.from_matrix((b[j + k][kk] for k in range(batch_size)) \
|
|
for kk in range(l))
|
|
res = xor_op(av, bv).popcnt()
|
|
mismatches[i].set_range(j, (t(x) for x in res.elements()))
|
|
|
|
@for_range_multithread(n_batches, 8, n)
|
|
def _(i):
|
|
print_ln('%s', i)
|
|
@for_range_parallel(100, n_batches)
|
|
def _(j):
|
|
j = j * batch_size
|
|
v = sbitintvec(mismatches[i].get_range(j, batch_size))
|
|
vv = sbitintvec([threshold.read()] * batch_size)
|
|
matches[i].set_range(j, v.less_than(vv, 10).elements())
|
|
|
|
mg = MultiArray([n_batches, n, t.n], full_t)
|
|
ag = Matrix(n_batches, n, full_t)
|
|
|
|
@for_range_multithread(n_batches, 1, n_batches)
|
|
def _(i):
|
|
m = mg[i]
|
|
a = ag[i]
|
|
i = i * batch_size
|
|
print_ln('best %s', i)
|
|
@for_range(n)
|
|
def _(j):
|
|
m[j].assign(sbitintvec(mismatches[i + k][j]
|
|
for k in range(batch_size)).v)
|
|
m = [sbitintvec.from_vec(m[j]) for j in range(n)]
|
|
def reducer(a, b):
|
|
c = a[0].less_than(b[0])
|
|
return util.if_else(c, (a[0], a[1] + [0] * len(b[1])),
|
|
(b[0], [0] * len(a[1]) + b[1]))
|
|
mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in m))
|
|
a.assign(mm[1])
|
|
@for_range_parallel(100, len(a))
|
|
def _(j):
|
|
x = a[j]
|
|
pm = sbitintvec(matches[i + k][j] for k in range(batch_size))
|
|
x = sbitintvec.from_vec([x])
|
|
for k, y in enumerate((pm & x).elements()):
|
|
matches[i + k][j] = y
|
|
|
|
def test(result, expected):
|
|
print_ln('%s ?= %s', result.reveal(), expected)
|
|
|
|
test(matches[0][1], 1)
|
|
test(matches[0][0], 0)
|
|
test(matches[1][0], 0)
|
|
test(matches[1][1], 0)
|
|
test(sum(matches[2]), 1)
|
|
|
|
test(mismatches[0][1], 0)
|
|
test(mismatches[0][0], 64)
|
|
test(mismatches[1][0], 64)
|
|
test(mismatches[1][1], 128)
|
|
|
|
print_ln('%sx%s linkage of %s bits', n, n, l * batch_size)
|