mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
refactor into a few files
This commit is contained in:
@@ -1,32 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import layer_init, SGD
|
||||
from tinygrad.utils import fetch_mnist
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
# load the mnist dataset
|
||||
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if not os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# train a model
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init(784, 128))
|
||||
@@ -37,14 +22,6 @@ class TinyBobNet:
|
||||
|
||||
# optimizer
|
||||
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.tensors:
|
||||
t.data -= self.lr * t.grad
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = SGD([model.l1, model.l2], lr=0.01)
|
||||
|
||||
15
tinygrad/nn.py
Normal file
15
tinygrad/nn.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.tensors:
|
||||
t.data -= self.lr * t.grad
|
||||
|
||||
17
tinygrad/utils.py
Normal file
17
tinygrad/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
def fetch_mnist():
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if not os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy()
|
||||
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
Reference in New Issue
Block a user