Files
MP-SPDZ/Programs/Source/torch_cifar_lenet.mpc
2023-02-16 12:35:18 +11:00

58 lines
1.3 KiB
Plaintext

# this trains LeNet on CIFAR-10
program.options_from_args()
from Compiler import ml
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
import torchvision, numpy
data = []
for train in True, False:
ds = torchvision.datasets.CIFAR10(root='/tmp', train=train, download=True)
# normalize to [-1,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(3, 20, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.ReLU(),
nn.Linear(1250, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
# test network
ds = torchvision.datasets.CIFAR10(
root='/tmp', transform=torchvision.transforms.ToTensor())
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
print(inputs.shape)
outputs = net(inputs)
layers = ml.layers_from_torch(net, training_samples.shape, 128)
optimizer = ml.SGD(layers)
optimizer.fit(
training_samples,
training_labels,
epochs=int(program.args[1]),
batch_size=128,
validation_data=(test_samples, test_labels),
program=program
)