diff --git a/test/mnist.py b/test/mnist.py index ffcead0a24..686b6fcfc9 100644 --- a/test/mnist.py +++ b/test/mnist.py @@ -1,32 +1,17 @@ #!/usr/bin/env python import numpy as np from tinygrad.tensor import Tensor +from tinygrad.nn import layer_init, SGD +from tinygrad.utils import fetch_mnist + from tqdm import trange # load the mnist dataset -def fetch(url): - import requests, gzip, os, hashlib, numpy - fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest()) - if not os.path.isfile(fp): - with open(fp, "rb") as f: - dat = f.read() - else: - with open(fp, "wb") as f: - dat = requests.get(url).content - f.write(dat) - return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy() -X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) -Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:] -X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) -Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:] +X_train, Y_train, X_test, Y_test = fetch_mnist() # train a model -def layer_init(m, h): - ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h) - return ret.astype(np.float32) - class TinyBobNet: def __init__(self): self.l1 = Tensor(layer_init(784, 128)) @@ -37,14 +22,6 @@ class TinyBobNet: # optimizer -class SGD: - def __init__(self, tensors, lr): - self.tensors = tensors - self.lr = lr - - def step(self): - for t in self.tensors: - t.data -= self.lr * t.grad model = TinyBobNet() optim = SGD([model.l1, model.l2], lr=0.01) diff --git a/tinygrad/nn.py b/tinygrad/nn.py new file mode 100644 index 0000000000..e3f6e95f62 --- /dev/null +++ b/tinygrad/nn.py @@ -0,0 +1,15 @@ +import numpy as np + +def layer_init(m, h): + ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h) + return ret.astype(np.float32) + +class SGD: + def __init__(self, tensors, lr): + self.tensors = tensors + self.lr = lr + + def step(self): + for t in self.tensors: + t.data -= self.lr * t.grad + diff --git a/tinygrad/utils.py b/tinygrad/utils.py new file mode 100644 index 0000000000..bbef87f6ee --- /dev/null +++ b/tinygrad/utils.py @@ -0,0 +1,17 @@ +def fetch_mnist(): + def fetch(url): + import requests, gzip, os, hashlib, numpy + fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest()) + if not os.path.isfile(fp): + with open(fp, "rb") as f: + dat = f.read() + else: + with open(fp, "wb") as f: + dat = requests.get(url).content + f.write(dat) + return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy() + X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) + Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:] + X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) + Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:] + return X_train, Y_train, X_test, Y_test