mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
update docs (#4356)
* update docs * nn.md * mnist cleanups * rhip test is very slow
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
<!-- TODO: remove the imported members -->
|
||||
<!-- TODO: move Function from tensor to function -->
|
||||
::: tinygrad.function
|
||||
options:
|
||||
members: true
|
||||
inherited_members: false
|
||||
show_source: false
|
||||
members_order: source
|
||||
|
||||
@@ -30,15 +30,15 @@ If you are migrating from PyTorch, welcome. Most of the API is the same. We hope
|
||||
|
||||
### tinygrad doesn't have nn.Module
|
||||
|
||||
There's nothing special about a "Module" class in tinygrad, it's just a normal class. [`nn.state.get_parameters`](nn/#tinygrad.nn.state.get_parameters) can be used to recursively search normal classes for valid tensors. Instead of the `forward` method in PyTorch, tinygrad just uses `__call__`
|
||||
There's nothing special about a "Module" class in tinygrad, it's just a normal class. [`nn.state.get_parameters`](nn.md/#tinygrad.nn.state.get_parameters) can be used to recursively search normal classes for valid tensors. Instead of the `forward` method in PyTorch, tinygrad just uses `__call__`
|
||||
|
||||
### tinygrad is functional
|
||||
|
||||
In tinygrad, you can do [`x.conv2d(w, b)`](tensor/#tinygrad.Tensor.conv2d) or [`x.sparse_categorical_cross_entropy(y)`](tensor/#tinygrad.Tensor.sparse_categorical_crossentropy). We do also have a [`Conv2D`](nn/#tinygrad.nn.Conv2d) class like PyTorch if you want a place to keep the state, but all stateless operations don't have classes.
|
||||
In tinygrad, you can do [`x.conv2d(w, b)`](tensor.md/#tinygrad.Tensor.conv2d) or [`x.sparse_categorical_cross_entropy(y)`](tensor.md/#tinygrad.Tensor.sparse_categorical_crossentropy). We do also have a [`Conv2D`](nn.md/#tinygrad.nn.Conv2d) class like PyTorch if you want a place to keep the state, but all stateless operations don't have classes.
|
||||
|
||||
### tinygrad is lazy
|
||||
|
||||
When you do `a+b` in tinygrad, nothing happens. It's not until you [`realize`](tensor/#tinygrad.Tensor.realize) the Tensor that the computation actually runs.
|
||||
When you do `a+b` in tinygrad, nothing happens. It's not until you [`realize`](tensor.md/#tinygrad.Tensor.realize) the Tensor that the computation actually runs.
|
||||
|
||||
### tinygrad requires @TinyJIT to be fast
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ After you have installed tinygrad, this is a great first tutorial.
|
||||
|
||||
Start up a notebook locally, or use [colab](https://colab.research.google.com/). tinygrad is very lightweight, so it's easy to install anywhere and doesn't need a special colab image, but for speed we recommend a T4 GPU image.
|
||||
|
||||
### One-liner to install in colab
|
||||
### One-liner to install tinygrad in colab
|
||||
|
||||
```python
|
||||
!pip install git+https://github.com/tinygrad/tinygrad.git
|
||||
@@ -64,12 +64,12 @@ So creating the model and evaluating it is a matter of:
|
||||
model = Model()
|
||||
acc = (model(X_test).argmax(axis=1) == Y_test).mean()
|
||||
# NOTE: tinygrad is lazy, and hasn't actually run anything by this point
|
||||
print(acc.item()) # ~10% accuracy, as expected
|
||||
print(acc.item()) # ~10% accuracy, as expected from a random model
|
||||
```
|
||||
|
||||
### Training the model
|
||||
|
||||
We need an optimizer, and we'll use Adam. The `nn.state.get_parameters` will walk the class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.
|
||||
We'll use the Adam optimizer. The `nn.state.get_parameters` will walk the model class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.
|
||||
|
||||
```python
|
||||
optim = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
@@ -107,7 +107,7 @@ from tinygrad import TinyJit
|
||||
jit_step = TinyJit(step)
|
||||
```
|
||||
|
||||
NOTE: it can also be used as a decorator `@TinyJit`
|
||||
NOTE: It can also be used as a decorator `@TinyJit`
|
||||
|
||||
Now when we time it:
|
||||
|
||||
@@ -123,7 +123,7 @@ timeit.repeat(jit_step, repeat=5, number=1)
|
||||
|
||||
1.0 ms is 75x faster! Note that we aren't syncing the GPU, so GPU time may be slower.
|
||||
|
||||
The slowness the first two times is the JIT capturing the kernels. And this JIT will not run any Python in the function, it will just replay the tinygrad kernels that were run, so be aware. Randomness functions as expected.
|
||||
The slowness the first two times is the JIT capturing the kernels. And this JIT will not run any Python in the function, it will just replay the tinygrad kernels that were run, so be aware that non tinygrad Python operations won't work. Randomness functions as expected.
|
||||
|
||||
Unlike other JITs, we JIT everything, including the optimizer. Think of it as a dumb replay on different data.
|
||||
|
||||
|
||||
30
docs/nn.md
30
docs/nn.md
@@ -1,17 +1,29 @@
|
||||
## Neural Network classes
|
||||
|
||||
::: tinygrad.nn
|
||||
options:
|
||||
members: true
|
||||
::: tinygrad.nn.BatchNorm2d
|
||||
::: tinygrad.nn.Conv1d
|
||||
::: tinygrad.nn.Conv2d
|
||||
::: tinygrad.nn.ConvTranspose2d
|
||||
::: tinygrad.nn.Linear
|
||||
::: tinygrad.nn.GroupNorm
|
||||
::: tinygrad.nn.InstanceNorm
|
||||
::: tinygrad.nn.LayerNorm
|
||||
::: tinygrad.nn.LayerNorm2d
|
||||
::: tinygrad.nn.Embedding
|
||||
|
||||
## Optimizers
|
||||
|
||||
::: tinygrad.nn.optim
|
||||
options:
|
||||
members: true
|
||||
::: tinygrad.nn.optim.SGD
|
||||
::: tinygrad.nn.optim.LARS
|
||||
::: tinygrad.nn.optim.AdamW
|
||||
::: tinygrad.nn.optim.Adam
|
||||
::: tinygrad.nn.optim.LAMB
|
||||
|
||||
## Load/Save
|
||||
|
||||
::: tinygrad.nn.state
|
||||
options:
|
||||
members: true
|
||||
::: tinygrad.nn.state.safe_load
|
||||
::: tinygrad.nn.state.safe_save
|
||||
::: tinygrad.nn.state.get_state_dict
|
||||
::: tinygrad.nn.state.get_parameters
|
||||
::: tinygrad.nn.state.load_state_dict
|
||||
::: tinygrad.nn.state.torch_load
|
||||
Reference in New Issue
Block a user