mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 21:48:11 -05:00
52 lines
1017 B
Plaintext
52 lines
1017 B
Plaintext
import ml
|
|
import random
|
|
import re
|
|
|
|
program.use_trunc_pr = True
|
|
sfix.round_nearest = True
|
|
|
|
sfix.set_precision(16, 31)
|
|
cfix.set_precision(16, 31)
|
|
|
|
N = int(program.args[1])
|
|
n_features = int(program.args[2])
|
|
|
|
n_threads = None
|
|
|
|
for arg in program.args:
|
|
m = re.match('n_threads=(.*)', arg)
|
|
if m:
|
|
n_threads = int(m.group(1))
|
|
|
|
program.allocated_mem['s'] = 1 + n_features
|
|
|
|
b = sfix.load_mem(0)
|
|
W = sfix.load_mem(1, size=n_features)
|
|
|
|
#sint.load_mem(100).reveal().print_reg()
|
|
|
|
dense = ml.Dense(N, n_features, 1)
|
|
dense.b[0] = b
|
|
dense.W.assign_vector(W)
|
|
|
|
print_ln('b=%s W[-1]=%s', dense.b[0].reveal(),
|
|
dense.W[n_features - 1][0][0].reveal())
|
|
|
|
@for_range_opt_multithread(n_threads, n_features)
|
|
def _(i):
|
|
@for_range_opt(N)
|
|
def _(j):
|
|
dense.X[j][0][i] = sfix.get_input_from(0)
|
|
|
|
batch = regint.Array(N)
|
|
batch.assign(regint.inc(N))
|
|
dense.forward(batch)
|
|
|
|
print_str('predictions: ')
|
|
|
|
@for_range(N)
|
|
def _(i):
|
|
pred = ml.sigmoid(dense.Y[i][0][0])
|
|
print_str('%s', pred.reveal() >= 0.5)
|
|
print_ln()
|