diff --git a/examples/qwq.py b/examples/qwq.py new file mode 100644 index 0000000000..7c71fec049 --- /dev/null +++ b/examples/qwq.py @@ -0,0 +1,117 @@ +import argparse +import os +import sys + +from transformers import AutoTokenizer +from pathlib import Path +from typing import Dict, Union + +from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 +from examples.llama3 import load +from tinygrad import nn, Tensor +from tinygrad.helpers import fetch, colored, GlobalCounters, Timing, DEBUG +from tinygrad.nn.state import load_state_dict, get_parameters + +MODELS = { + "32B": { + "model_params": {"dim": 5120, "n_heads": 40, "n_kv_heads": 8, "n_layers": 64, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 152064, "hidden_dim": 27648}, + "total_num_weights": 17, + "tokenizer": "Qwen/QwQ-32B-Preview" + } +} + +def download_weights(total_num_weights:int) -> Path: + model = fetch("https://huggingface.co/Qwen/QwQ-32B-Preview/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=(subdir:="qwq_32b_preview")) + + for i in range(1, total_num_weights + 1): + filename = f"model-{i:05d}-of-{total_num_weights:05d}.safetensors" + fetch(f"https://huggingface.co/Qwen/QwQ-32B-Preview/resolve/main/{filename}?download=true", filename, subdir=subdir) + + return Path(os.path.dirname(model)) + +def load_model(model_path:Path, model_params:Dict[str, Union[int, float]]) -> Transformer: + # build model + model = Transformer(**model_params, linear=nn.Linear) + + # update layers to add bias + updated_layers = [] + for layer in model.layers: + head_dim = model_params["dim"] // model_params["n_heads"] + layer.attention.wq = nn.Linear(model_params["dim"], model_params["n_heads"] * head_dim, bias=True) + layer.attention.wk = nn.Linear(model_params["dim"], model_params["n_kv_heads"] * head_dim, bias=True) + layer.attention.wv = nn.Linear(model_params["dim"], model_params["n_kv_heads"] * head_dim, bias=True) + updated_layers.append(layer) + model.layers = updated_layers + + # load weights + weights = fix_bf16(convert_from_huggingface(load(str(model_path / "model.safetensors.index.json")), model, model_params["n_heads"], model_params["n_kv_heads"], permute_layers=False)) + + # replace weights in model + load_state_dict(model, weights, strict=False, consume=True) + return model + + +if __name__ == "__main__": + Tensor.no_grad = True + + parser = argparse.ArgumentParser(description="Run QwQ in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--size", choices=["32B"], default="32B", help="Model size") + parser.add_argument("--count", type=int, default=30, help="Max number of tokens to generate") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax") + parser.add_argument("--prompt", type=str, default="Hello.", help="Phrase to start with") + parser.add_argument("--weights", type=str, default=None, help="Path to the downloaded weights") + parser.add_argument("--timing", action="store_true", help="Print timing per token") + args = parser.parse_args() + + model_info = MODELS[args.size] + + model_path = Path(args.weights) if args.weights else download_weights(model_info["total_num_weights"]) + transformer = load_model(model_path, model_info["model_params"]) + tokenizer = AutoTokenizer.from_pretrained(model_info["tokenizer"]) + param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(transformer)) + + outputted = args.prompt + start_pos, toks = 0, tokenizer(outputted)["input_ids"] + print(outputted, end="", flush=True) + + tok_tensor = None + for i in range(args.count): + GlobalCounters.reset() + + if args.timing: print("") + st = GlobalCounters.time_sum_s + next_tok = Tensor([toks[start_pos:]]) if tok_tensor is None or (len(toks)-start_pos) > 1 else tok_tensor.reshape(1, 1) + with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "") + + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB" + + (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): + tok_tensor = transformer(next_tok, start_pos, args.temperature) + tok = tok_tensor.item() + + # use the kv cache + start_pos = len(toks) + + # add the new token + toks.append(tok) + + cur = tokenizer.decode(toks, skip_special_tokens=True) + sys.stdout.write(cur[len(outputted):]) + sys.stdout.flush() + outputted = cur + + if args.temperature == 0: + text = tokenizer.decode(toks) + key = (args.size, args.count, args.prompt) + expected = { + ("32B", 10, "Hello."): "Hello. I'm trying to make a program that will read", + ( + "32B", + 50, + "Can you tell me more about machine learning?" + ): "Can you tell me more about machine learning? Sure, I'd be happy to help! Machine learning is a subset of artificial intelligence that focuses on building systems that can learn from data and make predictions or decisions without being explicitly programmed to do so. It's a fascinating field with a lot of real" + } + try: + assert text == expected[key], f"invalid output: `{colored(text, 'red')}` != `{expected[key]}`" + print("\n" + colored("output validated", "green")) + except KeyError: + pass diff --git a/extra/models/llama.py b/extra/models/llama.py index deaf1ac3ec..cd5ca1f179 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -181,7 +181,7 @@ class Transformer: # *** helpers *** -def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int): +def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True): def permute(v: Tensor, n_heads: int): return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2]) @@ -189,6 +189,7 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he "model.embed_tokens.weight": "tok_embeddings.weight", **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, + **{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, "model.norm.weight": "norm.weight", @@ -199,9 +200,9 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he if ".rotary_emb." in k: continue v = v.to(Device.DEFAULT) if "model.layers" in k: - if "q_proj" in k: + if "q_proj" in k and permute_layers: v = permute(v, n_heads) - elif "k_proj" in k: + elif "k_proj" in k and permute_layers: v = permute(v, n_kv_heads) sd[keymap[k]] = v return sd