refactor into a few files

This commit is contained in:
George Hotz
2020-10-18 13:30:25 -07:00
parent 118c2eebe3
commit 92fd23df66
3 changed files with 36 additions and 27 deletions

View File

@@ -1,32 +1,17 @@
#!/usr/bin/env python
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import layer_init, SGD
from tinygrad.utils import fetch_mnist
from tqdm import trange
# load the mnist dataset
def fetch(url):
import requests, gzip, os, hashlib, numpy
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
if not os.path.isfile(fp):
with open(fp, "rb") as f:
dat = f.read()
else:
with open(fp, "wb") as f:
dat = requests.get(url).content
f.write(dat)
return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
X_train, Y_train, X_test, Y_test = fetch_mnist()
# train a model
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
class TinyBobNet:
def __init__(self):
self.l1 = Tensor(layer_init(784, 128))
@@ -37,14 +22,6 @@ class TinyBobNet:
# optimizer
class SGD:
def __init__(self, tensors, lr):
self.tensors = tensors
self.lr = lr
def step(self):
for t in self.tensors:
t.data -= self.lr * t.grad
model = TinyBobNet()
optim = SGD([model.l1, model.l2], lr=0.01)

15
tinygrad/nn.py Normal file
View File

@@ -0,0 +1,15 @@
import numpy as np
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
class SGD:
def __init__(self, tensors, lr):
self.tensors = tensors
self.lr = lr
def step(self):
for t in self.tensors:
t.data -= self.lr * t.grad

17
tinygrad/utils.py Normal file
View File

@@ -0,0 +1,17 @@
def fetch_mnist():
def fetch(url):
import requests, gzip, os, hashlib, numpy
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
if not os.path.isfile(fp):
with open(fp, "rb") as f:
dat = f.read()
else:
with open(fp, "wb") as f:
dat = requests.get(url).content
f.write(dat)
return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy()
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
return X_train, Y_train, X_test, Y_test