diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index ab604c6646..1d1156c847 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -30,6 +30,10 @@ jobs: ln -s ~/tinygrad/weights/LLaMA weights/LLaMA python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt + - name: Run GPT2 + run: | + python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt + JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt - name: Run 10 CIFAR training steps run: | ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz @@ -43,6 +47,8 @@ jobs: train_cifar.txt llama_unjitted.txt llama_jitted.txt + gpt2_unjitted.txt + gpt2_jitted.txt testamdbenchmark: name: AMD Benchmark @@ -67,6 +73,10 @@ jobs: ln -s ~/tinygrad/weights/LLaMA weights/LLaMA python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt + - name: Run GPT2 + run: | + python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt + JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt - name: Run 10 CIFAR training steps run: | ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz @@ -80,3 +90,5 @@ jobs: train_cifar.txt llama_unjitted.txt llama_jitted.txt + gpt2_unjitted.txt + gpt2_jitted.txt diff --git a/examples/gpt2.py b/examples/gpt2.py new file mode 100644 index 0000000000..d6b763da7c --- /dev/null +++ b/examples/gpt2.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# pip3 install tiktoken + +import functools, argparse +import numpy as np +from tqdm import trange +np.set_printoptions(linewidth=200) +from typing import Optional, Tuple + +from tinygrad.helpers import Timing, getenv, dtypes, DEBUG +from tinygrad.ops import GlobalCounters +from tinygrad.lazy import Device +from tinygrad.tensor import Tensor +from tinygrad.nn import Embedding, Linear +from tinygrad.jit import TinyJit + +from examples.llama import sample + +class LayerNorm: + def __init__(self, dim, eps=1e-5): + self.eps = eps + self.weight = Tensor.ones(dim) + self.bias = Tensor.zeros(dim) + + def __call__(self, x:Tensor): + return (x.layernorm(eps=self.eps)) * self.weight + self.bias + +class Attention: + def __init__(self, dim, n_heads, linear=Linear): + self.c_attn = linear(dim, 3*dim, bias=True) + self.c_proj = linear(dim, dim, bias=True) + self.n_heads = n_heads + self.dim = dim + self.head_dim = dim // n_heads + + def prepare_attention(self, x:Tensor) -> Tuple[Tensor, Tensor, Tensor]: + xqkv = self.c_attn(x) + xq, xk, xv = [xqkv.slice([None, None, (i*self.dim, (i+1)*self.dim)]) for i in range(3)] + xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], self.n_heads, self.head_dim) for x in (xq, xk, xv)] + return xq, xk, xv + + def inner_attention(self, xq:Tensor, xk:Tensor, xv:Tensor, start_pos:int, mask:Optional[Tensor]) -> Tensor: + bsz, seqlen, _, _ = xq.shape + # kv caching! + if start_pos == 0: + keys, values = xk, xv + else: + assert hasattr(self, 'cache_k'), "no cache" + assert start_pos == self.cache_k.shape[1] and start_pos == self.cache_v.shape[1], "cache is wrong shape" + assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?" + keys, values = self.cache_k.cat(xk, dim=1), self.cache_v.cat(xv, dim=1) + + # save the cache + self.cache_k, self.cache_v = keys.realize(), values.realize() + xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) + return xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1) + + # NOTE: this is not called + def __call__(self, x:Tensor, start_pos:int, mask:Optional[Tensor]) -> Tensor: + xq, xk, xv = self.prepare_attention(x) + output = self.inner_attention(xq, xk, xv, start_pos, mask) + return self.c_proj(output) + +class FeedForward: + def __init__(self, dim, hidden_dim, linear=Linear): + self.c_fc = linear(dim, hidden_dim, bias=True) + self.c_proj = linear(hidden_dim, dim, bias=True) + + def __call__(self, x:Tensor) -> Tensor: + return self.c_proj(self.c_fc(x).gelu()) + +class TransformerBlock: + def __init__(self, dim, n_heads, norm_eps, linear=Linear): + self.attn = Attention(dim, n_heads, linear) + self.mlp = FeedForward(dim, 4*dim, linear) + self.ln_1 = LayerNorm(dim, norm_eps) + self.ln_2 = LayerNorm(dim, norm_eps) + if getenv("JIT"): + self._pre = TinyJit(self.pre) + self._post = TinyJit(self.post) + else: + self._pre, self._post = self.pre, self.post + + def pre(self, x:Tensor) -> Tuple[Tensor, Tensor, Tensor]: + xq, xk, xv = self.attn.prepare_attention(self.ln_1(x)) + return xq.realize(), xk.realize(), xv.realize() + + def post(self, x:Tensor, output:Tensor) -> Tensor: + h = x + self.attn.c_proj(output) + return (h + self.mlp(self.ln_2(h))).realize() + + def __call__(self, x:Tensor, start_pos:int, mask:Optional[Tensor]): + xq, xk, xv = self._pre(x) if mask is None else self.pre(x) + # inner_attention can't be jitted because it's dynamic based on start_pos + output = self.attn.inner_attention(xq, xk, xv, start_pos, mask) + return self._post(x, output) if mask is None else self.post(x, output) + +class Transformer: + def __init__(self, dim, n_heads, n_layers, norm_eps=1e-5, vocab_size=50257, linear=Linear, max_seq_len=1024): + self.wte = Embedding(vocab_size, dim) + self.wpe = Embedding(max_seq_len, dim) + self.h = [TransformerBlock(dim, n_heads, norm_eps, linear) for _ in range(n_layers)] + self.ln_f = LayerNorm(dim, norm_eps) + self.lm_head = linear(dim, vocab_size, bias=False) + + def __call__(self, tokens:Tensor, start_pos:int): + _bsz, seqlen = tokens.shape + tok_emb = self.wte(tokens) + pos = Tensor.arange(start_pos, start_pos + seqlen).reshape(1, -1) + pos_emb = self.wpe(pos) + h = tok_emb + pos_emb + + # get only the part we are using. making it contiguous avoids more kernel calls + mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None + h = h.sequential([functools.partial(layer, start_pos=start_pos, mask=mask) for layer in self.h]) + h = self.ln_f(h) + return self.lm_head(h) + +# **** files and arguments **** + +MODEL_PARAMS = { + 'gpt2': dict(n_layers=12, n_heads=12, dim=768), # 124M params + 'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024), # 350M params + 'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280), # 774M params + 'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600), # 1558M params +} + +def get_url(model_size): return f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin' + +class GPT2: + @staticmethod + def build(model_size="gpt2"): + import tiktoken + from tinygrad.state import torch_load, load_state_dict + from extra.utils import fetch_as_file + tokenizer = tiktoken.get_encoding("gpt2") + + params = MODEL_PARAMS[model_size] + model = Transformer(**params) + weights = torch_load(fetch_as_file(get_url(model_size))) + # special treatment for the Conv1D weights we need to transpose + transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + for k in weights.keys(): + if any(k.endswith(w) for w in transposed): + weights[k] = Tensor(weights[k].numpy().T) + # lm head and wte are tied + weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy()) + + load_state_dict(model, weights) + return GPT2(model, tokenizer) + + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer + + def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False): + toks = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"}) + start_pos = 0 + for _ in trange(max_length, disable=(timing==True)): + if args.timing: print("") + st = GlobalCounters.time_sum_s + with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=timing): + logits = self.model(Tensor([toks[start_pos:]]), start_pos).realize()[:, -1, :] + with Timing("sync in ", enabled=timing): + tok = sample(logits, temperature) + start_pos = len(toks) + toks.append(tok) + output = self.tokenizer.decode(toks) + return output + +# **** main code **** + +if __name__ == "__main__": + Tensor.no_grad = True + print(f"using {Device.DEFAULT} backend") + + parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with") + parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate") + parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax") + parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]") + parser.add_argument('--timing', action='store_true', help="Print timing per token") + args = parser.parse_args() + + print(f"using {args.model_size}") + gpt2 = GPT2.build(args.model_size) + print('Generating text...') + y = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing) + print(y)