mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Llama: load models in HuggingFace format (incl. indexed, safetensors) (#1583)
This commit is contained in:
@@ -4,9 +4,8 @@
|
||||
#typeguard.importhook.install_import_hook('tinygrad')
|
||||
|
||||
from pathlib import Path
|
||||
import functools, sys, argparse, math, platform
|
||||
import functools, sys, argparse, json, os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
np.set_printoptions(linewidth=200)
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -14,6 +13,7 @@ from tinygrad.helpers import Timing, getenv, DEBUG, dtypes
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
@@ -225,6 +225,28 @@ def concat_weights(models):
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
|
||||
def load(fn:str):
|
||||
if fn.endswith('.index.json'):
|
||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
parts = {n: load(f'{os.path.dirname(fn)}/{os.path.basename(n)}') for n in set(weight_map.values())}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith('.safetensors'):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
def convert_from_huggingface(weights, model):
|
||||
keymap = {
|
||||
'model.embed_tokens.weight': 'tok_embeddings.weight',
|
||||
**{f'model.layers.{l}.input_layernorm.weight': f'layers.{l}.attention_norm.weight' for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.self_attn.{x}_proj.weight': f'layers.{l}.attention.w{x}.weight' for x in ['q', 'k', 'v', 'o'] for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.post_attention_layernorm.weight': f'layers.{l}.ffn_norm.weight' for l in range(len(model.layers))},
|
||||
**{f'model.layers.{l}.mlp.{x}_proj.weight': f'layers.{l}.feed_forward.w{y}.weight' for x, y in {'gate': '1', 'down': '2', 'up': '3'}.items() for l in range(len(model.layers))},
|
||||
'model.norm.weight': 'norm.weight',
|
||||
'lm_head.weight': 'output.weight',
|
||||
}
|
||||
return {keymap[k]: v for k,v in weights.items() if '.rotary_emb.' not in k}
|
||||
|
||||
class AbsmaxQuantizedLinear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
@@ -254,10 +276,16 @@ class LLaMa:
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
else:
|
||||
weights = load(str(model_path))
|
||||
if 'model.embed_tokens.weight' in weights:
|
||||
weights = convert_from_huggingface(weights, model)
|
||||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
load_state_dict(model, weights, strict=False)
|
||||
@@ -304,6 +332,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B, 70B] for Gen 2")
|
||||
parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]")
|
||||
parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument('--model', type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
|
||||
args = parser.parse_args()
|
||||
chatbot = args.prompt == None
|
||||
@@ -399,10 +428,10 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
|
||||
|
||||
LLAMA_SUFFIX = {1: "", 2: "-2"}[args.gen]
|
||||
WEIGHTS_DIR = Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/"
|
||||
TOKENIZER_FILENAME = WEIGHTS_DIR / "tokenizer.model"
|
||||
MODEL_PATH = args.model or Path(__file__).parent.parent / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
llama = LLaMa.build(WEIGHTS_DIR, TOKENIZER_FILENAME, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
||||
|
||||
if chatbot:
|
||||
# encode pre prompt
|
||||
|
||||
Reference in New Issue
Block a user