mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
74 lines
1.7 KiB
Plaintext
74 lines
1.7 KiB
Plaintext
# this trains network with dense layers in 4/9 distinction
|
|
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
|
|
|
import ml
|
|
import math
|
|
import re
|
|
import util
|
|
|
|
program.options_from_args()
|
|
sfix.set_precision_from_args(program)
|
|
MultiArray.disable_index_checks()
|
|
|
|
n_examples = 11791
|
|
n_test = 1991
|
|
n_features = 28 ** 2
|
|
|
|
try:
|
|
n_epochs = int(program.args[1])
|
|
except:
|
|
n_epochs = 100
|
|
|
|
N = n_examples
|
|
batch_size = 128
|
|
|
|
assert batch_size <= N
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[2]))
|
|
except:
|
|
pass
|
|
|
|
n_inner = 128
|
|
|
|
n_dense_layers = None
|
|
for arg in program.args:
|
|
m = re.match('(.*)dense', arg)
|
|
if m:
|
|
n_dense_layers = int(m.group(1))
|
|
|
|
if n_dense_layers == 1:
|
|
layers = [ml.Dense(N, n_features, 1, activation='id')]
|
|
elif n_dense_layers > 1:
|
|
layers = [ml.Dense(N, n_features, n_inner, activation='relu')]
|
|
for i in range(n_dense_layers - 2):
|
|
layers += [ml.Dense(N, n_inner, n_inner, activation='relu')]
|
|
layers += [ml.Dense(N, n_inner, 1, activation='id')]
|
|
else:
|
|
raise CompilerError('number of dense layers not specified')
|
|
|
|
layers += [ml.Output.from_args(N, program)]
|
|
|
|
Y = sint.Array(n_test)
|
|
X = sfix.Matrix(n_test, n_features)
|
|
|
|
if not ('no_acc' in program.args and 'no_loss' in program.args):
|
|
layers[-1].Y.input_from(0)
|
|
layers[0].X.input_from(0)
|
|
Y.input_from(0)
|
|
X.input_from(0)
|
|
|
|
sgd = ml.Optimizer.from_args(program, layers)
|
|
|
|
if 'no_out' in program.args:
|
|
del sgd.layers[-1]
|
|
|
|
if 'forward' in program.args:
|
|
sgd.forward(batch=regint.Array(batch_size))
|
|
elif 'backward' in program.args:
|
|
sgd.backward(batch=regint.Array(batch_size))
|
|
elif 'update' in program.args:
|
|
sgd.update(0, batch=regint.Array(batch_size))
|
|
else:
|
|
sgd.run_by_args(program, n_epochs, batch_size, X, Y)
|