docs cleanup and move (#4593)

* cleanup and move

* docs-legacy is gone

* don't update setup.py
This commit is contained in:
George Hotz
2024-05-14 20:44:59 -07:00
committed by GitHub
parent fd02ab1e8b
commit 9425973bc7
18 changed files with 8 additions and 215 deletions

117
docs/abstractions2.py Normal file
View File

@@ -0,0 +1,117 @@
# tinygrad is a tensor library, and as a tensor library it has multiple parts
# 1. a "runtime". this allows buffer management, compilation, and running programs
# 2. a "Device" that uses the runtime but specifies compute in an abstract way for all
# 3. a "LazyBuffer" that fuses the compute into kernels, using memory only when needed
# 4. a "Tensor" that provides an easy to use frontend with autograd ".backward()"
print("******** first, the runtime ***********")
from tinygrad.runtime.ops_clang import ClangProgram, ClangCompiler, MallocAllocator
# allocate some buffers
out = MallocAllocator.alloc(4)
a = MallocAllocator.alloc(4)
b = MallocAllocator.alloc(4)
# load in some values (little endian)
MallocAllocator.copyin(a, bytearray([2,0,0,0]))
MallocAllocator.copyin(b, bytearray([3,0,0,0]))
# compile a program to a binary
lib = ClangCompiler().compile("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
# create a runtime for the program (ctypes.CDLL)
fxn = ClangProgram("add", lib)
# run the program
fxn(out, a, b)
# check the data out
print(val := MallocAllocator.as_buffer(out).cast("I").tolist()[0])
assert val == 5
print("******** second, the Device ***********")
DEVICE = "CLANG" # NOTE: you can change this!
import struct
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device
from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
a = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_linearizer, CompiledRunner
lin = get_linearizer(Device[DEVICE].renderer, (st_0,)).linearize()
for u in lin.uops: print(u)
# compile a program (and print the source)
fxn = CompiledRunner(lin.to_program())
print(fxn.p.src)
# NOTE: fxn.clprg is the ClangProgram
# run the program
fxn.exec([out, a, b])
# check the data out
assert out.as_buffer().cast('I')[0] == 5
print("******** third, the LazyBuffer ***********")
from tinygrad.lazy import LazyBuffer, LoadOps
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule
# allocate some values + load in values
a = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE)
b = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE)
a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2))))
b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
del a.srcs
del b.srcs
# describe the computation
out = a.e(BinaryOps.ADD, b)
# schedule the computation as a list of kernels
sched = create_schedule([out])
for si in sched: print(si.ast[0].op) # NOTE: the first two convert it to CLANG
# DEBUGGING: print the compute ast as a tree
from tinygrad.features.graph import print_tree
print_tree(sched[-1].ast[0])
# NOTE: sched[-1].ast is the same as st_0 above
# run that schedule
run_schedule(sched)
# check the data out
assert out.realized.as_buffer().cast('I')[0] == 5
print("******** fourth, the Tensor ***********")
from tinygrad import Tensor
a = Tensor([2], dtype=dtypes.int32, device=DEVICE)
b = Tensor([3], dtype=dtypes.int32, device=DEVICE)
out = a + b
# check the data out
print(val:=out.item())
assert val == 5

62
docs/abstractions3.py Normal file
View File

@@ -0,0 +1,62 @@
# abstractions2 goes from back to front, here we will go from front to back
from typing import List
from tqdm import tqdm
from tinygrad.helpers import DEBUG
# *****
# 0. Load mnist on the device
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
X_train = X_train.float()
X_train -= X_train.mean()
# *****
# 1. Define an MNIST model.
from tinygrad import Tensor
l1 = Tensor.kaiming_uniform(128, 784)
l2 = Tensor.kaiming_uniform(10, 128)
def model(x): return x.flatten(1).dot(l1.T).relu().dot(l2.T)
l1n, l2n = l1.numpy(), l2.numpy()
# *****
# 2. Choose a batch for training and do the backward pass.
from tinygrad.nn.optim import SGD
optim = SGD([l1, l2])
X, Y = X_train[samples:=Tensor.randint(128, high=X_train.shape[0])], Y_train[samples]
optim.zero_grad()
model(X).sparse_categorical_crossentropy(Y).backward()
optim._step() # this will step the optimizer without running realize
# *****
# 3. Create a schedule.
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
# l1.lazydata and l2.lazydata define a computation graph
from tinygrad.engine.schedule import ScheduleItem
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
print(f"The schedule contains {len(schedule)} items.")
for si in schedule: print(str(si)[:80])
# *****
# 4. Lower a schedule.
from tinygrad.engine.realize import lower_schedule_item, ExecItem
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si), list(si.bufs)) for si in tqdm(schedule)]
# *****
# 5. Run the schedule
for ei in tqdm(lowered): ei.run()
# *****
# 6. Print the weight change
print("first weight change\n", l1.numpy()-l1n)
print("second weight change\n", l2.numpy()-l2n)

