This commit is contained in:
George Hotz
2020-12-13 21:32:20 -08:00
parent da72a0eed4
commit b86bbd2e72
4 changed files with 19 additions and 11 deletions

View File

@@ -48,7 +48,7 @@ print(x.grad) # dz/dx
print(y.grad) # dz/dy
```
### Neural networks?
## Neural networks?
It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, RMSprop, and Adam implemented) from tinygrad.optim, write some boilerplate minibatching code, and you have all you need.
@@ -78,7 +78,7 @@ loss.backward()
optim.step()
```
### GPU Support?!
## GPU Support
tinygrad supports GPUs through PyOpenCL.
@@ -87,7 +87,7 @@ from tinygrad.tensor import Tensor
(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu()
```
### ANE Support?!?!
### ANE Support?!
If all you want to do is ReLU, you are in luck! You can do very fast ReLU (at least 30 MEGAReLUs/sec confirmed)
@@ -103,7 +103,18 @@ print(b.cpu())
Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doing something important with tinygrad and wanted to use the ANE, you might have a bad time.
### ImageNet inference
### Adding an accelerator
You need to support 14 basic ops:
```
Add, Sub, Mul, Pow, Sum, Dot
Pad2D, Reshape
Relu, Sigmoid, LogSoftmax
Conv2D, MaxPool2D, AvgPool2D
```
## ImageNet inference
Despite being tiny, tinygrad supports the full EfficientNet. Pass in a picture to discover what it is.
@@ -129,7 +140,7 @@ See `examples/mnist_gan.py`
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/mnist_by_tinygrad.jpg">
</p>
### The promise of small
## The promise of small
tinygrad will always be below 1000 lines. If it isn't, we will revert commits until tinygrad becomes smaller.
@@ -142,9 +153,6 @@ python3 -m pytest
### TODO
* Train an EfficientNet on ImageNet
* Make broadcasting work on the backward pass (simple please)
* EfficientNet backward pass
* Tensors on GPU (a few more backward)
* Add a language model. BERT?
* Add a detection model. EfficientDet?
* Reduce code

View File

@@ -86,8 +86,6 @@ class Dot(Function):
grad_weight = input.T.dot(grad_output)
return grad_input, grad_weight
register('dot', Dot)
register('matmul', Dot)
# ************* simple ops *************

View File

@@ -290,7 +290,6 @@ class Dot(Function):
return grad_input, grad_weight
register('dot', Dot, device=Tensor.GPU)
register('matmul', Dot, device=Tensor.GPU)
# ************* simple ops *************

View File

@@ -212,6 +212,9 @@ class Tensor:
# ***** non first class ops *****
def matmul(self, w):
return self.dot(w)
def mean(self, axis=None):
out = self.sum(axis=axis)
coeff = np.prod(out.shape)/np.prod(self.shape)