mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
refactor better
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import layer_init, SGD
|
|
||||||
from tinygrad.utils import fetch_mnist
|
from tinygrad.utils import fetch_mnist
|
||||||
|
import tinygrad.optim as optim
|
||||||
|
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
@@ -12,6 +12,10 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|||||||
|
|
||||||
# train a model
|
# train a model
|
||||||
|
|
||||||
|
def layer_init(m, h):
|
||||||
|
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||||
|
return ret.astype(np.float32)
|
||||||
|
|
||||||
class TinyBobNet:
|
class TinyBobNet:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.l1 = Tensor(layer_init(784, 128))
|
self.l1 = Tensor(layer_init(784, 128))
|
||||||
@@ -22,9 +26,8 @@ class TinyBobNet:
|
|||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
|
|
||||||
|
|
||||||
model = TinyBobNet()
|
model = TinyBobNet()
|
||||||
optim = SGD([model.l1, model.l2], lr=0.01)
|
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
||||||
|
|
||||||
BS = 128
|
BS = 128
|
||||||
losses, accuracies = [], []
|
losses, accuracies = [], []
|
||||||
|
|||||||
0
tinygrad/__init__.py
Normal file
0
tinygrad/__init__.py
Normal file
@@ -1,9 +1,3 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
def layer_init(m, h):
|
|
||||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
|
||||||
return ret.astype(np.float32)
|
|
||||||
|
|
||||||
class SGD:
|
class SGD:
|
||||||
def __init__(self, tensors, lr):
|
def __init__(self, tensors, lr):
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
Reference in New Issue
Block a user