Files
tinygrad/examples/gpt2.py
George Hotz 67ed4c4eb3 move gguf stuff from nn/state.py to llm/gguf.py (#15783)
* move gguf stuff from nn/state.py to llm/gguf.py

* docs
2026-04-20 09:41:43 +08:00

256 lines
12 KiB
Python

#!/usr/bin/env python3
import os, argparse, contextlib
from typing import Optional, Union
with contextlib.suppress(ImportError): import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.helpers import Timing, DEBUG, JIT, getenv, fetch, colored, trange
from tinygrad.llm.gguf import gguf_load
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.bench_log import BenchEvent, WallTimeEvent
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")
class Attention:
def __init__(self, dim, n_heads):
self.c_attn = Linear(dim, 3*dim, bias=True)
self.c_proj = Linear(dim, dim, bias=True)
self.n_heads = n_heads
self.dim = dim
self.head_dim = dim // n_heads
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
if mask is not None or start_pos.val == 0:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
if HALF: x = x.half()
xqkv = self.c_attn(x).reshape(None, None, 3, self.n_heads, self.head_dim)
xq, xk, xv = [xqkv[:, :, i, :, :] for i in range(3)]
bsz, seqlen, _, _ = xq.shape
# create kv cache
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
# update the cache
self.cache_kv[:, :, start_pos:start_pos+seqlen, :, :].assign(Tensor.stack(xk, xv)).realize()
if start_pos > 0:
keys = self.cache_kv[0][:, :start_pos+seqlen, :, :]
values = self.cache_kv[1][:, :start_pos+seqlen, :, :]
else:
keys = xk
values = xv
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))
class FeedForward:
def __init__(self, dim, hidden_dim):
self.c_fc = Linear(dim, hidden_dim, bias=True)
self.c_proj = Linear(hidden_dim, dim, bias=True)
def __call__(self, x:Tensor) -> Tensor:
return self.c_proj(self.c_fc(x).gelu())
class TransformerBlock:
def __init__(self, dim, n_heads, norm_eps):
self.attn = Attention(dim, n_heads)
self.mlp = FeedForward(dim, 4*dim)
self.ln_1 = LayerNorm(dim, norm_eps)
self.ln_2 = LayerNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
h = x + self.attn(self.ln_1(x), start_pos, mask).float()
return (h + self.mlp(self.ln_2(h))).contiguous()
class Transformer:
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
self.vocab_size = vocab_size
self.wte = Embedding(vocab_size, dim)
self.wpe = Embedding(max_seq_len, dim)
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
self.ln_f = LayerNorm(dim, norm_eps)
self.lm_head = Linear(dim, vocab_size, bias=False)
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0):
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
if isinstance(tokens, UOp):
seqlen = 1
tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
else:
seqlen = tokens.shape[1]
tok_emb = self.wte(tokens)
# not symbolic when consuming the prompt
selected_pos = (0, seqlen) if start_pos.val == 0 else (start_pos, start_pos+1)
pos_emb = self.wpe(self.allpos.shrink((None, selected_pos)))
h = tok_emb + pos_emb
if HALF: h = h.half()
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None
for hi in self.h: h = hi(h, start_pos, mask)
logits = self.lm_head(self.ln_f(h))
if logits.shape[1] == 0:
# special case for empty prompt
logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
else:
logits = logits[:, -1, :]
if temperature < 1e-6:
ret = logits.argmax(-1)
else:
ret = (logits / temperature).softmax().multinomial()
return ret.flatten().realize()
def __call__(self, tokens:Union[Tensor,UOp], start_pos:Variable, temperature:float=0.0) -> Tensor:
forward = (self.forward_jit if JIT and (isinstance(tokens, UOp) or tokens.shape[1] == 1) else self.forward)
return forward(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
MODEL_PARAMS = {
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
}
class GPT2:
@staticmethod
def build(model_size="gpt2"):
tokenizer = tiktoken.get_encoding("gpt2")
model = Transformer(**MODEL_PARAMS[model_size])
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
# special treatment for the Conv1D weights we need to transpose
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
for k in weights:
if k.endswith(transposed):
weights[k] = weights[k].T
# lm head and wte are tied
weights['lm_head.weight'] = weights['wte.weight']
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
load_state_dict(model, weights)
if HALF:
for l in get_state_dict(model).values():
l.replace(l.half().realize())
return GPT2(model, tokenizer)
@staticmethod
def build_gguf(model_size: str):
q_type = model_size[len("gpt2_gguf_"):].upper()
fn = fetch(f"https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.{q_type}.gguf?download=true")
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
kv_data, state_dict = gguf_load(gguf_tensor)
gpt2_params = {
"dim": kv_data["gpt2.embedding_length"], "n_heads": kv_data["gpt2.attention.head_count"],
"n_layers": kv_data["gpt2.block_count"], "norm_eps": kv_data["gpt2.attention.layer_norm_epsilon"],
"vocab_size": VOCAB_SIZE, "max_seq_len": kv_data["gpt2.context_length"],
}
def _remap_gguf_key(key: str):
replaces = [
("blk.", "h."), (".attn_qkv.bias", ".attn.c_attn.bias"), (".attn_qkv.weight", ".attn.c_attn.weight"),
(".ffn_norm.bias", ".ln_2.bias"), (".ffn_norm.weight", ".ln_2.weight"), (".attn_norm.bias", ".ln_1.bias"),
(".attn_norm.weight", ".ln_1.weight"), (".attn_output.bias", ".attn.c_proj.bias"), (".attn_output.weight", ".attn.c_proj.weight"),
(".ffn_up.bias", ".mlp.c_fc.bias"), (".ffn_up.weight", ".mlp.c_fc.weight"), (".ffn_down.bias", ".mlp.c_proj.bias"),
(".ffn_down.weight", ".mlp.c_proj.weight"), ("token_embd.weight", "wte.weight"), ("output.weight", "lm_head.weight"),
("output_norm.bias", "ln_f.bias"), ("output_norm.weight", "ln_f.weight"), ("position_embd.weight", "wpe.weight"),
]
for ostr, ns in replaces: key = key.replace(ostr, ns)
return key
state_dict = { _remap_gguf_key(k): v for k, v in state_dict.items() }
model = Transformer(**gpt2_params)
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
load_state_dict(model, state_dict)
return GPT2(model, tiktoken.get_encoding("gpt2"))
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
step_times = []
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
toks = [prompt_tokens[:] for _ in range(batch_size)]
start_pos = 0
for _ in trange(max_length, disable=(timing==True)):
GlobalCounters.reset()
if timing: print("")
st = GlobalCounters.time_sum_s
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
with WallTimeEvent(BenchEvent.STEP):
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
tokens = Variable("tokens", 0, VOCAB_SIZE-1).bind(toks[0][start_pos])
else:
tokens = Tensor([x[start_pos:] for x in toks])
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT-1).bind(start_pos), temperature).tolist()
step_times.append((GlobalCounters.time_sum_s-st)*1e3)
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
return [self.tokenizer.decode(x) for x in toks]
# **** main code ****
if __name__ == "__main__":
print(f"using {Device.DEFAULT} backend")
default_prompt = "What is the answer to life, the universe, and everything?"
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to start with")
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--seed', type=int, help="Set the random seed")
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
args = parser.parse_args()
if args.seed is not None:
Tensor.manual_seed(args.seed)
print(f"using {args.model_size}")
gpt2 = GPT2.build_gguf(args.model_size) if args.model_size.startswith("gpt2_gguf_") else GPT2.build(args.model_size)
if args.benchmark != -1:
gpt2.model(Tensor.randint(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
else:
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
if not args.noshow:
print('Generating text...')
if len(texts) == 1: print(texts[0])
else:
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
# validate output!
if args.temperature == 0 and args.model_size == "gpt2-medium" and args.count == 10:
expected = {
default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
"Hello.": "Hello. I'm a little late to the party, but",
}
try:
assert texts[0] == expected[args.prompt]
print(colored("output validated", "green"))
except KeyError:
pass