Files
MP-SPDZ/Programs/Source/mnist_49.mpc
Marcel Keller 32950fe8d4 Maintenance.
2021-11-04 16:24:34 +11:00

74 lines
1.7 KiB
Plaintext

# this trains network with dense layers in 4/9 distinction
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
import ml
import math
import re
import util
program.options_from_args()
sfix.set_precision_from_args(program)
MultiArray.disable_index_checks()
n_examples = 11791
n_test = 1991
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
N = n_examples
batch_size = 128
assert batch_size <= N
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
n_inner = 128
n_dense_layers = None
for arg in program.args:
m = re.match('(.*)dense', arg)
if m:
n_dense_layers = int(m.group(1))
if n_dense_layers == 1:
layers = [ml.Dense(N, n_features, 1, activation='id')]
elif n_dense_layers > 1:
layers = [ml.Dense(N, n_features, n_inner, activation='relu')]
for i in range(n_dense_layers - 2):
layers += [ml.Dense(N, n_inner, n_inner, activation='relu')]
layers += [ml.Dense(N, n_inner, 1, activation='id')]
else:
raise CompilerError('number of dense layers not specified')
layers += [ml.Output.from_args(N, program)]
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
if not ('no_acc' in program.args and 'no_loss' in program.args):
layers[-1].Y.input_from(0)
layers[0].X.input_from(0)
Y.input_from(0)
X.input_from(0)
sgd = ml.Optimizer.from_args(program, layers)
if 'no_out' in program.args:
del sgd.layers[-1]
if 'forward' in program.args:
sgd.forward(batch=regint.Array(batch_size))
elif 'backward' in program.args:
sgd.backward(batch=regint.Array(batch_size))
elif 'update' in program.args:
sgd.update(0, batch=regint.Array(batch_size))
else:
sgd.run_by_args(program, n_epochs, batch_size, X, Y)