coder.py can write and run code (#2439)

* wip mistral

* coder

* touchups

* cleanups

* mistral cleanups

* clean up cache create

* download the weights, fix tests

* fix llama loading

* global fixup

* clean up all

* move llama model

* cleanups

* Revert "cleanups"

This reverts commit a71c5d59eb.

* fine, leave it
This commit is contained in:
George Hotz
2023-11-25 12:27:54 -08:00
committed by GitHub
parent df41a57e09
commit 7170a9a057
10 changed files with 334 additions and 167 deletions

107
examples/coder.py Normal file
View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
import os, sys, traceback
sys.path.append(os.getcwd())
from io import StringIO
from contextlib import redirect_stdout
from tinygrad import Tensor, nn
from tinygrad.helpers import Timing, colored, getenv, fetch
from extra.models.llama import Transformer, convert_from_huggingface
from sentencepiece import SentencePieceProcessor
def create_fixed_tokenizer(output_file):
print("creating fixed tokenizer")
import extra.junk.sentencepiece_model_pb2 as spb2
mp = spb2.ModelProto()
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
with open(output_file, "wb") as f:
f.write(mp.SerializeToString())
# TODO: make loading bf16 fast so we can remove this
def create_model_cache(output_file, model):
print(f"creating model cache at {output_file}")
# TODO: add read only Tensors
with Timing("download weights: "):
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
with Timing("weights -> model: "):
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
with Timing("saving float16 cache: "):
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
print("cache created, rerun to use")
exit(0)
if __name__ == "__main__":
Tensor.no_grad = True
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
with Timing("create model: "):
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
cached_model = "/tmp/cached_openhermes.safetensors"
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
with Timing("loading float16 cache: "):
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
IM_END = 32000
IM_START = 32001
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
def output(outputted, toks, color):
cur = spp.decode(toks)[len(outputted):]
sys.stdout.write(colored(cur, color))
sys.stdout.flush()
outputted += cur
return outputted
# *** app below this line ***
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
PROMPT = getenv("PROMPT", 1)
temperature = getenv("TEMP", 0.7)
start_pos = 0
outputted = output("", toks, "green")
turn = True
while 1:
if PROMPT:
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
else:
toks += start_prompt("user" if turn else "assistant")
turn = not turn
old_output_len = len(outputted)
while 1:
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
start_pos = len(toks)
toks.append(tok)
outputted = output(outputted, toks, "blue" if not turn else "cyan")
if tok == IM_END: break
if tok == spp.eos_id(): break
new_output = outputted[old_output_len:]
if new_output.endswith("```") and '```python\n' in new_output:
python_code = new_output.split('```python\n')[1].split("```")[0]
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
my_stdout = StringIO()
try:
with redirect_stdout(my_stdout): exec(python_code)
result = my_stdout.getvalue()
except Exception as e:
result = ''.join(traceback.format_exception_only(e))
toks += spp.encode(f"\nOutput:\n```\n{result}```")
outputted = output(outputted, toks, "yellow")
old_output_len = len(outputted)
print("")

View File

@@ -89,7 +89,7 @@ class Transformer:
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT", int(not CI)) else self.forward)(tokens, start_pos, temperature)
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
MODEL_PARAMS = {

View File

@@ -7,145 +7,14 @@ from pathlib import Path
import sys, argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
from typing import Optional, Tuple, Union
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes, CI
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit
from tinygrad.shape.symbolic import Variable
from extra.models.llama import Transformer, convert_from_huggingface
MAX_CONTEXT = 1024
JIT = getenv("JIT", 0 if CI else 1)
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
ro = a*c - b*d
co = a*d + b*c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
xq_out = complex_mult(xq, c, d)
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1: return x
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class RMSNorm:
def __init__(self, dim, eps=1e-6):
self.eps = eps
self.weight = Tensor.ones(dim)
def __call__(self, x:Tensor):
# TODO: convert to float?
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, linear=Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
bsz, seqlen, n_heads, head_dim = xq.shape
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
# TODO: what is this?
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False)
def __call__(self, x:Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x))
class TransformerBlock:
def __init__(self, dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
self.attention = Attention(dim, n_heads, n_kv_heads, linear)
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
_bsz, seqlen = tokens.shape
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
h = self.tok_embeddings(tokens)
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and JIT:
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos), temperature)
return self.forward(tokens, start_pos, temperature)
MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
# **** files and arguments ****
MODEL_PARAMS = {
@@ -227,6 +96,23 @@ MODEL_PARAMS = {
}
}
# fix up MODEL_PARAMS to have hidden_dim
for model_gen in MODEL_PARAMS.values():
for model_type in model_gen.values():
model_args = model_type['args']
hidden_dim = model_args['dim'] * 4
multiple_of = model_args['multiple_of']
# TODO: what is this?
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
ffn_dim_multiplier = getattr(model_args, 'ffn_dim_multiplier', None)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
del model_args['ffn_dim_multiplier']
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
model_args['hidden_dim'] = hidden_dim
del model_args['multiple_of']
# **** helper functions ****
def concat_weights(models):
def convert(name) -> Tensor:
@@ -248,31 +134,6 @@ def load(fn:str):
else:
return torch_load(fn)
def convert_from_huggingface(weights, model: Transformer, n_heads: int, n_kv_heads: int):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
for k, v in weights.items():
if ".rotary_emb." in k: continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if "q_proj" in k:
v = permute(v, n_heads)
elif "k_proj" in k:
v = permute(v, n_kv_heads)
sd[keymap[k]] = v
return sd
class AbsmaxQuantizedLinear:
def __init__(self, in_features, out_features, bias=False):
assert bias == False
@@ -303,8 +164,7 @@ class LLaMa:
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {MODEL_PARAMS[model_gen][model_size]['args']['vocab_size']}"
params = MODEL_PARAMS[model_gen][model_size]
model_args = params["args"]
model = Transformer(**model_args, linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
if model_path.is_dir():
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: sentencepiece_model.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'H\003'
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC']._serialized_start=45
_globals['_TRAINERSPEC']._serialized_end=1581
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
_globals['_NORMALIZERSPEC']._serialized_start=1584
_globals['_NORMALIZERSPEC']._serialized_end=1793
_globals['_SELFTESTDATA']._serialized_start=1795
_globals['_SELFTESTDATA']._serialized_end=1916
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
_globals['_MODELPROTO']._serialized_start=1919
_globals['_MODELPROTO']._serialized_end=2429
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
# @@protoc_insertion_point(module_scope)

151
extra/models/llama.py Normal file
View File

@@ -0,0 +1,151 @@
from typing import Tuple, Union, Optional, Dict
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
ro = a*c - b*d
co = a*d + b*c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
xq_out = complex_mult(xq, c, d)
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1: return x
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class RMSNorm:
def __init__(self, dim, eps=1e-6):
self.eps = eps
self.weight = Tensor.ones(dim)
def __call__(self, x:Tensor):
# TODO: convert to float?
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
bsz, seqlen, n_heads, head_dim = xq.shape
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim, hidden_dim, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False)
def __call__(self, x:Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x))
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
self.feed_forward = FeedForward(dim, hidden_dim, linear)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
self.forward_jit = TinyJit(self.forward) if jit else None
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
_bsz, seqlen = tokens.shape
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
h = self.tok_embeddings(tokens)
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1,1) and self.forward_jit:
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
return self.forward(tokens, start_pos, temperature)
# *** helpers ***
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
for k, v in weights.items():
if ".rotary_emb." in k: continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if "q_proj" in k:
v = permute(v, n_heads)
elif "k_proj" in k:
v = permute(v, n_kv_heads)
sd[keymap[k]] = v
return sd

View File

@@ -93,7 +93,7 @@ class TestAllocators(unittest.TestCase):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def __test():
model = Transformer(**args_tiny)
derandomize_model(model)
@@ -105,7 +105,7 @@ class TestAllocators(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama_alloc_counts(self):
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def test_alloc_count(t):
model = Transformer(**args_tiny)
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))

View File

@@ -19,7 +19,7 @@ class TestJittedModels(unittest.TestCase):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
derandomize_model(model)
def test(t): return model(t, 0).realize()

View File

@@ -86,7 +86,7 @@ class TestInferenceMinKernels(unittest.TestCase):
def test_llama(self):
from examples.llama import Transformer
from tinygrad.shape.symbolic import Variable
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(100):

View File

@@ -51,7 +51,7 @@ class TestRealWorld(unittest.TestCase):
def test_llama(self):
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
derandomize_model(model)
@TinyJit

View File

@@ -1,4 +1,8 @@
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.helpers import dtypes # noqa: F401
# NOTE: these should not be relied on to be stable
from tinygrad.ops import Device # noqa: F401
from tinygrad.helpers import GlobalCounters # noqa: F401