refactor better

This commit is contained in:
George Hotz
2020-10-18 13:33:02 -07:00
parent 92fd23df66
commit 6532233d24
3 changed files with 6 additions and 9 deletions

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import layer_init, SGD
from tinygrad.utils import fetch_mnist
import tinygrad.optim as optim
from tqdm import trange
@@ -12,6 +12,10 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
# train a model
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
class TinyBobNet:
def __init__(self):
self.l1 = Tensor(layer_init(784, 128))
@@ -22,9 +26,8 @@ class TinyBobNet:
# optimizer
model = TinyBobNet()
optim = SGD([model.l1, model.l2], lr=0.01)
optim = optim.SGD([model.l1, model.l2], lr=0.01)
BS = 128
losses, accuracies = [], []

0
tinygrad/__init__.py Normal file
View File

View File

@@ -1,9 +1,3 @@
import numpy as np
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
class SGD:
def __init__(self, tensors, lr):
self.tensors = tensors