diff --git a/README.md b/README.md index 6fd542c0ed..64e8b8f762 100644 --- a/README.md +++ b/README.md @@ -39,9 +39,9 @@ print(y.grad) # dz/dy ### Neural networks? -It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD and Adam implemented) from tinygrad.optim, write some boilerplate minibatching code, and you have all you need. +It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, RMSprop and Adam implemented) from tinygrad.optim, write some boilerplate minibatching code, and you have all you need. -### Neural network example (from test/mnist.py) +### Neural network example (from test/test_mnist.py) ```python from tinygrad.tensor import Tensor diff --git a/test/test_mnist.py b/test/test_mnist.py index d464a39ce4..9103d1d132 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -4,7 +4,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor from tinygrad.utils import layer_init_uniform, fetch_mnist -import tinygrad.optim as tinygrad_optim +import tinygrad.optim as optim from tqdm import trange np.random.seed(1337) @@ -39,54 +39,62 @@ class TinyConvNet: class TestMNIST(unittest.TestCase): def test_mnist(self): - if os.getenv("CONV") == "1": - model = TinyConvNet() - optim = tinygrad_optim.Adam([model.c1, model.l1, model.l2], lr=0.001) - steps = 400 - else: - model = TinyBobNet() - optim = tinygrad_optim.SGD([model.l1, model.l2], lr=0.001) - steps = 1000 + def train(model, optim, steps, BS=128): + losses, accuracies = [], [] + for i in (t := trange(steps)): + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + + x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32)) + Y = Y_train[samp] + y = np.zeros((len(samp),10), np.float32) + # correct loss for NLL, torch NLL loss returns one per row + y[range(y.shape[0]),Y] = -10.0 + y = Tensor(y) + + # network + out = model.forward(x) - BS = 128 - losses, accuracies = [], [] - for i in (t := trange(steps)): - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - - x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32)) - Y = Y_train[samp] - y = np.zeros((len(samp),10), np.float32) - # correct loss for NLL, torch NLL loss returns one per row - y[range(y.shape[0]),Y] = -10.0 - y = Tensor(y) - - # network - out = model.forward(x) + # NLL loss function + loss = out.mul(y).mean() + loss.backward() + optim.step() + + cat = np.argmax(out.data, axis=1) + accuracy = (cat == Y).mean() + + # printing + loss = loss.data + losses.append(loss) + accuracies.append(accuracy) + t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) - # NLL loss function - loss = out.mul(y).mean() - loss.backward() - optim.step() - - cat = np.argmax(out.data, axis=1) - accuracy = (cat == Y).mean() - - # printing - loss = loss.data - losses.append(loss) - accuracies.append(accuracy) - t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + def evaluate(model): + def numpy_eval(): + Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32))) + Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1) + return (Y_test == Y_test_preds).mean() - # evaluate - def numpy_eval(): - Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32))) - Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1) - return (Y_test == Y_test_preds).mean() - - accuracy = numpy_eval() - print("test set accuracy is %f" % accuracy) - assert accuracy > 0.95 + accuracy = numpy_eval() + print("test set accuracy is %f" % accuracy) + assert accuracy > 0.95 + # models + model = TinyConvNet() + optimizer = optim.Adam([model.c1, model.l1, model.l2], lr=0.001) + steps = 400 + train(model, optimizer, steps) + evaluate(model) + + model = TinyBobNet() + steps = 1000 + optimizer = optim.SGD([model.l1, model.l2], lr=0.001) + train(model, optimizer, steps) + evaluate(model) + + model = TinyBobNet() + optimizer = optim.RMSprop([model.l1, model.l2], lr=0.001) + train(model, optimizer, steps) + evaluate(model) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/optim.py b/tinygrad/optim.py index 8496060394..d2465c4015 100644 --- a/tinygrad/optim.py +++ b/tinygrad/optim.py @@ -35,3 +35,19 @@ class Adam(Optimizer): vhat = self.v[i] / (1. - self.b2**self.t) t.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps) +# fill the 20% uncertainty of the above optim +class RMSprop(Optimizer): + def __init__(self, params, lr=0.001, decay=0.9, eps=1e-8): + super(RMSprop, self).__init__(params) + self.lr = lr + self.decay = decay + self.eps = eps + self.t = 0 + + self.v = [np.zeros_like(t.data) for t in self.params] + + def step(self): + self.t += 1 + for i, t in enumerate(self.params): + self.v[i] = self.decay * self.v[i] + (1 - self.decay) * np.square(t.grad) + t.data -= (self.lr / np.sqrt(self.v[i] + self.eps)) * t.grad \ No newline at end of file