mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
simpler graph demo
This commit is contained in:
@@ -171,9 +171,8 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un
|
||||
* Purple edge is intermediates created in the forward
|
||||
|
||||
```bash
|
||||
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_conv_onestep
|
||||
dot -Tsvg /tmp/net.dot -o /tmp/net.svg
|
||||
open /tmp/net.svg
|
||||
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
|
||||
dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg
|
||||
```
|
||||
|
||||
### Running tests
|
||||
|
||||
@@ -13,7 +13,6 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# create a model
|
||||
class TinyBobNet:
|
||||
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.uniform(784, 128)
|
||||
self.l2 = Tensor.uniform(128, 10)
|
||||
@@ -46,11 +45,11 @@ class TinyConvNet:
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def test_conv_onestep(self):
|
||||
def test_sgd_onestep(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=1)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, BS=69, steps=1)
|
||||
|
||||
def test_conv(self):
|
||||
np.random.seed(1337)
|
||||
|
||||
Reference in New Issue
Block a user