mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
olmoe (from stream, wip) (#9390)
* olmoest working (but not) * it's correct * compare ropes * old code wasn't wrong * default device * no metal * fix permute * working * more minimal
This commit is contained in:
86
examples/olmoe.py
Normal file
86
examples/olmoe.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# https://arxiv.org/pdf/2409.02060
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True, linewidth=1000)
|
||||
import functools, collections, json
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import tqdm, CI, Profiling, Timing, fetch, getenv
|
||||
from extra.models.llama import Transformer, Variable, convert_from_huggingface
|
||||
|
||||
class MixtureFeedForward:
|
||||
def __init__(self, num_experts:int, activated_experts:int, dim:int, hidden_dim:int, linear=nn.Linear):
|
||||
self.activated_experts = activated_experts
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.up_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16')
|
||||
self.down_proj = Tensor.zeros(num_experts, dim, hidden_dim, dtype='bfloat16')
|
||||
self.gate_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16')
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert x.shape[0] == 1, "only BS=1"
|
||||
assert x.shape[1] == 1, "only length=1"
|
||||
g = self.gate(x).float().softmax(-1)
|
||||
|
||||
# TODO: don't go to CPU here
|
||||
choice = g.data().tolist()[0][0]
|
||||
top = sorted(enumerate(choice), key=lambda x: -x[1])[:self.activated_experts]
|
||||
sel, probs = Tensor([x[0] for x in top]), Tensor([x[1] for x in top])
|
||||
#print(sel.numpy(), probs.numpy())
|
||||
|
||||
# run MoE
|
||||
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
|
||||
x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1))
|
||||
return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
|
||||
|
||||
# model is bf16, 1.3B active, 6.9B total
|
||||
# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s
|
||||
|
||||
def fetch_weights() -> dict[str, Tensor]:
|
||||
# TODO: make this lazy so the 3 fetches can happen in parallel
|
||||
m1 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00001-of-00003.safetensors").to(Device.DEFAULT)
|
||||
m2 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00002-of-00003.safetensors").to(Device.DEFAULT)
|
||||
m3 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00003-of-00003.safetensors").to(Device.DEFAULT)
|
||||
return {**nn.state.safe_load(m1), **nn.state.safe_load(m2), **nn.state.safe_load(m3)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("TORCH"):
|
||||
from transformers import OlmoeForCausalLM, AutoTokenizer
|
||||
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
||||
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
||||
inputs = tokenizer("Hello", return_tensors="pt")
|
||||
generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
print(out)
|
||||
exit(0)
|
||||
|
||||
with Timing("create model: "):
|
||||
model = Transformer(n_layers=16, dim=2048, hidden_dim=1024, n_heads=16, norm_eps=1e-5, qk_norm=1e-5, max_context=1024,
|
||||
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8), jit=False)
|
||||
model_state_dict = nn.state.get_state_dict(model)
|
||||
del model_state_dict['freqs_cis']
|
||||
|
||||
with Timing("fetch and load weights: "):
|
||||
state = fetch_weights()
|
||||
nhf_state = convert_from_huggingface(state, model, 16, 16)
|
||||
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
|
||||
for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float()
|
||||
nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False)
|
||||
assert len(nhf_state) == 0
|
||||
|
||||
count = 30
|
||||
temperature = 0
|
||||
|
||||
with Timing("load tokenizer: "):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
||||
|
||||
toks = [12092]
|
||||
start_pos = 0
|
||||
for i in range(count):
|
||||
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), temperature).item()
|
||||
toks.append(tok)
|
||||
start_pos += 1
|
||||
print(toks)
|
||||
print(tokenizer.decode(toks))
|
||||
|
||||
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
|
||||
assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755,
|
||||
247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Union, Optional, Any
|
||||
import collections
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
@@ -8,6 +9,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
|
||||
# matches meta, non hugging face weights
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[..., 0:1], A[..., 1:2]
|
||||
@@ -32,7 +34,7 @@ def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear, qk_norm:float|None=None):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
@@ -44,6 +46,9 @@ class Attention:
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
@@ -52,6 +57,10 @@ class Attention:
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
if self.q_norm is not None and self.k_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
@@ -89,8 +98,9 @@ class FeedForward:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear,
|
||||
feed_forward=FeedForward, qk_norm=None):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm)
|
||||
self.feed_forward = feed_forward(dim, hidden_dim, linear)
|
||||
self.attention_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
@@ -151,8 +161,9 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
return output_token
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding,
|
||||
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = embedding(vocab_size, dim)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
|
||||
@@ -181,30 +192,39 @@ class Transformer:
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
# TODO: model shouldn't be an input here, and n_kv_heads should support None
|
||||
def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
|
||||
# huggingface stores Q and K permuted! it is mostly correct without this, but without it makes RoPE different, so it will diverge after 10+ toks.
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
experts = collections.defaultdict(dict)
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k and permute_layers:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k and permute_layers:
|
||||
v = permute(v, n_kv_heads)
|
||||
if ("q_proj" in k or "q_norm" in k) and permute_layers: v = permute(v, n_heads)
|
||||
elif ("k_proj" in k or "k_norm" in k) and permute_layers: v = permute(v, n_kv_heads)
|
||||
if '.mlp.experts.' in k:
|
||||
# support MoE models
|
||||
_, _, layer, _, _, expert, name, _ = k.split('.')
|
||||
experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
|
||||
continue
|
||||
sd[keymap[k]] = v
|
||||
for k,v in experts.items(): sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
|
||||
return sd
|
||||
|
||||
def convert_from_gguf(weights:dict[str, Tensor], model: Transformer):
|
||||
|
||||
@@ -124,7 +124,7 @@ def get_parameters(obj) -> list[Tensor]:
|
||||
"""
|
||||
return list(get_state_dict(obj).values())
|
||||
|
||||
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
|
||||
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
|
||||
"""
|
||||
Loads a state_dict into a model.
|
||||
|
||||
@@ -140,7 +140,8 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
```
|
||||
"""
|
||||
start_mem_used = GlobalCounters.mem_used
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"):
|
||||
with Timing("loaded weights in ",
|
||||
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
|
||||
model_state_dict = get_state_dict(model)
|
||||
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
|
||||
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
|
||||
@@ -152,9 +153,10 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
if v.shape != state_dict[k].shape:
|
||||
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
||||
if isinstance(v.device, tuple):
|
||||
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
|
||||
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
|
||||
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
|
||||
else: v.replace(state_dict[k].to(v.device))
|
||||
if realize: v.realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
@accept_filename
|
||||
|
||||
Reference in New Issue
Block a user