mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
63 lines
1.7 KiB
Plaintext
63 lines
1.7 KiB
Plaintext
# this trains LeNet on MNIST with a dropout layer
|
|
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
|
|
|
program.options_from_args()
|
|
|
|
from Compiler import ml
|
|
tf = ml
|
|
ml.set_n_threads(36)
|
|
|
|
try:
|
|
ml.set_n_threads(int(program.args[1]))
|
|
except:
|
|
pass
|
|
|
|
if 'torch' in program.args:
|
|
import torchvision, numpy
|
|
data = []
|
|
for train in True, False:
|
|
ds = torchvision.datasets.CIFAR10(root='/tmp', train=train, download=True)
|
|
# normalize to [-1,1] before input
|
|
samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
|
|
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
|
|
data += [(labels, samples)]
|
|
|
|
(training_labels, training_samples), (test_labels, test_samples) = data
|
|
else:
|
|
training_samples = MultiArray([50000, 32, 32, 3], sfix)
|
|
training_labels = MultiArray([50000, 10], sint)
|
|
|
|
test_samples = MultiArray([10000, 32, 32, 3], sfix)
|
|
test_labels = MultiArray([10000, 10], sint)
|
|
|
|
training_labels.input_from(0)
|
|
training_samples.input_from(0)
|
|
|
|
test_labels.input_from(0)
|
|
test_samples.input_from(0)
|
|
|
|
layers = [
|
|
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
|
|
tf.keras.layers.MaxPooling2D(2),
|
|
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
|
|
tf.keras.layers.MaxPooling2D(2),
|
|
tf.keras.layers.Flatten(),
|
|
tf.keras.layers.Dropout(0.5),
|
|
tf.keras.layers.Dense(500, activation='relu'),
|
|
tf.keras.layers.Dense(10, activation='softmax')
|
|
]
|
|
|
|
model = tf.keras.models.Sequential(layers)
|
|
|
|
optim = tf.keras.optimizers.Adam(amsgrad=True)
|
|
|
|
model.compile(optimizer=optim)
|
|
|
|
opt = model.fit(
|
|
training_samples,
|
|
training_labels,
|
|
epochs=10,
|
|
batch_size=128,
|
|
validation_data=(test_samples, test_labels)
|
|
)
|