mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
more readme
This commit is contained in:
@@ -55,7 +55,12 @@ class TinyBobNet:
|
||||
model = TinyBobNet()
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
|
||||
# ... and complete like pytorch
|
||||
# ... and complete like pytorch, with (x,y) data
|
||||
|
||||
out = model.forward(x)
|
||||
loss = out.mul(y).mean()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
```
|
||||
|
||||
### TODO (to make real neural network library)
|
||||
|
||||
@@ -41,14 +41,14 @@ for i in (t := trange(1000)):
|
||||
y = Tensor(y)
|
||||
|
||||
# network
|
||||
outs = model.forward(x)
|
||||
out = model.forward(x)
|
||||
|
||||
# NLL loss function
|
||||
loss = outs.mul(y).mean()
|
||||
loss = out.mul(y).mean()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(outs.data, axis=1)
|
||||
cat = np.argmax(out.data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# printing
|
||||
|
||||
Reference in New Issue
Block a user