Files
MP-SPDZ/Programs/Source/regression.mpc

196 lines
4.9 KiB
Plaintext

import ml
import random
import re
program.use_trunc_pr = True
ml.set_n_threads(8)
debug = False
if 'halfprec' in program.args:
print('8-bit precision after point')
sfix.set_precision(8, 31)
cfix.set_precision(8, 31)
else:
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
sfloat.vlen = sfix.f
if 'nearest' in program.args:
sfix.round_nearest = True
n_examples = 227
n_normal = 84
n_features = 12634
if len(program.args) > 2:
if 'bc' in program.args:
print('Compiling for BC-TCGA')
n_examples = 472
n_normal = 49
n_features = 17814
n_pos = n_examples - n_normal
n_epochs = 1
if len(program.args) > 1:
n_epochs = int(program.args[1])
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
print('Using %d threads' % ml.Layer.n_threads)
n_fold = 5
test_share = 1. / n_fold
n_ex = [n_normal, n_pos]
n_tests = [int(test_share * x) for x in n_ex]
n_train = [x - y for x, y in zip(n_ex, n_tests)]
weighted = 'weighted' in program.args
if 'fast' in program.args:
N = min(n_train) * 2
elif 'mini' in program.args:
N = 32
elif weighted:
N = sum(n_train)
else:
N = max(n_train) * 2
n_test = sum(n_tests)
indices = [regint.Array(n) for n in n_ex]
indices[0].assign(list(range(n_pos, n_pos + n_normal)))
indices[1].assign(list(range(n_pos)))
test = regint.Array(n_test)
if 'quant' in program.args:
dense = ml.QuantizedDense(N, n_features, 1)
else:
dense = ml.Dense(N, n_features, 1)
layers = [dense, ml.Output(N, debug=debug)]
Y = sfix.Array(n_examples)
X = sfix.Matrix(n_examples, n_features)
Y.input_from(0)
@for_range_opt(n_features)
def _(i):
@for_range_opt(n_examples)
def _(j):
X[j][i] = sfix.get_input_from(0)
print_ln('X[0][%s] = %s', n_features - 1, X[0][n_features - 1].reveal())
print_ln('X[%s][0] = %s', n_examples - 1, X[n_examples - 1][0].reveal())
sgd = ml.SGD(layers, n_epochs, debug=debug, report_loss=True)
if 'tol' in program.args:
sgd.tol = 0.001
for arg in program.args:
m = re.match('tol=(.*)', arg)
if m:
sgd.tol = float(m.group(1))
print('Stop with tolerance', sgd.tol)
sum_acc = cfix.MemValue(0)
@for_range(100)
def _(i_run):
for idx in indices:
idx.shuffle()
@for_range(n_fold)
def _(i_fold):
i_test = regint.MemValue(0)
Xs = [sfix.Matrix(n, n_features) for n in n_train]
training = [regint.Array(n) for n in n_train]
for label in 0, 1:
i_train = regint.MemValue(0)
@for_range(len(indices[label]))
def _(i):
@if_e(i / n_tests[label] == i_fold)
def _():
test[i_test] = indices[label][i]
i_test.iadd(1)
@else_
def _():
j = indices[label][i]
training[label][i_train] = j
Xs[label][i_train] = X[j]
i_train.iadd(1)
print_ln('training on %s', [list(x) for x in training])
print_ln('testing on %s', list(test))
if 'static' in program.args or weighted:
sgd.reset()
if weighted:
if 'doublenormal' in program.args:
factor = 2
else:
factor = 1
layers[-1].set_weights([factor * n_train[1]] * n_train[0] +
[n_train[0]] * n_train[1])
n_indices = n_train
n = n_train[0]
else:
n = N / 2
assert 2 * n == N
n_indices = [n, n]
for label, idx in enumerate(training):
@for_range(n_indices[label])
def _(i):
layers[0].X[i + label * n] = X[idx[i % len(idx)]]
layers[-1].Y[i + label * n] = label
else:
sgd.reset(Xs)
sgd.run()
match = lambda x, y: (y.v >> y.f).if_else(x > 0.5, x < 0.5)
def get_acc(N, noise=False):
pos_acc = cfix.MemValue(0)
neg_acc = cfix.MemValue(0)
@for_range(N)
def _(i):
y = layers[-1].Y[i].reveal()
x = ml.sigmoid(layers[0].Y[i][0][0]).reveal()
m = match(x, y)
pos_acc.iadd(m * y)
neg_acc.iadd(m * (1 - y))
if noise:
print_ln('%s %s %s', match(x, y), y, x)
return neg_acc.read(), pos_acc.read()
print_ln('train_acc: %s', sum(get_acc(N)) / N)
@for_range(n_test)
def _(i):
j = test[i]
layers[0].X[i] = X[j]
layers[-1].Y[i] = Y[j]
sgd.forward(n_test)
print_ln('test_loss: %s', sgd.layers[-1].l.reveal())
accs = get_acc(n_test, True)
acc = sum(accs)
real_accs = [x / y for x, y in zip(accs, n_tests)]
real_acc = sum(real_accs) / 2
print_ln('test_acc: %s (%s=%s/%s %s=%s/%s)', real_acc,
real_accs[0], accs[0], n_tests[0], real_accs[1], accs[1],
n_tests[1])
sum_acc.iadd(real_acc)
#print_ln('test set: %s', test)
print_ln('average test acc: %s (%s/%s)',
sum_acc / (n_fold * (i_run + 1)), sum_acc, (n_fold * (i_run + 1)))