Files
MP-SPDZ/Programs/Source/torch_mnist_dense.mpc
2025-05-30 13:35:02 +10:00

59 lines
1.4 KiB
Plaintext

# this trains a dense neural network on MNIST
program.options_from_args()
sfix.set_precision_from_args(program)
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255)
labels = sint.input_tensor_via(0, ds.targets, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# test network
ds = torchvision.datasets.MNIST(
root='/tmp', transform=torchvision.transforms.ToTensor())
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
print(inputs.shape)
outputs = net(inputs)
from Compiler import ml
ml.set_n_threads(int(program.args[2]))
layers = ml.layers_from_torch(net, training_samples.shape, 128)
optimizer = ml.SGD(layers)
optimizer.fit(
training_samples,
training_labels,
epochs=int(program.args[1]),
batch_size=128,
validation_data=(test_samples, test_labels),
program=program
)
# store secret model for use in torch_mnist_dense_test
for var in optimizer.trainable_variables:
var.write_to_file()
# output to be used in Scripts/torch_mnist_lenet_import.py
optimizer.reveal_model_to_binary()