Files
MP-SPDZ/Programs/Source/torch_mnist_dense_pretrain.mpc
2023-02-16 12:35:18 +11:00

73 lines
2.1 KiB
Plaintext

# this trains a dense neural network on MNIST
program.options_from_args()
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# train for a bit
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()])
ds = torchvision.datasets.MNIST(root='/tmp', transform=transform, train=True)
optimizer = torch.optim.Adam(net.parameters(), amsgrad=True)
criterion = nn.CrossEntropyLoss()
for i, data in enumerate(torch.utils.data.DataLoader(ds, batch_size=128)):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
with torch.no_grad():
ds = torchvision.datasets.MNIST(root='/tmp', transform=transform,
train=False)
total = correct_classified = 0
for data in torch.utils.data.DataLoader(ds, batch_size=128):
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct_classified += (predicted == labels).sum().item()
test_acc = (100 * correct_classified / total)
print('Test accuracy of the network: %.2f %%' % test_acc)
from Compiler import ml
ml.set_n_threads(int(program.args[2]))
layers = ml.layers_from_torch(net, training_samples.shape, 128, input_via=0)
optimizer = ml.SGD(layers)
optimizer.fit(
training_samples,
training_labels,
epochs=int(program.args[1]),
batch_size=128,
validation_data=(test_samples, test_labels),
program=program,
reset=False
)