mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-04-20 03:01:31 -04:00
84 lines
2.2 KiB
Plaintext
84 lines
2.2 KiB
Plaintext
import ml
|
|
import util
|
|
import math
|
|
import sys
|
|
|
|
if len(program.args) < 2:
|
|
print('Usage: %s <net> <n_threads>' % program.args[0],
|
|
file=sys.stderr)
|
|
print('<net> refers to the letter naming in SecureNN.', file=sys.stderr)
|
|
exit(1)
|
|
|
|
program.options_from_args()
|
|
program.options.cisc = True
|
|
|
|
try:
|
|
n_threads = int(program.args[2])
|
|
except:
|
|
n_threads = None
|
|
|
|
ml.Layer.n_threads = n_threads
|
|
ml.FixConv2d.use_conv2ds = True
|
|
|
|
if 'full' in program.args:
|
|
sfix.set_precision(12, 63)
|
|
else:
|
|
sfix.set_precision(12, 31)
|
|
|
|
if program.args[1] == 'A':
|
|
layers = [
|
|
ml.Dense(1, 784, 128),
|
|
ml.Square([1, 128]),
|
|
ml.Dense(1, 128, 128),
|
|
ml.Square([1, 128]),
|
|
ml.Dense(1, 128, 10),
|
|
ml.Argmax((1, 10)),
|
|
]
|
|
elif program.args[1] == 'B':
|
|
layers = [
|
|
ml.FixConv2d([1, 28, 28, 1], (16, 5, 5, 1), (16,), [1, 24, 24, 16], (1, 1), 'VALID'),
|
|
ml.MaxPool([1, 24, 24, 16]),
|
|
ml.Relu([1, 12, 12, 16]),
|
|
ml.FixConv2d([1, 12, 12, 16], (16, 5, 5, 16), (16,), [1, 8, 8, 16], (1, 1), 'VALID'),
|
|
ml.MaxPool([1, 8, 8, 16]),
|
|
ml.Relu([1, 4, 4, 16]),
|
|
ml.Dense(1, 256, 100),
|
|
ml.Relu([1, 100]),
|
|
ml.Dense(1, 100, 10),
|
|
ml.Argmax((1, 10)),
|
|
]
|
|
elif program.args[1] == 'C':
|
|
layers = [
|
|
ml.FixConv2d([1, 28, 28, 1], (20, 5, 5, 1), (20,), [1, 24, 24, 20], (1, 1), 'VALID'),
|
|
ml.MaxPool([1, 24, 24, 20]),
|
|
ml.Relu([1, 12, 12, 20]),
|
|
ml.FixConv2d([1, 12, 12, 20], (50, 5, 5, 20), (50,), [1, 8, 8, 50], (1, 1), 'VALID'),
|
|
ml.MaxPool([1, 8, 8, 50]),
|
|
ml.Relu([1, 4, 4, 50]),
|
|
ml.Dense(1, 800, 500),
|
|
ml.Relu([1, 500]),
|
|
ml.Dense(1, 500, 10),
|
|
ml.Argmax((1, 10)),
|
|
]
|
|
elif program.args[1] == 'D':
|
|
layers = [
|
|
ml.FixConv2d([1, 28, 28, 1], (5, 5, 5, 1), (5,), [1, 14, 14, 5], (2, 2)),
|
|
ml.Relu([1, 14, 14, 5]),
|
|
ml.Dense(1, 980, 100),
|
|
ml.Relu([1, 100]),
|
|
ml.Dense(1, 100, 10),
|
|
ml.Argmax((1, 10)),
|
|
]
|
|
else:
|
|
raise Exception('unknown network: ' + program.args[1])
|
|
|
|
opt = ml.Optimizer()
|
|
opt.layers = layers
|
|
for layer in layers:
|
|
layer.input_from(0)
|
|
layers[0].X.input_from(1)
|
|
start_timer(1)
|
|
opt.forward(1)
|
|
stop_timer(1)
|
|
print_ln('guess %s', layers[-1].Y[0].reveal())
|