diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py new file mode 100644 index 0000000000..4fbee4f79e --- /dev/null +++ b/examples/serious_mnist.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# see https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb +import os +import sys +sys.path.append(os.getcwd()) +sys.path.append(os.path.join(os.getcwd(), 'test')) + +from tinygrad.tensor import Tensor +from tinygrad.nn import BatchNorm2D +from tinygrad.utils import get_parameters +import tinygrad.optim as optim + +# TODO: abstract this generic trainer out of the test +from test_mnist import train as train_on_mnist + +GPU = os.getenv("GPU") is not None + +class SeriousModel: + def __init__(self): + self.blocks = 3 + self.block_convs = 3 + self.chans = 128 + + self.convs = [Tensor.uniform(self.chans, self.chans if i > 0 else 1, 3, 3) for i in range(self.blocks * self.block_convs)] + # TODO: Make batchnorm work at train time + #self.bn = [BatchNorm2D(self.chans) for i in range(3)] + self.fc = Tensor.uniform(self.chans, 10) + + def forward(self, x): + x = x.reshape(shape=(-1, 1, 28, 28)) # hacks + for i in range(self.blocks): + for j in range(self.block_convs): + # TODO: should padding be used? + x = x.conv2d(self.convs[i*3+j]).relu() + #x = self.bn[i](x) + if i > 0: + x = x.avg_pool2d(kernel_size=(2,2)) + # TODO: Add concat support to concat with max_pool2d + x = x.avg_pool2d(kernel_size=x.shape[2:4]) + x = x.reshape(shape=(-1, x.shape[1])) + x = x.dot(self.fc).logsoftmax() + return x + +if __name__ == "__main__": + model = SeriousModel() + params = get_parameters(model) + print(len(params)) + if GPU: + [x.cuda_() for x in params] + optimizer = optim.Adam(params, lr=0.001) + train_on_mnist(model, optimizer, steps=1875, BS=32, gpu=GPU) + diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 73de62f033..a1cc500157 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -16,6 +16,8 @@ def fetch(url): return dat def get_parameters(model): + if isinstance(model, Tensor): + return [model] parameters = [] if hasattr(model, '__dict__'): for k,v in model.__dict__.items():