simpler graph demo

This commit is contained in:
George Hotz
2022-06-05 12:40:12 -07:00
parent 89acf6742d
commit f0fe37bd34
2 changed files with 6 additions and 8 deletions

View File

@@ -171,9 +171,8 @@ tinygrad will always be below 1000 lines. If it isn't, we will revert commits un
* Purple edge is intermediates created in the forward
```bash
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_conv_onestep
dot -Tsvg /tmp/net.dot -o /tmp/net.svg
open /tmp/net.svg
GRAPH=1 python3 test/test_mnist.py TestMNIST.test_sgd_onestep
dot -Tsvg /tmp/net.dot -o /tmp/net.svg && open /tmp/net.svg
```
### Running tests

View File

@@ -13,7 +13,6 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
# create a model
class TinyBobNet:
def __init__(self):
self.l1 = Tensor.uniform(784, 128)
self.l2 = Tensor.uniform(128, 10)
@@ -46,11 +45,11 @@ class TinyConvNet:
return x.dot(self.l1).logsoftmax()
class TestMNIST(unittest.TestCase):
def test_conv_onestep(self):
def test_sgd_onestep(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=1)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, BS=69, steps=1)
def test_conv(self):
np.random.seed(1337)