Files
MP-SPDZ/Programs/Source/mnist_A.mpc
Marcel Keller 78fe3d8bad Maintenance.
2024-07-09 12:19:52 +10:00

90 lines
1.8 KiB
Plaintext

# this trains network with dense layers in 0/1 distinction
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
import ml
import math
#ml.report_progress = True
program.options_from_args()
approx = 3
if 'profile' in program.args:
print('Compiling for profiling')
N = 1000
n_test = 100
elif 'debug' in program.args:
N = 10
n_test = 10
elif 'gisette' in program.args:
print('Compiling for 4/9')
N = 11791
n_test = 1991
else:
N = 12665
n_test = 2115
n_examples = N
n_features = 28 ** 2
try:
n_epochs = int(program.args[1])
except:
n_epochs = 100
try:
batch_size = int(program.args[2])
except:
batch_size = N
assert batch_size <= N
ml.Layer.back_batch_size = batch_size
try:
ml.set_n_threads(int(program.args[3]))
except:
pass
if 'debug' in program.args:
n_inner = 10
n_features = 10
else:
n_inner = 128
if 'norelu' in program.args:
activation = 'id'
else:
activation = 'relu'
layers = [ml.Dense(N, n_features, n_inner, activation=activation),
ml.Dense(N, n_inner, n_inner, activation=activation),
ml.Dense(N, n_inner, 1),
ml.Output(N, approx=approx)]
if '2dense' in program.args:
del layers[1]
layers[-1].Y.input_from(0)
layers[0].X.input_from(0)
Y = sint.Array(n_test)
X = sfix.Matrix(n_test, n_features)
Y.input_from(0)
X.input_from(0)
sgd = ml.SGD(layers, 10, report_loss=True)
sgd.reset()
@for_range(int(math.ceil(n_epochs / 10)))
def _(i):
start_timer(1)
sgd.run(batch_size)
stop_timer(1)
n_correct, loss = sgd.reveal_correctness(layers[0].X, layers[-1].Y)
print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N)
n_correct, loss = sgd.reveal_correctness(X, Y)
print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test)