Files
MP-SPDZ/Programs/Source/keras_mnist_lenet_avgpool.mpc
Marcel Keller 6cc3fccef0 Maintenance.
2023-05-09 14:50:53 +10:00

73 lines
1.9 KiB
Plaintext

# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
if 'torch' in program.args:
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
else:
training_samples = sfix.Tensor([60000, 28, 28])
training_labels = sint.Tensor([60000, 10])
test_samples = sfix.Tensor([10000, 28, 28])
test_labels = sint.Tensor([10000, 10])
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
from Compiler import ml
tf = ml
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
optim = tf.keras.optimizers.Adam(amsgrad=True)
model.compile(optimizer=optim)
opt = model.fit(
training_samples,
training_labels,
epochs=10,
batch_size=128,
validation_data=(test_samples, test_labels)
)
for var in model.trainable_variables:
var.write_to_file()