diff --git a/docs/developer/developer.md b/docs/developer/developer.md index 39e9e0901b..f932f0a935 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -7,6 +7,8 @@ The tinygrad framework has four pieces There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-notes/) by Di Zhu that go over tinygrad internals. +There's also a [doc describing speed](../developer/speed.md) + ## Frontend Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md). diff --git a/docs/developer/speed.md b/docs/developer/speed.md new file mode 100644 index 0000000000..e4801e6418 --- /dev/null +++ b/docs/developer/speed.md @@ -0,0 +1,71 @@ +# speed in tinygrad + +## Overview + +Speed refers to many different things. To break it down to four, there's: + +- Compile Speed (Python) +- Execution Speed (driver) +- Model Speed (scheduler) +- Kernel Speed (codegen) + +## Compile Speed (Python) + +This is how long the first run of your model takes. It's limited largely by the runtime of the Python doing UOp rewrites. Currently it's a bit slow, but on par with torch.compile. It gets even slower if you are using BEAM, since that's compiling many variants of each kernel. + +This will be improved by writing faster graph_rewrite, doing less graph_rewrite, and better parallelization. + +## Execution Speed (driver) + +After your model is compiled, you are often using the `TinyJIT`. tinygrad has the best execution speed of any framework because it usually bypasses the GPU driver and prebuilds the command queue. It's tons faster than normal CUDA, and often even faster than CUDA Graph. + +There's very little to improve here, as this is almost never the bottleneck. + +## Model Speed (scheduler) + +The scheduler determines how operations are grouped into kernels and which Tensors are written to memory. This is currently a big bottleneck of training speed. + +The decisions are often not obvious. For example, when is it worth recomputing an arithmetic operation instead of storing and loading from memory? Example: + +```python +from tinygrad import Tensor +a = Tensor.rand(100) +b = Tensor.rand(100) +c = Tensor.rand(100) +d = Tensor.rand(100) +out1 = a+b+c +out2 = a+b+d +Tensor.realize(out1, out2) +``` + +The real answer is obvious, compute both `out1` and `out2` in the same kernel. But you can't always do that. If you can't, should `a+b` first be saved to a subbuffer? Or should both the `out1` and `out2` kernels recompute `a+b`? + +In this case: with recompute (6 reads + 2 writes), no recompute (6 reads + 3 writes), so we should probably recompute. However, once you add movement ops and casts this is even harder to figure out. tinygrad doesn't yet have a systematic way to do it. + +## Kernel Speed (codegen) + +Given that you have decided how the model ops will be grouped and what will be written to memory, kernel speed determines how fast that operation is done. This is what BEAM changes, it searches over a set of equivalent kernels which all perform the same operation and finds the one which performs the task the fastest. + +In `kernel.py` we have a set of `OptOps`, these control the parameters of the speed optimizations applied to the kernel. + +### Memory + +The main bottleneck in most kernels is accessing memory. In a freshman algorithms class, you'll learn about cache aware matrix multiplication, and this is all forms of that. While the same math is run, the order in which you run it can have large impacts on the speed depending on if the data you are loading. OptOps will change this order. + +Memory, even cache, is often much slower than accessing the register file. The amount of times data is used in math is called the "arithmetic intensity". For operations like BS=1 GEMV, the arithmetic intensity is 1, but for GEMMs and convs it can be much higher. OptOps like UPCAST and UNROLL can increase this, but be careful of making them too large, as if there's too much register pressure on the GPU the warp scheduler may not be able to fit many warps, or even worse, it could be spilling to local memory. + +4090s have 1 TB/s of ram bandwidth and ~160 TFLOPS of compute, so you need to use each loaded value ~100 times. The L1 cache has around 40 TB/s of bandwidth, so in order to get full compute utilization you need to use each value ~4 times. + +A lot of work can still be done here. For example, we never copy the inputs to on chip SRAM, but this is often quite helpful for kernel speed. Also, we aren't doing a good job with L2 cache awareness (the locals handle L1 quite well) + +### Tensor Cores + +Many accelerators have Tensor Cores / MAC arrays / systolic arrays. The main value of these is that, since they are 2-D, they create an n^2 ratio between the compute and the input data. + +GPUs use Tensor Cores instead of MAC arrays to fit better in the GPU warp paradigm. This is because the output of Tensor Cores is O(n) wrt the input, while the output of MAC arrays like the AMX is O(n^2) + +We have a simple framework in tinygrad for adding these ALU blocks and achieving good performance from them. + +### Indexing + +Indexing determines the address of the memory we need to load. GPUs often have less integer math resources than floating point math, so this can sometimes be the bottleneck. We have a symbolic math engine in our rewrite rules to simplifiy indexing before it's emitted to the kernel. Newer NVIDIA GPUs have a "Tensor Memory Accelerator" to assist with fast indexing, however, this is not supported in tinygrad yet. diff --git a/extra/gemm/torch_gemm.py b/extra/gemm/torch_gemm.py index a87f2757ec..6dde871980 100644 --- a/extra/gemm/torch_gemm.py +++ b/extra/gemm/torch_gemm.py @@ -1,17 +1,26 @@ +import os +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" import time import torch +torch.set_num_threads(1) +from tinygrad.helpers import getenv +CUDA = getenv("CUDA", 1) -for dtype in [torch.float16, torch.float32]: +for dtype in [torch.float32, torch.float16]: for N in [256, 512, 1024, 2048, 4096]: FLOPS = N*N*N*2 - b = torch.rand((N,N), dtype=dtype).cuda() - c = torch.rand((N,N), dtype=dtype).cuda() + b = torch.rand((N,N), dtype=dtype) + c = torch.rand((N,N), dtype=dtype) + if CUDA: b,c = b.cuda(),c.cuda() def torch_prog(b, c): st = time.perf_counter() a = b@c - torch.cuda.synchronize() + if CUDA: torch.cuda.synchronize() return time.perf_counter() - st tm = min([torch_prog(b, c) for _ in range(20)]) print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}") diff --git a/mkdocs.yml b/mkdocs.yml index 38419a5708..a09a4b47fc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,6 +22,7 @@ nav: - Runtime: runtime.md - Developer: - Intro: developer/developer.md + - Speed: developer/speed.md - UOp: developer/uop.md - Runtime: - developer/runtime.md diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a5b5519c8a..596f502238 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -981,6 +981,16 @@ class TestLinearizer(unittest.TestCase): assert len(stores) == 1 assert stores[0].src[-1].dtype == dtypes.float.vec(4) + # NOTE: can reenable, it does work. it just makes BEAM slow + @unittest.expectedFailure + @unittest.skipUnless(Device.DEFAULT == "CLANG", "test only for CLANG") + def test_upcast_with_locals_clang(self): + out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous() + k = Kernel(out.schedule()[-1].ast) + k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4)) + prg = k.to_program() + self.assertEqual(len(prg.src.split("for")), 5) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index b137bc387c..a9df66b810 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -385,6 +385,8 @@ class Kernel: check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") if opt.op is OptOps.LOCAL: # cyan + # NOTE: LLVM/CLANG can use locals too, but they are treated the same as globals (still helpful for L1 cache) + # it's disabled for now since it makes BEAM slow for little gain check(self.opts.has_local, "target does not support local") check(axis < self.global_dims, "local is for globals") self.shift_to(axis, amt, insert_before=self.first_reduce) diff --git a/tinygrad/device.py b/tinygrad/device.py index 2fd2607c2c..68ffa9fdc0 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -207,7 +207,8 @@ class LRUAllocator(Allocator): class _MallocAllocator(LRUAllocator): def _alloc(self, size:int, options:BufferSpec): - return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 16) + # must be aligned to 0x20 for 256-bit ymm registers + return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 0x20) def _alloc_aligned(self, size:int, alignment:int): buffer = (ctypes.c_uint8 * (size + alignment))() offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 6a59245875..8ba14a7c37 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -133,7 +133,8 @@ class LLVMRenderer(Renderer): if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" - args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") + # NOTE: MallocAllocator promises 0x20 alignment + args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)