update docs (#4356)

* update docs

* nn.md

* mnist cleanups

* rhip test is very slow
This commit is contained in:
George Hotz
2024-04-30 15:51:42 +08:00
committed by GitHub
parent a2d81514fd
commit d325be2540
9 changed files with 45 additions and 23 deletions

View File

@@ -1,7 +1,7 @@
<!-- TODO: remove the imported members -->
<!-- TODO: move Function from tensor to function -->
::: tinygrad.function
options:
members: true
inherited_members: false
show_source: false
members_order: source

View File

@@ -30,15 +30,15 @@ If you are migrating from PyTorch, welcome. Most of the API is the same. We hope
### tinygrad doesn't have nn.Module
There's nothing special about a "Module" class in tinygrad, it's just a normal class. [`nn.state.get_parameters`](nn/#tinygrad.nn.state.get_parameters) can be used to recursively search normal classes for valid tensors. Instead of the `forward` method in PyTorch, tinygrad just uses `__call__`
There's nothing special about a "Module" class in tinygrad, it's just a normal class. [`nn.state.get_parameters`](nn.md/#tinygrad.nn.state.get_parameters) can be used to recursively search normal classes for valid tensors. Instead of the `forward` method in PyTorch, tinygrad just uses `__call__`
### tinygrad is functional
In tinygrad, you can do [`x.conv2d(w, b)`](tensor/#tinygrad.Tensor.conv2d) or [`x.sparse_categorical_cross_entropy(y)`](tensor/#tinygrad.Tensor.sparse_categorical_crossentropy). We do also have a [`Conv2D`](nn/#tinygrad.nn.Conv2d) class like PyTorch if you want a place to keep the state, but all stateless operations don't have classes.
In tinygrad, you can do [`x.conv2d(w, b)`](tensor.md/#tinygrad.Tensor.conv2d) or [`x.sparse_categorical_cross_entropy(y)`](tensor.md/#tinygrad.Tensor.sparse_categorical_crossentropy). We do also have a [`Conv2D`](nn.md/#tinygrad.nn.Conv2d) class like PyTorch if you want a place to keep the state, but all stateless operations don't have classes.
### tinygrad is lazy
When you do `a+b` in tinygrad, nothing happens. It's not until you [`realize`](tensor/#tinygrad.Tensor.realize) the Tensor that the computation actually runs.
When you do `a+b` in tinygrad, nothing happens. It's not until you [`realize`](tensor.md/#tinygrad.Tensor.realize) the Tensor that the computation actually runs.
### tinygrad requires @TinyJIT to be fast

View File

@@ -4,7 +4,7 @@ After you have installed tinygrad, this is a great first tutorial.
Start up a notebook locally, or use [colab](https://colab.research.google.com/). tinygrad is very lightweight, so it's easy to install anywhere and doesn't need a special colab image, but for speed we recommend a T4 GPU image.
### One-liner to install in colab
### One-liner to install tinygrad in colab
```python
!pip install git+https://github.com/tinygrad/tinygrad.git
@@ -64,12 +64,12 @@ So creating the model and evaluating it is a matter of:
model = Model()
acc = (model(X_test).argmax(axis=1) == Y_test).mean()
# NOTE: tinygrad is lazy, and hasn't actually run anything by this point
print(acc.item()) # ~10% accuracy, as expected
print(acc.item()) # ~10% accuracy, as expected from a random model
```
### Training the model
We need an optimizer, and we'll use Adam. The `nn.state.get_parameters` will walk the class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.
We'll use the Adam optimizer. The `nn.state.get_parameters` will walk the model class and pull out the parameters for the optimizer. Also, in tinygrad, it's typical to write a function to do the training step so it can be jitted.
```python
optim = nn.optim.Adam(nn.state.get_parameters(model))
@@ -107,7 +107,7 @@ from tinygrad import TinyJit
jit_step = TinyJit(step)
```
NOTE: it can also be used as a decorator `@TinyJit`
NOTE: It can also be used as a decorator `@TinyJit`
Now when we time it:
@@ -123,7 +123,7 @@ timeit.repeat(jit_step, repeat=5, number=1)
1.0 ms is 75x faster! Note that we aren't syncing the GPU, so GPU time may be slower.
The slowness the first two times is the JIT capturing the kernels. And this JIT will not run any Python in the function, it will just replay the tinygrad kernels that were run, so be aware. Randomness functions as expected.
The slowness the first two times is the JIT capturing the kernels. And this JIT will not run any Python in the function, it will just replay the tinygrad kernels that were run, so be aware that non tinygrad Python operations won't work. Randomness functions as expected.
Unlike other JITs, we JIT everything, including the optimizer. Think of it as a dumb replay on different data.

View File

@@ -1,17 +1,29 @@
## Neural Network classes
::: tinygrad.nn
options:
members: true
::: tinygrad.nn.BatchNorm2d
::: tinygrad.nn.Conv1d
::: tinygrad.nn.Conv2d
::: tinygrad.nn.ConvTranspose2d
::: tinygrad.nn.Linear
::: tinygrad.nn.GroupNorm
::: tinygrad.nn.InstanceNorm
::: tinygrad.nn.LayerNorm
::: tinygrad.nn.LayerNorm2d
::: tinygrad.nn.Embedding
## Optimizers
::: tinygrad.nn.optim
options:
members: true
::: tinygrad.nn.optim.SGD
::: tinygrad.nn.optim.LARS
::: tinygrad.nn.optim.AdamW
::: tinygrad.nn.optim.Adam
::: tinygrad.nn.optim.LAMB
## Load/Save
::: tinygrad.nn.state
options:
members: true
::: tinygrad.nn.state.safe_load
::: tinygrad.nn.state.safe_save
::: tinygrad.nn.state.get_state_dict
::: tinygrad.nn.state.get_parameters
::: tinygrad.nn.state.load_state_dict
::: tinygrad.nn.state.torch_load

View File

@@ -6,6 +6,7 @@ nav:
- Tensor: tensor.md
- dtypes: dtypes.md
- Neural Networks: nn.md
- MNIST Tutorial: mnist.md
- Quickstart: quickstart.md
- Showcase: showcase.md
- Developer: developer.md

View File

@@ -249,6 +249,7 @@ class TestLinearizer(unittest.TestCase):
# check correctness
helper_tc_allclose(tc.dims[1]+pad, tc.dims[2]+pad, tc.dims[0]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
@unittest.skipIf(Device.DEFAULT == "RHIP", "RHIP is really slow here")
def test_tensor_cores_multi_reduce(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")

View File

@@ -1,6 +1,6 @@
import unittest
from PIL import Image
from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction
from tinygrad.helpers import Context, ContextVar, merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, CI
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)
@@ -145,6 +145,7 @@ class TestFetch(unittest.TestCase):
def test_fetch_bad_http(self):
self.assertRaises(Exception, fetch, 'http://www.google.com/404')
@unittest.skipIf(not CI, "pre commit tests should run offline")
def test_fetch_small(self):
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)

View File

@@ -56,7 +56,7 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False):
def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
start_mem_used = GlobalCounters.mem_used
with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
model_state_dict = get_state_dict(model)

View File

@@ -13,11 +13,18 @@ from tinygrad.buffer import Buffer
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
class UnaryOps(Enum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum):
"""A + A -> A (elementwise)"""
ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
class TernaryOps(Enum): WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class TernaryOps(Enum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum):
"""A -> B (reduce)"""
SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702