mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 22:17:57 -05:00
37 lines
882 B
Plaintext
37 lines
882 B
Plaintext
program.use_edabit(True)
|
|
program.set_bit_length(32)
|
|
|
|
n = int(program.args[1])
|
|
n_threads = None
|
|
|
|
try:
|
|
n_threads = int(program.args[2])
|
|
except:
|
|
pass
|
|
|
|
db = MultiArray([n, 4], sint)
|
|
sample = sint.Array(4)
|
|
|
|
db.input_from(0)
|
|
sample.input_from(1)
|
|
|
|
def match(db_entry, sample):
|
|
return sum(x * x for x in (a - b for a, b in zip(db_entry, sample)))
|
|
|
|
from Compiler import util
|
|
|
|
if n_threads is None:
|
|
res = util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n)))
|
|
else:
|
|
tmp = sint.Array(n_threads)
|
|
|
|
@multithread(n_threads, n)
|
|
def _(base, size):
|
|
tmp[get_arg()] = util.tree_reduce(lambda x, y: x.min(y),
|
|
(match(db[base + i], sample)
|
|
for i in range(size)))
|
|
|
|
res = util.tree_reduce(lambda x, y: x.min(y), tmp)
|
|
|
|
print_ln('result: %s', res.reveal())
|