mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
106 lines
2.6 KiB
Plaintext
106 lines
2.6 KiB
Plaintext
# this trains network C (LeNet) from SecureNN
|
|
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
|
|
|
import ml
|
|
import math
|
|
import re
|
|
import util
|
|
|
|
program.options_from_args()
|
|
sfix.set_precision_from_args(program, adapt_ring=True)
|
|
MultiArray.disable_index_checks()
|
|
|
|
if 'profile' in program.args:
|
|
print('Compiling for profiling')
|
|
N = 1000
|
|
n_test = 100
|
|
elif 'debug' in program.args:
|
|
N = 100
|
|
n_test = 100
|
|
elif 'debug1000' in program.args:
|
|
N = 1000
|
|
n_test = 1000
|
|
elif 'debug5000' in program.args:
|
|
N = 5000
|
|
n_test = 5000
|
|
else:
|
|
N = 60000
|
|
n_test = 10000
|
|
|
|
n_examples = N
|
|
n_features = 28 ** 2
|
|
|
|
try:
|
|
n_epochs = int(program.args[1])
|
|
except:
|
|
n_epochs = 100
|
|
|
|
try:
|
|
batch_size = int(program.args[2])
|
|
except:
|
|
batch_size = min(N, 128)
|
|
|
|
if 'savemem' in program.args:
|
|
N = batch_size
|
|
else:
|
|
N = min(N, max(1000, batch_size))
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[3]))
|
|
except:
|
|
pass
|
|
|
|
ml.Layer.back_batch_size = batch_size
|
|
|
|
layers = [
|
|
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'),
|
|
ml.MaxPool([N, 24, 24, 20]),
|
|
ml.Relu([N, 12, 12, 20]),
|
|
ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'),
|
|
ml.MaxPool([N, 8, 8, 50]),
|
|
ml.Relu([N, 4, 4, 50]),
|
|
ml.Dense(N, 800, 500),
|
|
ml.Relu([N, 500]),
|
|
ml.Dense(N, 500, 10),
|
|
]
|
|
|
|
layers += [ml.MultiOutput.from_args(program, n_examples, 10)]
|
|
|
|
if 'batchnorm' in program.args:
|
|
for arg in program.args:
|
|
assert not arg.startswith('dropout')
|
|
layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args))
|
|
layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args))
|
|
|
|
if 'dropout' in program.args or 'dropout2' in program.args:
|
|
layers.insert(8, ml.Dropout(N, 500))
|
|
elif 'dropout.25' in program.args:
|
|
layers.insert(8, ml.Dropout(N, 500, alpha=0.25))
|
|
elif 'dropout.125' in program.args:
|
|
layers.insert(8, ml.Dropout(N, 500, alpha=0.125))
|
|
|
|
if 'dropout2' in program.args:
|
|
layers.insert(6, ml.Dropout(N, 800, alpha=0.125))
|
|
elif 'dropout1' in program.args:
|
|
layers.insert(6, ml.Dropout(N, 800, alpha=0.5))
|
|
|
|
if 'no_relu' in program.args:
|
|
for x in layers:
|
|
if isinstance(x, ml.Relu):
|
|
layers.remove(x)
|
|
|
|
print(layers)
|
|
|
|
Y = sint.Matrix(n_test, 10)
|
|
X = sfix.Matrix(n_test, n_features)
|
|
|
|
if not ('no_acc' in program.args and 'no_loss' in program.args):
|
|
layers[-1].Y.input_from(0)
|
|
layers[0].X.input_from(0)
|
|
Y.input_from(0)
|
|
X.input_from(0)
|
|
|
|
optim = ml.Optimizer.from_args(program, layers)
|
|
optim.summary()
|
|
optim.run_by_args(program, n_epochs, batch_size, X, Y, acc_batch_size=N)
|