Files
MP-SPDZ/Programs/Source/torch_cifar_lenet_pretrain.mpc
2023-02-16 12:35:18 +11:00

82 lines
2.2 KiB
Plaintext

# this trains LeNet on CIFAR-10 on a model pretrained in cleartext
program.options_from_args()
from Compiler import ml
try:
ml.set_n_threads(int(program.args[2]))
except:
pass
get_data = lambda train, transform=None: torchvision.datasets.CIFAR10(
root='/tmp', train=train, download=True, transform=transform)
import torchvision, numpy
data = []
for train in True, False:
ds = get_data(train)
# normalize to [-1,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255 * 2 - 1, binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(3, 20, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.ReLU(),
nn.Linear(1250, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
# train for a bit
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(), lambda x: 2 * x - 1])
ds = get_data(train=True, transform=transform)
optimizer = torch.optim.Adam(net.parameters(), amsgrad=True)
criterion = nn.CrossEntropyLoss()
for i, data in enumerate(torch.utils.data.DataLoader(ds, batch_size=128)):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
with torch.no_grad():
ds = get_data(False, transform)
total = correct_classified = 0
for data in torch.utils.data.DataLoader(ds, batch_size=128):
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct_classified += (predicted == labels).sum().item()
test_acc = (100 * correct_classified / total)
print('Cleartext test accuracy of the network: %.2f %%' % test_acc)
layers = ml.layers_from_torch(net, training_samples.shape, 128, input_via=0)
optimizer = ml.SGD(layers)
optimizer.fit(
training_samples,
training_labels,
epochs=int(program.args[1]),
batch_size=128,
validation_data=(test_samples, test_labels),
program=program,
reset=False
)