178
docs/env_vars.md Normal file
View File

@@ -0,0 +1,178 @@
# List of environment variables that control tinygrad behavior.
This is a list of environment variable that control the runtime behavior of tinygrad and its examples.
Most of these are self-explanatory, and are usually used to set an option at runtime.
Example: `GPU=1 DEBUG=4 python3 -m pytest`
However you can also decorate a function to set a value only inside that function.
```python
# in tensor.py (probably only useful if you are a tinygrad developer)
@Context(DEBUG=4)
def numpy(self) -> ...
```
Or use contextmanager to temporarily set a value inside some scope:
```python
with Context(DEBUG=0):
a = Tensor.ones(10, 10)
a *= 2
```
## Global Variables
The columns of this list are are: Variable, Possible Value(s) and Description.
- A `#` means that the variable can take any integer value.
These control the behavior of core tinygrad even when used as a library.
Variable | Possible Value(s) | Description
---|---|---
DEBUG | [1-6] | enable debugging output, with 4 you get operations, timings, speed, generated code and more
GPU | [1] | enable the GPU backend
CUDA | [1] | enable CUDA backend
HSA | [1] | enable HSA backend
METAL | [1] | enable Metal backend (for Mac M1 and after)
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
CLANG | [1] | enable Clang backend
LLVM | [1] | enable LLVM backend
BEAM | [#] | number of beams in kernel beam search
GRAPH | [1] | create a graph of all operations (requires graphviz)
GRAPHUOPS | [1] | create a graph of uops (requires graphviz and saves at /tmp/uops.{svg,dot})
GRAPHPATH | [/path/to] | where to put the generated graph
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
IMAGE | [1-2] | enable 2d specific optimizations
FLOAT16 | [1] | use float16 for images instead of float32
PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend.
## File Specific Variables
These are variables that control the behavior of a specific file, these usually don't affect the library itself.
Most of the time these will never be used, but they are here for completeness.
### accel/ane/2_compile/hwx_parse.py
Variable | Possible Value(s) | Description
---|---|---
PRINTALL | [1] | print all ANE registers
### extra/onnx.py
Variable | Possible Value(s) | Description
---|---|---
ONNXLIMIT | [#] | set a limit for ONNX
DEBUGONNX | [1] | enable ONNX debugging
### extra/thneed.py
Variable | Possible Value(s) | Description
---|---|---
DEBUGCL | [1-4] | enable Debugging for OpenCL
PRINT_KERNEL | [1] | Print OpenCL Kernels
### examples/vit.py
Variable | Possible Value(s) | Description
---|---|---
LARGE | [1] | enable larger dimension model
### examples/llama.py
Variable | Possible Value(s) | Description
---|---|---
WEIGHTS | [1] | enable loading weights
### examples/mlperf
Variable | Possible Value(s) | Description
---|---|---
MODEL | [resnet,retinanet,unet3d,rnnt,bert,maskrcnn] | what models to use
### examples/benchmark_train_efficientnet.py
Variable | Possible Value(s) | Description
---|---|---
CNT | [10] | the amount of times to loop the benchmark
BACKWARD | [1] | enable backward pass
TRAINING | [1] | set Tensor.training
CLCACHE | [1] | enable cache for OpenCL
### examples/hlb_cifar10.py
Variable | Possible Value(s) | Description
---|---|---
TORCHWEIGHTS | [1] | use torch to initialize weights
DISABLE_BACKWARD | [1] | don't do backward pass
DIST | [1] | enable distributed training
STEPS | [#] | number of steps
### examples/benchmark_train_efficientnet.py & examples/hlb_cifar10.py
Variable | Possible Value(s) | Description
---|---|---
ADAM | [1] | use the Adam optimizer
### examples/train_efficientnet.py
Variable | Possible Value(s) | Description
---|---|---
STEPS | [# % 1024] | number of steps
TINY | [1] | use a tiny convolution network
IMAGENET | [1] | use imagenet for training
### examples/train_efficientnet.py & examples/train_resnet.py
Variable | Possible Value(s) | Description
---|---|---
TRANSFER | [1] | enable to use pretrained data
### examples & test/external/external_test_opt.py
Variable | Possible Value(s) | Description
---|---|---
NUM | [18, 2] | what ResNet[18] / EfficientNet[2] to train
### test/test_ops.py
Variable | Possible Value(s) | Description
---|---|---
PRINT_TENSORS | [1] | print tensors
FORWARD_ONLY | [1] | use forward operations only
### test/test_speed_v_torch.py
Variable | Possible Value(s) | Description
---|---|---
TORCHCUDA | [1] | enable the torch cuda backend
### test/external/external_test_gpu_ast.py
Variable | Possible Value(s) | Description
---|---|---
KCACHE | [1] | enable kernel cache
### test/external/external_test_opt.py
Variable | Possible Value(s) | Description
---|---|---
ENET_NUM | [-2,-1] | what EfficientNet to use
### test/test_dtype.py & test/extra/test_utils.py & extra/training.py
Variable | Possible Value(s) | Description
---|---|---
CI | [1] | disables some tests for CI
### examples & extra & test
Variable | Possible Value(s) | Description
---|---|---
BS | [8, 16, 32, 64, 128] | batch size to use
### extra/datasets/imagenet_download.py
Variable | Possible Value(s) | Description
---|---|---
IMGNET_TRAIN | [1] | download also training data with imagenet

11
docs/logo_tiny_dark.svg Normal file
View File

@@ -0,0 +1,11 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges" fill="#fff">
<!-- t -->
<polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
<!-- i -->
<polygon points="40,40 40,20 50,20 50,40" />
<polygon points="40,10 40,0 50,0 50,10" />
<!-- n -->
<polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
<!-- y -->
<polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
</svg>

After

Width:  |  Height:  |  Size: 538 B

11
docs/logo_tiny_light.svg Normal file
View File

@@ -0,0 +1,11 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
<!-- t -->
<polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
<!-- i -->
<polygon points="40,40 40,20 50,20 50,40" />
<polygon points="40,10 40,0 50,0 50,10" />
<!-- n -->
<polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
<!-- y -->
<polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
</svg>

After

Width:  |  Height:  |  Size: 526 B

View File

@@ -76,7 +76,7 @@ print(t6.numpy())
```
There are a lot more operations that can be performed on tensors, you can find them in the [Tensor](tensor.md) file.
Additionally reading through [abstractions2.py](https://github.com/tinygrad/tinygrad/blob/master/docs-legacy/abstractions2.py) will help you understand how operations on these tensors make their way down to your hardware.
Additionally reading through [abstractions2.py](https://github.com/tinygrad/tinygrad/blob/master/docs/abstractions2.py) will help you understand how operations on these tensors make their way down to your hardware.
## Models
@@ -299,7 +299,7 @@ Many of the models in the [models/](https://github.com/tinygrad/tinygrad/tree/ma
There exist a bunch of environment variables that control the runtime behavior of tinygrad.
Some of the commons ones are `DEBUG` and the different backend enablement variables.
You can find a full list and their descriptions in [env_vars.md](https://github.com/tinygrad/tinygrad/blob/master/docs-legacy/env_vars.md).
You can find a full list and their descriptions in [env_vars.md](https://github.com/tinygrad/tinygrad/blob/master/docs/env_vars.md).
### Visualizing the Computation Graph