more readme

This commit is contained in:
George Hotz
2020-10-18 14:38:20 -07:00
parent cc9054e3ec
commit 4019c38942
2 changed files with 9 additions and 4 deletions

View File

@@ -55,7 +55,12 @@ class TinyBobNet:
model = TinyBobNet()
optim = optim.SGD([model.l1, model.l2], lr=0.001)
# ... and complete like pytorch
# ... and complete like pytorch, with (x,y) data
out = model.forward(x)
loss = out.mul(y).mean()
loss.backward()
optim.step()
```
### TODO (to make real neural network library)

View File

@@ -41,14 +41,14 @@ for i in (t := trange(1000)):
y = Tensor(y)
# network
outs = model.forward(x)
out = model.forward(x)
# NLL loss function
loss = outs.mul(y).mean()
loss = out.mul(y).mean()
loss.backward()
optim.step()
cat = np.argmax(outs.data, axis=1)
cat = np.argmax(out.data, axis=1)
accuracy = (cat == Y).mean()
# printing