diff --git a/examples/llama.py b/examples/llama.py index 4139a91e88..f3a3deaf6a 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -133,13 +133,14 @@ class TransformerBlock: return (h + self.feed_forward(self.ffn_norm(h))).realize(), cache_k.realize(), cache_v.realize() class Transformer: - def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None): + def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000): self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)] self.kv_caches = [(None, None) for _ in range(n_layers)] self.norm = RMSNorm(dim, norm_eps) self.tok_embeddings = Embedding(vocab_size, dim) self.output = linear(dim, vocab_size, bias=False) - self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2)) + self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)) + self.norm_output = lambda x: self.output(self.norm(x)) self.tok_embeddings_jitted = TinyJit(lambda x: self.tok_embeddings(x).realize()) self.postprocess_jitted = TinyJit(self.postprocess) @@ -176,41 +177,77 @@ class Transformer: return self.postprocess(h, temperature) # **** files and arguments **** - -VOCAB_SIZE = 32000 MODEL_PARAMS = { - 1: { + "1": { "7B": { - "args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000}, "files": 1, }, "13B": { - "args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000}, "files": 2, }, "30B": { - "args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000}, "files": 4, }, "65B": { - "args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000}, "files": 8, }, }, - 2: { + "2": { "7B": { - "args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000}, "files": 1, }, "13B": { - "args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000}, "files": 2, }, "70B": { - "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": VOCAB_SIZE}, + "args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000}, "files": 8, }, }, + "code": { + "7B": { + "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016}, + "files": 1, + }, + "7B-Python": { + "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000}, + "files": 1, + }, + "7B-Instruct": { + "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016}, + "files": 1, + }, + "13B": { + "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016}, + "files": 2, + }, + "13B-Python": { + "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000}, + "files": 2, + }, + "13B-Instruct": { + "args": {"dim": 5120, "n_layers": 40, "n_headvocab_sizes": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000}, + "files": 2, + }, + "34B": { + "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016}, + "files": 4, + }, + "34B-Python": { + "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000}, + "files": 4, + }, + "34B-Instruct": { + "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000}, + "files": 4, + }, + } } # **** helper functions **** @@ -219,7 +256,7 @@ def concat_weights(models): disk_tensors = [model[name] for model in models] if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: return disk_tensors[0].to(device=Device.DEFAULT) - axis = 1 if name.startswith('tok_embeddings.') or name.endswith('.attention.wo.weight') or name.endswith('.feed_forward.w2.weight') else 0 + axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0 lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors] return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis) return {name: convert(name) for name in {name: None for model in models for name in model}} @@ -229,22 +266,22 @@ def load(fn:str): with open(fn) as fp: weight_map = json.load(fp)['weight_map'] parts = {n: load(Path(fn).parent / Path(n).name) for n in set(weight_map.values())} return {k: parts[n][k] for k, n in weight_map.items()} - elif fn.endswith('.safetensors'): + elif fn.endswith(".safetensors"): return safe_load(fn) else: return torch_load(fn) def convert_from_huggingface(weights, model): keymap = { - 'model.embed_tokens.weight': 'tok_embeddings.weight', - **{f'model.layers.{l}.input_layernorm.weight': f'layers.{l}.attention_norm.weight' for l in range(len(model.layers))}, - **{f'model.layers.{l}.self_attn.{x}_proj.weight': f'layers.{l}.attention.w{x}.weight' for x in ['q', 'k', 'v', 'o'] for l in range(len(model.layers))}, - **{f'model.layers.{l}.post_attention_layernorm.weight': f'layers.{l}.ffn_norm.weight' for l in range(len(model.layers))}, - **{f'model.layers.{l}.mlp.{x}_proj.weight': f'layers.{l}.feed_forward.w{y}.weight' for x, y in {'gate': '1', 'down': '2', 'up': '3'}.items() for l in range(len(model.layers))}, - 'model.norm.weight': 'norm.weight', - 'lm_head.weight': 'output.weight', + "model.embed_tokens.weight": "tok_embeddings.weight", + **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, + **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, + **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, + **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", } - return {keymap[k]: v for k,v in weights.items() if '.rotary_emb.' not in k} + return {keymap[k]: v for k,v in weights.items() if ".rotary_emb." not in k} class AbsmaxQuantizedLinear: def __init__(self, in_features, out_features, bias=False): @@ -259,7 +296,7 @@ class AbsmaxQuantizedLinear: def quantize(tensors): new_tensors = {} for name,v in tensors.items(): - if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight': + if "feed_forward" in name or ("attention.w") in name or name == "output.weight": scale = v.abs().max(axis=1) / 127.0 int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8) new_tensors[name] = int8_weight @@ -270,10 +307,10 @@ class AbsmaxQuantizedLinear: class LLaMa: @staticmethod - def build(model_path, tokenizer_path, model_gen=1, model_size="7B", quantize=False): + def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False): from sentencepiece import SentencePieceProcessor sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) - assert sp_model.vocab_size() == VOCAB_SIZE + assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"] params = MODEL_PARAMS[model_gen][model_size] model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"]) @@ -282,7 +319,7 @@ class LLaMa: weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) else: weights = load(str(model_path)) - if 'model.embed_tokens.weight' in weights: + if "model.embed_tokens.weight" in weights: weights = convert_from_huggingface(weights, model) if quantize: @@ -313,27 +350,81 @@ class LLaMa: return output # **** main code **** +""" +test: +python3 examples/llama.py --temperature=0 --count=50 --prompt="Hello." +output: +Hello. I'm a 20 year old male. I'm a student at the University of Texas at Austin. I'm a sophomore majoring in Computer Science. +test: +python3 examples/llama.py --gen='2' --temperature=0 --count=50 --prompt="Hello." +output: +Hello. I'm a 20 year old girl who is looking for a good lay in Palm Coast. I don't care whether it's at your place or not, as long as it's clean. + +test: +python3 examples/llama.py --gen="code" --temperature=0.2 --count=50 --prompt="\ +import argparse + +def main(string: str): + print(string) + print(string[::-1]) + +if __name__ == "__main__":" +output: + parser = argparse.ArgumentParser() + parser.add_argument('string', type=str, help='string to be reversed') + args = parser.parse_args() + main(args.string) + +test: +python3 examples/llama.py --gen="code" --size="7B-Python" --temperature=0.2 --count=70 --prompt="def add_elements(arr,k):" +output: + for i in range(len(arr)): + arr[i] += k + return arr + + +arr = [1, 2, 3, 4, 5] +k = 2 +print(add_elements(arr, k)) + +test: +python3 examples/llama.py --gen="code" --size="7B-Instruct" --temperature=0.2 --count=120 --prompt="write a function in c++ that adds three float numbers" +output: +\begin{code} +#include +using namespace std; + +float add(float a, float b, float c) +{ + return a+b+c; +} + +int main() +{ + float a, b, c; + cout<<"Enter three numbers: "; + cin>>a>>b>>c; + cout<<"The sum is: "<