Files
MP-SPDZ/Programs/Source/keras_mnist_lenet_predict.mpc
Marcel Keller 78fe3d8bad Maintenance.
2024-07-09 12:19:52 +10:00

46 lines
1.2 KiB
Plaintext

# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
# training_samples = MultiArray([60000, 28, 28], sfix)
# training_labels = MultiArray([60000, 10], sint)
test_samples = MultiArray([1, 28, 28], sfix)
test_labels = MultiArray([1, 10], sint)
# training_labels.input_from(0)
# training_samples.input_from(0)
# test_labels.input_from(0)
# test_samples.input_from(0)
from Compiler import ml
tf = ml
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
model.build(test_samples.sizes)
start = 0
for var in model.trainable_variables:
var.assign_all(0)
# activate to use the model output by keras_mnist_lenet
# start = var.read_from_file(start)
guesses = model.predict(test_samples)
print_ln('guess %s', guesses.reveal_nested()[:3])
print_ln('truth %s', test_labels.reveal_nested()[:3])