mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
50 lines
1.2 KiB
Plaintext
50 lines
1.2 KiB
Plaintext
# this trains a dense neural network on MNIST
|
|
|
|
program.options_from_args()
|
|
|
|
import torchvision
|
|
|
|
data = []
|
|
for train in True, False:
|
|
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
|
|
# normalize to [0,1] before input
|
|
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
|
|
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
|
|
data += [(labels, samples)]
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
net = nn.Sequential(
|
|
nn.Conv2d(1, 20, 5),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(2),
|
|
nn.Conv2d(20, 50, 5),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(2),
|
|
nn.Flatten(),
|
|
nn.ReLU(),
|
|
nn.Linear(800, 500),
|
|
nn.ReLU(),
|
|
nn.Linear(500, 10)
|
|
)
|
|
|
|
# test network
|
|
ds = torchvision.datasets.MNIST(
|
|
root='/tmp', transform=torchvision.transforms.ToTensor())
|
|
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
|
|
print(inputs.shape)
|
|
outputs = net(inputs)
|
|
|
|
from Compiler import ml
|
|
|
|
ml.set_n_threads(int(program.args[2]))
|
|
|
|
layers = ml.layers_from_torch(net, data[0][1].shape, 128)
|
|
layers[0].X = data[0][1]
|
|
layers[-1].Y = data[0][0]
|
|
|
|
optimizer = ml.SGD(layers)
|
|
optimizer.run_by_args(program, int(program.args[1]), 128,
|
|
data[1][1], data[1][0])
|