mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Big MNIST model with PIL augmentation and load/save (#160)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up
This commit is contained in:
@@ -1,55 +1,138 @@
|
||||
#!/usr/bin/env python
|
||||
# see https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.nn import BatchNorm2D
|
||||
import tinygrad.optim as optim
|
||||
from extra.utils import get_parameters
|
||||
from test_mnist import fetch_mnist
|
||||
from extra.training import train, evaluate
|
||||
import tinygrad.optim as optim
|
||||
from extra.augment import augment_img
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
QUICK = os.getenv("QUICK", None) is not None
|
||||
DEBUG = os.getenv("DEBUG", None) is not None
|
||||
|
||||
# TODO: abstract this generic trainer out of the test
|
||||
from test_mnist import train as train_on_mnist
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
self.filters = filters
|
||||
self.weight1 = Tensor.uniform(self.filters, self.filters//32)
|
||||
self.bias1 = Tensor.uniform(1,self.filters//32)
|
||||
self.weight2 = Tensor.uniform(self.filters//32, self.filters)
|
||||
self.bias2 = Tensor.uniform(1, self.filters)
|
||||
|
||||
GPU = os.getenv("GPU") is not None
|
||||
def __call__(self, input):
|
||||
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
||||
se = se.reshape(shape=(-1, self.filters))
|
||||
se = se.dot(self.weight1) + self.bias1
|
||||
se = se.relu()
|
||||
se = se.dot(self.weight2) + self.bias2
|
||||
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
|
||||
se = input.mul(se)
|
||||
return se
|
||||
|
||||
class SeriousModel:
|
||||
class ConvBlock:
|
||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||
self.h, self.w = h, w
|
||||
self.inp = inp
|
||||
#init weights
|
||||
self.cweights = [Tensor.uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
||||
self.cbiases = [Tensor.uniform(1, filters, 1, 1) for i in range(3)]
|
||||
#init layers
|
||||
self._bn = BatchNorm2D(128, training=True)
|
||||
self._seb = SqueezeExciteBlock2D(filters)
|
||||
|
||||
def __call__(self, input):
|
||||
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
|
||||
for cweight, cbias in zip(self.cweights, self.cbiases):
|
||||
x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
|
||||
x = self._bn(x)
|
||||
x = self._seb(x)
|
||||
return x
|
||||
|
||||
class BigConvNet:
|
||||
def __init__(self):
|
||||
self.blocks = 3
|
||||
self.block_convs = 3
|
||||
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
||||
self.weight1 = Tensor.uniform(128,10)
|
||||
self.weight2 = Tensor.uniform(128,10)
|
||||
|
||||
# TODO: raise back to 128 when it's fast
|
||||
self.chans = 32
|
||||
def parameters(self):
|
||||
if DEBUG: #keeping this for a moment
|
||||
pars = [par for par in get_parameters(self) if par.requires_grad]
|
||||
no_pars = 0
|
||||
for par in pars:
|
||||
print(par.shape)
|
||||
no_pars += np.prod(par.shape)
|
||||
print('no of parameters', no_pars)
|
||||
return pars
|
||||
else:
|
||||
return get_parameters(self)
|
||||
|
||||
self.convs = [Tensor.uniform(self.chans, self.chans if i > 0 else 1, 3, 3) for i in range(self.blocks * self.block_convs)]
|
||||
self.cbias = [Tensor.uniform(1, self.chans, 1, 1) for i in range(self.blocks * self.block_convs)]
|
||||
self.bn = [BatchNorm2D(self.chans, training=True) for i in range(3)]
|
||||
self.fc1 = Tensor.uniform(self.chans, 10)
|
||||
self.fc2 = Tensor.uniform(self.chans, 10)
|
||||
def save(self, filename):
|
||||
with open(filename+'.npy', 'wb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
np.save(f, par.cpu().data)
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename+'.npy', 'rb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
try:
|
||||
par.cpu().data[:] = np.load(f)
|
||||
if GPU:
|
||||
par.cuda()
|
||||
except:
|
||||
print('Could not load parameter')
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
|
||||
for i in range(self.blocks):
|
||||
for j in range(self.block_convs):
|
||||
#print(i, j, x.shape, x.sum().cpu())
|
||||
# TODO: should padding be used?
|
||||
x = x.conv2d(self.convs[i*3+j]).add(self.cbias[i*3+j]).relu()
|
||||
x = self.bn[i](x)
|
||||
if i > 0:
|
||||
x = x.avg_pool2d(kernel_size=(2,2))
|
||||
# TODO: Add concat support to concat with max_pool2d
|
||||
x1 = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, x.shape[1]))
|
||||
x2 = x.max_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, x.shape[1]))
|
||||
x = x1.dot(self.fc1) + x2.dot(self.fc2)
|
||||
return x.logsoftmax()
|
||||
x = self.conv[0](x)
|
||||
x = self.conv[1](x)
|
||||
x = x.avg_pool2d(kernel_size=(2,2))
|
||||
x = self.conv[2](x)
|
||||
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
|
||||
return xo.logsoftmax()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = SeriousModel()
|
||||
params = get_parameters(model)
|
||||
if GPU:
|
||||
[x.cuda_() for x in params]
|
||||
optimizer = optim.Adam(params, lr=0.001)
|
||||
train_on_mnist(model, optimizer, steps=1875, BS=32, gpu=GPU)
|
||||
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
|
||||
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
|
||||
BS = 32
|
||||
|
||||
lmbd = 0.00025
|
||||
lossfn = lambda out,y: out.mul(y).mean() + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
steps = len(X_train)//BS
|
||||
np.random.seed(1337)
|
||||
if QUICK:
|
||||
steps = 1
|
||||
X_test, Y_test = X_test[:BS], Y_test[:BS]
|
||||
|
||||
model = BigConvNet()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
model.load(sys.argv[1])
|
||||
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
|
||||
evaluate(model, X_test, Y_test, BS=BS)
|
||||
except:
|
||||
print('could not load weights "'+sys.argv[1]+'".')
|
||||
|
||||
if GPU:
|
||||
params = get_parameters(model)
|
||||
[x.cuda_() for x in params]
|
||||
|
||||
for lr, epochs in zip(lrs, epochss):
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
for epoch in range(1,epochs+1):
|
||||
#first epoch without augmentation
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, gpu=GPU, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save('examples/checkpoint'+str("%.0f" % (accuracy*1.0e6)))
|
||||
|
||||
40
extra/augment.py
Normal file
40
extra/augment.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
from test_mnist import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
Xaug = np.zeros_like(X)
|
||||
for i in trange(len(X)):
|
||||
im = Image.fromarray(X[i])
|
||||
im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC)
|
||||
w, h = X.shape[1:]
|
||||
#upper left, lower left, lower right, upper right
|
||||
quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0])
|
||||
im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
|
||||
Xaug[i] = im
|
||||
return Xaug
|
||||
|
||||
if __name__ == "__main__":
|
||||
from test_mnist import fetch_mnist
|
||||
import matplotlib.pyplot as plt
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10)
|
||||
fig, a = plt.subplots(2,len(X))
|
||||
Xaug = augment_img(X)
|
||||
for i in range(len(X)):
|
||||
a[0][i].imshow(X[i], cmap='gray')
|
||||
a[1][i].imshow(Xaug[i],cmap='gray')
|
||||
a[0][i].axis('off')
|
||||
a[1][i].axis('off')
|
||||
plt.show()
|
||||
|
||||
#create some nice gifs for doc?!
|
||||
for i in range(10):
|
||||
im = Image.fromarray(X_train[7353+i])
|
||||
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
|
||||
im.save("aug"+str(i)+".gif", save_all=True, append_images=im_aug, duration=100, loop=0)
|
||||
50
extra/training.py
Normal file
50
extra/training.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from extra.utils import get_parameters
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()):
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
|
||||
if num_classes is None: num_classes = Y_train.max().astype(int)+1
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((len(samp),num_classes), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),Y] = -1.0*num_classes
|
||||
y = Tensor(y, gpu=gpu)
|
||||
|
||||
# network
|
||||
out = model.forward(x)
|
||||
|
||||
# NLL loss function
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# printing
|
||||
loss = loss.cpu().data
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128):
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
accuracy = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
return accuracy
|
||||
@@ -4,8 +4,8 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
import tinygrad.optim as optim
|
||||
from extra.training import train, evaluate
|
||||
from extra.utils import fetch, get_parameters
|
||||
from tqdm import trange
|
||||
|
||||
# mnist loader
|
||||
def fetch_mnist():
|
||||
@@ -54,47 +54,6 @@ class TinyConvNet:
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
def train(model, optim, steps, BS=128, gpu=False):
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((len(samp),10), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),Y] = -10.0
|
||||
y = Tensor(y, gpu=gpu)
|
||||
|
||||
# network
|
||||
out = model.forward(x)
|
||||
|
||||
# NLL loss function
|
||||
loss = out.mul(y).mean()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# printing
|
||||
loss = loss.cpu().data
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, gpu=False):
|
||||
def numpy_eval():
|
||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu()
|
||||
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
accuracy = numpy_eval()
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
assert accuracy > 0.95
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
gpu=False
|
||||
|
||||
@@ -102,22 +61,22 @@ class TestMNIST(unittest.TestCase):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=200, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=1000, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
|
||||
def test_rmsprop(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
|
||||
train(model, optimizer, steps=1000, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
|
||||
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestMNISTGPU(TestMNIST):
|
||||
|
||||
Reference in New Issue
Block a user