Files
MP-SPDZ/Programs/Source/torch_mnist_lenet_avgpool.mpc
Marcel Keller 6cc3fccef0 Maintenance.
2023-05-09 14:50:53 +10:00

50 lines
1.2 KiB
Plaintext

# this trains a dense neural network on MNIST
program.options_from_args()
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Flatten(),
nn.ReLU(),
nn.Linear(800, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
# test network
ds = torchvision.datasets.MNIST(
root='/tmp', transform=torchvision.transforms.ToTensor())
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
print(inputs.shape)
outputs = net(inputs)
from Compiler import ml
ml.set_n_threads(int(program.args[2]))
layers = ml.layers_from_torch(net, data[0][1].shape, 128)
layers[0].X = data[0][1]
layers[-1].Y = data[0][0]
optimizer = ml.SGD(layers)
optimizer.run_by_args(program, int(program.args[1]), 128,
data[1][1], data[1][0])