mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
@@ -1,10 +1,10 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import json, argparse, random, time
|
||||
import json, argparse, random, time, os
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from extra.models.llama import Transformer, convert_from_huggingface, convert_from_gguf, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters, gguf_load
|
||||
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
|
||||
from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
|
||||
|
||||
@@ -57,6 +57,9 @@ def load(fn:str):
|
||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith(".gguf"):
|
||||
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
|
||||
return gguf_load(gguf_tensor)[1]
|
||||
elif fn.endswith(".safetensors"):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
@@ -128,6 +131,10 @@ def NF4Linear(block_size):
|
||||
return _NF4Linear
|
||||
|
||||
MODEL_PARAMS = {
|
||||
"1B": {
|
||||
"args": {"dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192},
|
||||
"files": 1
|
||||
},
|
||||
"8B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
|
||||
"files": 1
|
||||
@@ -153,6 +160,8 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
|
||||
weights = load(str(model_path))
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
||||
elif "token_embd.weight" in weights:
|
||||
weights = convert_from_gguf(weights, model)
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
with Context(BEAM=0):
|
||||
@@ -210,7 +219,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--download_model", action="store_true", help="Download a 8B model")
|
||||
parser.add_argument("--model", type=Path, help="Model path")
|
||||
parser.add_argument("--size", choices=["8B", "70B"], default="8B", help="Model size")
|
||||
parser.add_argument("--size", choices=["1B", "8B", "70B"], default="8B", help="Model size")
|
||||
parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
|
||||
parser.add_argument("--quantize", choices=["int8", "nf4"], help="Quantization method")
|
||||
parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
|
||||
@@ -226,12 +235,16 @@ if __name__ == "__main__":
|
||||
|
||||
assert not (args.download_model and args.model), "either download or provide model"
|
||||
if args.download_model:
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
|
||||
if args.size == "1B":
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
args.model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", "Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
|
||||
elif args.size == "8B":
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
|
||||
|
||||
assert args.model is not None, "please provide --model option"
|
||||
|
||||
|
||||
@@ -205,6 +205,21 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
|
||||
def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
|
||||
keymap = {
|
||||
"token_embd.weight": "tok_embeddings.weight",
|
||||
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"blk.{l}.attn_{x}.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v"] for l in range(len(model.layers))},
|
||||
**{f"blk.{l}.attn_output.weight": f"layers.{l}.attention.wo.weight" for l in range(len(model.layers))},
|
||||
**{f"blk.{l}.ffn_norm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"blk.{l}.ffn_{x}.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
"output_norm.weight": "norm.weight",
|
||||
"rope_freqs.weight": "rope_freqs.weight",
|
||||
}
|
||||
sd = {keymap[k]: v for k,v in weights.items()}
|
||||
sd["output.weight"] = weights["token_embd.weight"]
|
||||
return sd
|
||||
|
||||
def fix_bf16(weights:Dict[Any, Tensor]):
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
|
||||
Reference in New Issue
Block a user