Files
MP-SPDZ/Programs/Source/benchmark_net.mpc
Marcel Keller c597554af9 ATLAS.
2021-08-06 18:25:27 +10:00

84 lines
2.2 KiB
Plaintext

import ml
import util
import math
import sys
if len(program.args) < 2:
print('Usage: %s <net> <n_threads>' % program.args[0],
file=sys.stderr)
print('<net> refers to the letter naming in SecureNN.', file=sys.stderr)
exit(1)
program.options_from_args()
program.options.cisc = True
try:
n_threads = int(program.args[2])
except:
n_threads = None
ml.Layer.n_threads = n_threads
ml.FixConv2d.use_conv2ds = True
if 'full' in program.args:
sfix.set_precision(12, 63)
else:
sfix.set_precision(12, 31)
if program.args[1] == 'A':
layers = [
ml.Dense(1, 784, 128),
ml.Square([1, 128]),
ml.Dense(1, 128, 128),
ml.Square([1, 128]),
ml.Dense(1, 128, 10),
ml.Argmax((1, 10)),
]
elif program.args[1] == 'B':
layers = [
ml.FixConv2d([1, 28, 28, 1], (16, 5, 5, 1), (16,), [1, 24, 24, 16], (1, 1), 'VALID'),
ml.MaxPool([1, 24, 24, 16]),
ml.Relu([1, 12, 12, 16]),
ml.FixConv2d([1, 12, 12, 16], (16, 5, 5, 16), (16,), [1, 8, 8, 16], (1, 1), 'VALID'),
ml.MaxPool([1, 8, 8, 16]),
ml.Relu([1, 4, 4, 16]),
ml.Dense(1, 256, 100),
ml.Relu([1, 100]),
ml.Dense(1, 100, 10),
ml.Argmax((1, 10)),
]
elif program.args[1] == 'C':
layers = [
ml.FixConv2d([1, 28, 28, 1], (20, 5, 5, 1), (20,), [1, 24, 24, 20], (1, 1), 'VALID'),
ml.MaxPool([1, 24, 24, 20]),
ml.Relu([1, 12, 12, 20]),
ml.FixConv2d([1, 12, 12, 20], (50, 5, 5, 20), (50,), [1, 8, 8, 50], (1, 1), 'VALID'),
ml.MaxPool([1, 8, 8, 50]),
ml.Relu([1, 4, 4, 50]),
ml.Dense(1, 800, 500),
ml.Relu([1, 500]),
ml.Dense(1, 500, 10),
ml.Argmax((1, 10)),
]
elif program.args[1] == 'D':
layers = [
ml.FixConv2d([1, 28, 28, 1], (5, 5, 5, 1), (5,), [1, 14, 14, 5], (2, 2)),
ml.Relu([1, 14, 14, 5]),
ml.Dense(1, 980, 100),
ml.Relu([1, 100]),
ml.Dense(1, 100, 10),
ml.Argmax((1, 10)),
]
else:
raise Exception('unknown network: ' + program.args[1])
opt = ml.Optimizer()
opt.layers = layers
for layer in layers:
layer.input_from(0)
layers[0].X.input_from(1)
start_timer(1)
opt.forward(1)
stop_timer(1)
print_ln('guess %s', layers[-1].Y[0].reveal())