From da72a0eed49b53c602cadd565777668f7bc33361 Mon Sep 17 00:00:00 2001 From: Marcel Bischoff <65973015+marcelbischoff@users.noreply.github.com> Date: Sun, 13 Dec 2020 23:45:55 -0500 Subject: [PATCH] Big MNIST model with PIL augmentation and load/save (#160) * 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up --- examples/serious_mnist.py | 155 +++++++++++++++++++++++++++++--------- extra/augment.py | 40 ++++++++++ extra/training.py | 50 ++++++++++++ test/test_mnist.py | 55 ++------------ 4 files changed, 216 insertions(+), 84 deletions(-) create mode 100644 extra/augment.py create mode 100644 extra/training.py diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py index 882dba11bd..56cc432c15 100644 --- a/examples/serious_mnist.py +++ b/examples/serious_mnist.py @@ -1,55 +1,138 @@ #!/usr/bin/env python -# see https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb +#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb import os import sys sys.path.append(os.getcwd()) sys.path.append(os.path.join(os.getcwd(), 'test')) -from tinygrad.tensor import Tensor +import numpy as np +from tinygrad.tensor import Tensor, GPU from tinygrad.nn import BatchNorm2D -import tinygrad.optim as optim from extra.utils import get_parameters +from test_mnist import fetch_mnist +from extra.training import train, evaluate +import tinygrad.optim as optim +from extra.augment import augment_img +GPU = os.getenv("GPU", None) is not None +QUICK = os.getenv("QUICK", None) is not None +DEBUG = os.getenv("DEBUG", None) is not None -# TODO: abstract this generic trainer out of the test -from test_mnist import train as train_on_mnist +class SqueezeExciteBlock2D: + def __init__(self, filters): + self.filters = filters + self.weight1 = Tensor.uniform(self.filters, self.filters//32) + self.bias1 = Tensor.uniform(1,self.filters//32) + self.weight2 = Tensor.uniform(self.filters//32, self.filters) + self.bias2 = Tensor.uniform(1, self.filters) -GPU = os.getenv("GPU") is not None + def __call__(self, input): + se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D + se = se.reshape(shape=(-1, self.filters)) + se = se.dot(self.weight1) + self.bias1 + se = se.relu() + se = se.dot(self.weight2) + self.bias2 + se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting + se = input.mul(se) + return se -class SeriousModel: +class ConvBlock: + def __init__(self, h, w, inp, filters=128, conv=3): + self.h, self.w = h, w + self.inp = inp + #init weights + self.cweights = [Tensor.uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)] + self.cbiases = [Tensor.uniform(1, filters, 1, 1) for i in range(3)] + #init layers + self._bn = BatchNorm2D(128, training=True) + self._seb = SqueezeExciteBlock2D(filters) + + def __call__(self, input): + x = input.reshape(shape=(-1, self.inp, self.w, self.h)) + for cweight, cbias in zip(self.cweights, self.cbiases): + x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu() + x = self._bn(x) + x = self._seb(x) + return x + +class BigConvNet: def __init__(self): - self.blocks = 3 - self.block_convs = 3 + self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)] + self.weight1 = Tensor.uniform(128,10) + self.weight2 = Tensor.uniform(128,10) - # TODO: raise back to 128 when it's fast - self.chans = 32 + def parameters(self): + if DEBUG: #keeping this for a moment + pars = [par for par in get_parameters(self) if par.requires_grad] + no_pars = 0 + for par in pars: + print(par.shape) + no_pars += np.prod(par.shape) + print('no of parameters', no_pars) + return pars + else: + return get_parameters(self) - self.convs = [Tensor.uniform(self.chans, self.chans if i > 0 else 1, 3, 3) for i in range(self.blocks * self.block_convs)] - self.cbias = [Tensor.uniform(1, self.chans, 1, 1) for i in range(self.blocks * self.block_convs)] - self.bn = [BatchNorm2D(self.chans, training=True) for i in range(3)] - self.fc1 = Tensor.uniform(self.chans, 10) - self.fc2 = Tensor.uniform(self.chans, 10) + def save(self, filename): + with open(filename+'.npy', 'wb') as f: + for par in get_parameters(self): + #if par.requires_grad: + np.save(f, par.cpu().data) + + def load(self, filename): + with open(filename+'.npy', 'rb') as f: + for par in get_parameters(self): + #if par.requires_grad: + try: + par.cpu().data[:] = np.load(f) + if GPU: + par.cuda() + except: + print('Could not load parameter') def forward(self, x): - x = x.reshape(shape=(-1, 1, 28, 28)) # hacks - for i in range(self.blocks): - for j in range(self.block_convs): - #print(i, j, x.shape, x.sum().cpu()) - # TODO: should padding be used? - x = x.conv2d(self.convs[i*3+j]).add(self.cbias[i*3+j]).relu() - x = self.bn[i](x) - if i > 0: - x = x.avg_pool2d(kernel_size=(2,2)) - # TODO: Add concat support to concat with max_pool2d - x1 = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, x.shape[1])) - x2 = x.max_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, x.shape[1])) - x = x1.dot(self.fc1) + x2.dot(self.fc2) - return x.logsoftmax() + x = self.conv[0](x) + x = self.conv[1](x) + x = x.avg_pool2d(kernel_size=(2,2)) + x = self.conv[2](x) + x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global + x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global + xo = x1.dot(self.weight1) + x2.dot(self.weight2) + return xo.logsoftmax() + if __name__ == "__main__": - model = SeriousModel() - params = get_parameters(model) - if GPU: - [x.cuda_() for x in params] - optimizer = optim.Adam(params, lr=0.001) - train_on_mnist(model, optimizer, steps=1875, BS=32, gpu=GPU) + lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5] + epochss = [2, 1] if QUICK else [13, 3, 3, 1] + BS = 32 + lmbd = 0.00025 + lossfn = lambda out,y: out.mul(y).mean() + lmbd*(model.weight1.abs() + model.weight2.abs()).sum() + X_train, Y_train, X_test, Y_test = fetch_mnist() + steps = len(X_train)//BS + np.random.seed(1337) + if QUICK: + steps = 1 + X_test, Y_test = X_test[:BS], Y_test[:BS] + + model = BigConvNet() + + if len(sys.argv) > 1: + try: + model.load(sys.argv[1]) + print('Loaded weights "'+sys.argv[1]+'", evaluating...') + evaluate(model, X_test, Y_test, BS=BS) + except: + print('could not load weights "'+sys.argv[1]+'".') + + if GPU: + params = get_parameters(model) + [x.cuda_() for x in params] + + for lr, epochs in zip(lrs, epochss): + optimizer = optim.Adam(model.parameters(), lr=lr) + for epoch in range(1,epochs+1): + #first epoch without augmentation + X_aug = X_train if epoch == 1 else augment_img(X_train) + train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, gpu=GPU, BS=BS) + accuracy = evaluate(model, X_test, Y_test, BS=BS) + model.save('examples/checkpoint'+str("%.0f" % (accuracy*1.0e6))) diff --git a/extra/augment.py b/extra/augment.py new file mode 100644 index 0000000000..c68205717c --- /dev/null +++ b/extra/augment.py @@ -0,0 +1,40 @@ +import numpy as np +from PIL import Image +import os +import sys +sys.path.append(os.getcwd()) +sys.path.append(os.path.join(os.getcwd(), 'test')) +from test_mnist import fetch_mnist +from tqdm import trange + +def augment_img(X, rotate=10, px=3): + Xaug = np.zeros_like(X) + for i in trange(len(X)): + im = Image.fromarray(X[i]) + im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC) + w, h = X.shape[1:] + #upper left, lower left, lower right, upper right + quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0]) + im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) + Xaug[i] = im + return Xaug + +if __name__ == "__main__": + from test_mnist import fetch_mnist + import matplotlib.pyplot as plt + X_train, Y_train, X_test, Y_test = fetch_mnist() + X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10) + fig, a = plt.subplots(2,len(X)) + Xaug = augment_img(X) + for i in range(len(X)): + a[0][i].imshow(X[i], cmap='gray') + a[1][i].imshow(Xaug[i],cmap='gray') + a[0][i].axis('off') + a[1][i].axis('off') + plt.show() + + #create some nice gifs for doc?! + for i in range(10): + im = Image.fromarray(X_train[7353+i]) + im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))] + im.save("aug"+str(i)+".gif", save_all=True, append_images=im_aug, duration=100, loop=0) diff --git a/extra/training.py b/extra/training.py new file mode 100644 index 0000000000..3ce050679f --- /dev/null +++ b/extra/training.py @@ -0,0 +1,50 @@ +import os +import numpy as np +from tqdm import trange +from extra.utils import get_parameters +from tinygrad.tensor import Tensor, GPU + +def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()): + if gpu is True: [x.cuda_() for x in get_parameters([model, optim])] + if num_classes is None: num_classes = Y_train.max().astype(int)+1 + losses, accuracies = [], [] + for i in (t := trange(steps, disable=os.getenv('CI') is not None)): + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + + x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu) + Y = Y_train[samp] + y = np.zeros((len(samp),num_classes), np.float32) + # correct loss for NLL, torch NLL loss returns one per row + y[range(y.shape[0]),Y] = -1.0*num_classes + y = Tensor(y, gpu=gpu) + + # network + out = model.forward(x) + + # NLL loss function + loss = lossfn(out, y) + optim.zero_grad() + loss.backward() + optim.step() + + cat = np.argmax(out.cpu().data, axis=1) + accuracy = (cat == Y).mean() + + # printing + loss = loss.cpu().data + losses.append(loss) + accuracies.append(accuracy) + t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + +def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128): + def numpy_eval(num_classes): + Y_test_preds_out = np.zeros((len(Y_test),num_classes)) + for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None): + Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data + Y_test_preds = np.argmax(Y_test_preds_out, axis=1) + return (Y_test == Y_test_preds).mean() + + if num_classes is None: num_classes = Y_test.max().astype(int)+1 + accuracy = numpy_eval(num_classes) + print("test set accuracy is %f" % accuracy) + return accuracy diff --git a/test/test_mnist.py b/test/test_mnist.py index 35e4cd4549..a123938d1f 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -4,8 +4,8 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, GPU import tinygrad.optim as optim +from extra.training import train, evaluate from extra.utils import fetch, get_parameters -from tqdm import trange # mnist loader def fetch_mnist(): @@ -54,47 +54,6 @@ class TinyConvNet: x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1).logsoftmax() -def train(model, optim, steps, BS=128, gpu=False): - if gpu is True: [x.cuda_() for x in get_parameters([model, optim])] - losses, accuracies = [], [] - for i in (t := trange(steps, disable=os.getenv('CI') is not None)): - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - - x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu) - Y = Y_train[samp] - y = np.zeros((len(samp),10), np.float32) - # correct loss for NLL, torch NLL loss returns one per row - y[range(y.shape[0]),Y] = -10.0 - y = Tensor(y, gpu=gpu) - - # network - out = model.forward(x) - - # NLL loss function - loss = out.mul(y).mean() - optim.zero_grad() - loss.backward() - optim.step() - - cat = np.argmax(out.cpu().data, axis=1) - accuracy = (cat == Y).mean() - - # printing - loss = loss.cpu().data - losses.append(loss) - accuracies.append(accuracy) - t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) - -def evaluate(model, gpu=False): - def numpy_eval(): - Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu() - Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1) - return (Y_test == Y_test_preds).mean() - - accuracy = numpy_eval() - print("test set accuracy is %f" % accuracy) - assert accuracy > 0.95 - class TestMNIST(unittest.TestCase): gpu=False @@ -102,22 +61,22 @@ class TestMNIST(unittest.TestCase): np.random.seed(1337) model = TinyConvNet() optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, optimizer, steps=200, gpu=self.gpu) - evaluate(model, gpu=self.gpu) + train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu) + assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 def test_sgd(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, optimizer, steps=1000, gpu=self.gpu) - evaluate(model, gpu=self.gpu) + train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu) + assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 def test_rmsprop(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.RMSprop(model.parameters(), lr=0.0002) - train(model, optimizer, steps=1000, gpu=self.gpu) - evaluate(model, gpu=self.gpu) + train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu) + assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 @unittest.skipUnless(GPU, "Requires GPU") class TestMNISTGPU(TestMNIST):