Files
MP-SPDZ/Programs/Source/idash_predict.mpc
2020-04-02 09:09:45 +02:00

52 lines
1017 B
Plaintext

import ml
import random
import re
program.use_trunc_pr = True
sfix.round_nearest = True
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
N = int(program.args[1])
n_features = int(program.args[2])
n_threads = None
for arg in program.args:
m = re.match('n_threads=(.*)', arg)
if m:
n_threads = int(m.group(1))
program.allocated_mem['s'] = 1 + n_features
b = sfix.load_mem(0)
W = sfix.load_mem(1, size=n_features)
#sint.load_mem(100).reveal().print_reg()
dense = ml.Dense(N, n_features, 1)
dense.b[0] = b
dense.W.assign_vector(W)
print_ln('b=%s W[-1]=%s', dense.b[0].reveal(),
dense.W[n_features - 1][0][0].reveal())
@for_range_opt_multithread(n_threads, n_features)
def _(i):
@for_range_opt(N)
def _(j):
dense.X[j][0][i] = sfix.get_input_from(0)
batch = regint.Array(N)
batch.assign(regint.inc(N))
dense.forward(batch)
print_str('predictions: ')
@for_range(N)
def _(i):
pred = ml.sigmoid(dense.Y[i][0][0])
print_str('%s', pred.reveal() >= 0.5)
print_ln()