mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
refactor better
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import layer_init, SGD
|
||||
from tinygrad.utils import fetch_mnist
|
||||
import tinygrad.optim as optim
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
@@ -12,6 +12,10 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# train a model
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init(784, 128))
|
||||
@@ -22,9 +26,8 @@ class TinyBobNet:
|
||||
|
||||
# optimizer
|
||||
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = SGD([model.l1, model.l2], lr=0.01)
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
||||
|
||||
BS = 128
|
||||
losses, accuracies = [], []
|
||||
|
||||
0
tinygrad/__init__.py
Normal file
0
tinygrad/__init__.py
Normal file
@@ -1,9 +1,3 @@
|
||||
import numpy as np
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
Reference in New Issue
Block a user