mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
coder.py can write and run code (#2439)
* wip mistral
* coder
* touchups
* cleanups
* mistral cleanups
* clean up cache create
* download the weights, fix tests
* fix llama loading
* global fixup
* clean up all
* move llama model
* cleanups
* Revert "cleanups"
This reverts commit a71c5d59eb.
* fine, leave it
This commit is contained in:
107
examples/coder.py
Normal file
107
examples/coder.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, sys, traceback
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from io import StringIO
|
||||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import Timing, colored, getenv, fetch
|
||||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def create_fixed_tokenizer(output_file):
|
||||
print("creating fixed tokenizer")
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
|
||||
# TODO: make loading bf16 fast so we can remove this
|
||||
def create_model_cache(output_file, model):
|
||||
print(f"creating model cache at {output_file}")
|
||||
# TODO: add read only Tensors
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
|
||||
|
||||
with Timing("saving float16 cache: "):
|
||||
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
||||
|
||||
print("cache created, rerun to use")
|
||||
exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
|
||||
|
||||
cached_model = "/tmp/cached_openhermes.safetensors"
|
||||
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
|
||||
with Timing("loading float16 cache: "):
|
||||
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
||||
|
||||
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
IM_END = 32000
|
||||
IM_START = 32001
|
||||
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
def output(outputted, toks, color):
|
||||
cur = spp.decode(toks)[len(outputted):]
|
||||
sys.stdout.write(colored(cur, color))
|
||||
sys.stdout.flush()
|
||||
outputted += cur
|
||||
return outputted
|
||||
|
||||
# *** app below this line ***
|
||||
|
||||
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
|
||||
|
||||
PROMPT = getenv("PROMPT", 1)
|
||||
temperature = getenv("TEMP", 0.7)
|
||||
|
||||
start_pos = 0
|
||||
outputted = output("", toks, "green")
|
||||
turn = True
|
||||
while 1:
|
||||
if PROMPT:
|
||||
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
|
||||
else:
|
||||
toks += start_prompt("user" if turn else "assistant")
|
||||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
while 1:
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
if tok == IM_END: break
|
||||
if tok == spp.eos_id(): break
|
||||
new_output = outputted[old_output_len:]
|
||||
|
||||
if new_output.endswith("```") and '```python\n' in new_output:
|
||||
python_code = new_output.split('```python\n')[1].split("```")[0]
|
||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
|
||||
my_stdout = StringIO()
|
||||
try:
|
||||
with redirect_stdout(my_stdout): exec(python_code)
|
||||
result = my_stdout.getvalue()
|
||||
except Exception as e:
|
||||
result = ''.join(traceback.format_exception_only(e))
|
||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||
outputted = output(outputted, toks, "yellow")
|
||||
old_output_len = len(outputted)
|
||||
print("")
|
||||
@@ -89,7 +89,7 @@ class Transformer:
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT", int(not CI)) else self.forward)(tokens, start_pos, temperature)
|
||||
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
MODEL_PARAMS = {
|
||||
|
||||
@@ -7,145 +7,14 @@ from pathlib import Path
|
||||
import sys, argparse, json
|
||||
import numpy as np
|
||||
np.set_printoptions(linewidth=200)
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes, CI
|
||||
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
|
||||
MAX_CONTEXT = 1024
|
||||
JIT = getenv("JIT", 0 if CI else 1)
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, linear=Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_kv_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
|
||||
# TODO: what is this?
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x))
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear=Linear, ffn_dim_multiplier=None):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, linear)
|
||||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None, rope_theta=10000):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, n_kv_heads, norm_eps, linear, ffn_dim_multiplier) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and JIT:
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
|
||||
|
||||
# **** files and arguments ****
|
||||
MODEL_PARAMS = {
|
||||
@@ -227,6 +96,23 @@ MODEL_PARAMS = {
|
||||
}
|
||||
}
|
||||
|
||||
# fix up MODEL_PARAMS to have hidden_dim
|
||||
for model_gen in MODEL_PARAMS.values():
|
||||
for model_type in model_gen.values():
|
||||
model_args = model_type['args']
|
||||
hidden_dim = model_args['dim'] * 4
|
||||
multiple_of = model_args['multiple_of']
|
||||
# TODO: what is this?
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
ffn_dim_multiplier = getattr(model_args, 'ffn_dim_multiplier', None)
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
del model_args['ffn_dim_multiplier']
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
model_args['hidden_dim'] = hidden_dim
|
||||
del model_args['multiple_of']
|
||||
|
||||
# **** helper functions ****
|
||||
def concat_weights(models):
|
||||
def convert(name) -> Tensor:
|
||||
@@ -248,31 +134,6 @@ def load(fn:str):
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
def convert_from_huggingface(weights, model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
|
||||
class AbsmaxQuantizedLinear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
@@ -303,8 +164,7 @@ class LLaMa:
|
||||
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {MODEL_PARAMS[model_gen][model_size]['args']['vocab_size']}"
|
||||
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model_args = params["args"]
|
||||
model = Transformer(**model_args, linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
|
||||
45
extra/junk/sentencepiece_model_pb2.py
Normal file
45
extra/junk/sentencepiece_model_pb2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: sentencepiece_model.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'H\003'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC']._serialized_start=45
|
||||
_globals['_TRAINERSPEC']._serialized_end=1581
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
|
||||
_globals['_NORMALIZERSPEC']._serialized_start=1584
|
||||
_globals['_NORMALIZERSPEC']._serialized_end=1793
|
||||
_globals['_SELFTESTDATA']._serialized_start=1795
|
||||
_globals['_SELFTESTDATA']._serialized_end=1916
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
|
||||
_globals['_MODELPROTO']._serialized_start=1919
|
||||
_globals['_MODELPROTO']._serialized_end=2429
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
151
extra/models/llama.py
Normal file
151
extra/models/llama.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import Tuple, Union, Optional, Dict
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x))
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit:
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
@@ -93,7 +93,7 @@ class TestAllocators(unittest.TestCase):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
def __test():
|
||||
model = Transformer(**args_tiny)
|
||||
derandomize_model(model)
|
||||
@@ -105,7 +105,7 @@ class TestAllocators(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
def test_lru_allocator_tiny_llama_alloc_counts(self):
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
def test_alloc_count(t):
|
||||
model = Transformer(**args_tiny)
|
||||
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
|
||||
|
||||
2
test/external/external_test_jit_on_models.py
vendored
2
test/external/external_test_jit_on_models.py
vendored
@@ -19,7 +19,7 @@ class TestJittedModels(unittest.TestCase):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
derandomize_model(model)
|
||||
def test(t): return model(t, 0).realize()
|
||||
|
||||
2
test/external/external_test_opt.py
vendored
2
test/external/external_test_opt.py
vendored
@@ -86,7 +86,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(100):
|
||||
|
||||
@@ -51,7 +51,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
def test_llama(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.jit import TinyJit # noqa: F401
|
||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
||||
from tinygrad.helpers import dtypes # noqa: F401
|
||||
|
||||
# NOTE: these should not be relied on to be stable
|
||||
from tinygrad.ops import Device # noqa: F401
|
||||
from tinygrad.helpers import GlobalCounters # noqa: F401
|
||||
Reference in New Issue
Block a user