mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 21:48:11 -05:00
90 lines
1.8 KiB
Plaintext
90 lines
1.8 KiB
Plaintext
# this trains network with dense layers in 0/1 distinction
|
|
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
|
|
|
import ml
|
|
import math
|
|
|
|
#ml.report_progress = True
|
|
|
|
program.options_from_args()
|
|
|
|
approx = 3
|
|
|
|
if 'profile' in program.args:
|
|
print('Compiling for profiling')
|
|
N = 1000
|
|
n_test = 100
|
|
elif 'debug' in program.args:
|
|
N = 10
|
|
n_test = 10
|
|
elif 'gisette' in program.args:
|
|
print('Compiling for 4/9')
|
|
N = 11791
|
|
n_test = 1991
|
|
else:
|
|
N = 12665
|
|
n_test = 2115
|
|
|
|
n_examples = N
|
|
n_features = 28 ** 2
|
|
|
|
try:
|
|
n_epochs = int(program.args[1])
|
|
except:
|
|
n_epochs = 100
|
|
|
|
try:
|
|
batch_size = int(program.args[2])
|
|
except:
|
|
batch_size = N
|
|
|
|
assert batch_size <= N
|
|
ml.Layer.back_batch_size = batch_size
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[3]))
|
|
except:
|
|
pass
|
|
|
|
if 'debug' in program.args:
|
|
n_inner = 10
|
|
n_features = 10
|
|
else:
|
|
n_inner = 128
|
|
|
|
if 'norelu' in program.args:
|
|
activation = 'id'
|
|
else:
|
|
activation = 'relu'
|
|
|
|
layers = [ml.Dense(N, n_features, n_inner, activation=activation),
|
|
ml.Dense(N, n_inner, n_inner, activation=activation),
|
|
ml.Dense(N, n_inner, 1),
|
|
ml.Output(N, approx=approx)]
|
|
|
|
if '2dense' in program.args:
|
|
del layers[1]
|
|
|
|
layers[-1].Y.input_from(0)
|
|
layers[0].X.input_from(0)
|
|
|
|
Y = sint.Array(n_test)
|
|
X = sfix.Matrix(n_test, n_features)
|
|
Y.input_from(0)
|
|
X.input_from(0)
|
|
|
|
sgd = ml.SGD(layers, 10, report_loss=True)
|
|
sgd.reset()
|
|
|
|
@for_range(int(math.ceil(n_epochs / 10)))
|
|
def _(i):
|
|
start_timer(1)
|
|
sgd.run(batch_size)
|
|
stop_timer(1)
|
|
|
|
n_correct, loss = sgd.reveal_correctness(layers[0].X, layers[-1].Y)
|
|
print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N)
|
|
|
|
n_correct, loss = sgd.reveal_correctness(X, Y)
|
|
print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test)
|