mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
52 lines
1.1 KiB
Plaintext
52 lines
1.1 KiB
Plaintext
import ml
|
|
import random
|
|
|
|
program.use_trunc_pr = True
|
|
sfix.round_nearest = True
|
|
|
|
sfix.set_precision(16, 31)
|
|
cfix.set_precision(16, 31)
|
|
sfloat.vlen = sfix.f
|
|
|
|
n_epochs = 200
|
|
|
|
n_normal = int(program.args[1])
|
|
n_pos = int(program.args[2])
|
|
n_features = int(program.args[3])
|
|
|
|
debug = 'debug' in program.args
|
|
|
|
n_examples = n_normal + n_pos
|
|
N = max(n_normal, n_pos) * 2
|
|
|
|
X_normal = sfix.Matrix(n_normal, n_features)
|
|
X_pos = sfix.Matrix(n_pos, n_features)
|
|
|
|
@for_range_opt(n_features)
|
|
def _(i):
|
|
@for_range_opt(n_normal)
|
|
def _(j):
|
|
X_normal[j][i] = sfix.get_input_from(0)
|
|
@for_range_opt(n_pos)
|
|
def _(j):
|
|
X_pos[j][i] = sfix.get_input_from(0)
|
|
|
|
dense = ml.Dense(N, n_features, 1)
|
|
layers = [dense, ml.Output(N)]
|
|
|
|
sgd = ml.SGD(layers, n_epochs, report_loss=debug)
|
|
sgd.reset([X_normal, X_pos])
|
|
sgd.run()
|
|
|
|
if debug:
|
|
@for_range(N)
|
|
def _(i):
|
|
print_ln('%s %s', layers[-1].Y[i].reveal(),
|
|
ml.sigmoid(layers[-1].X[i]).reveal())
|
|
|
|
layers[0].b[0].store_in_mem(0)
|
|
layers[0].W.get_vector().store_in_mem(1)
|
|
|
|
print_ln('b=%s W[-1]=%s', layers[0].b[0].reveal(),
|
|
layers[0].W[n_features - 1][0][0].reveal())
|