olmoe (from stream, wip) (#9390)

* olmoest working (but not)

* it's correct

* compare ropes

* old code wasn't wrong

* default device

* no metal

* fix permute

* working

* more minimal
This commit is contained in:
George Hotz
2025-03-10 13:46:33 +08:00
committed by GitHub
parent 1d64c12f2b
commit 25847080f0
3 changed files with 124 additions and 16 deletions

86
examples/olmoe.py Normal file
View File

@@ -0,0 +1,86 @@
# https://arxiv.org/pdf/2409.02060
import numpy as np
np.set_printoptions(suppress=True, linewidth=1000)
import functools, collections, json
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import tqdm, CI, Profiling, Timing, fetch, getenv
from extra.models.llama import Transformer, Variable, convert_from_huggingface
class MixtureFeedForward:
def __init__(self, num_experts:int, activated_experts:int, dim:int, hidden_dim:int, linear=nn.Linear):
self.activated_experts = activated_experts
self.gate = nn.Linear(dim, num_experts, bias=False)
self.up_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16')
self.down_proj = Tensor.zeros(num_experts, dim, hidden_dim, dtype='bfloat16')
self.gate_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16')
def __call__(self, x:Tensor) -> Tensor:
assert x.shape[0] == 1, "only BS=1"
assert x.shape[1] == 1, "only length=1"
g = self.gate(x).float().softmax(-1)
# TODO: don't go to CPU here
choice = g.data().tolist()[0][0]
top = sorted(enumerate(choice), key=lambda x: -x[1])[:self.activated_experts]
sel, probs = Tensor([x[0] for x in top]), Tensor([x[1] for x in top])
#print(sel.numpy(), probs.numpy())
# run MoE
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1))
return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
# model is bf16, 1.3B active, 6.9B total
# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s
def fetch_weights() -> dict[str, Tensor]:
# TODO: make this lazy so the 3 fetches can happen in parallel
m1 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00001-of-00003.safetensors").to(Device.DEFAULT)
m2 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00002-of-00003.safetensors").to(Device.DEFAULT)
m3 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00003-of-00003.safetensors").to(Device.DEFAULT)
return {**nn.state.safe_load(m1), **nn.state.safe_load(m2), **nn.state.safe_load(m3)}
if __name__ == "__main__":
if getenv("TORCH"):
from transformers import OlmoeForCausalLM, AutoTokenizer
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
inputs = tokenizer("Hello", return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=30)
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(out)
exit(0)
with Timing("create model: "):
model = Transformer(n_layers=16, dim=2048, hidden_dim=1024, n_heads=16, norm_eps=1e-5, qk_norm=1e-5, max_context=1024,
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8), jit=False)
model_state_dict = nn.state.get_state_dict(model)
del model_state_dict['freqs_cis']
with Timing("fetch and load weights: "):
state = fetch_weights()
nhf_state = convert_from_huggingface(state, model, 16, 16)
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float()
nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False)
assert len(nhf_state) == 0
count = 30
temperature = 0
with Timing("load tokenizer: "):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
toks = [12092]
start_pos = 0
for i in range(count):
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), temperature).item()
toks.append(tok)
start_pos += 1
print(toks)
print(tokenizer.decode(toks))
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755,
247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!"

View File

@@ -1,6 +1,7 @@
from typing import Union, Optional, Any
import collections
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, DEBUG
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
@@ -8,6 +9,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2)
# matches meta, non hugging face weights
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
a,b = A[..., 0:1], A[..., 1:2]
@@ -32,7 +34,7 @@ def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear, qk_norm:float|None=None):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
@@ -44,6 +46,9 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
if getenv("WQKV"):
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
@@ -52,6 +57,10 @@ class Attention:
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
if self.q_norm is not None and self.k_norm is not None:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
@@ -89,8 +98,9 @@ class FeedForward:
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear,
feed_forward=FeedForward, qk_norm=None):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
@@ -151,8 +161,9 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
return output_token
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding,
n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
@@ -181,30 +192,39 @@ class Transformer:
# *** helpers ***
# TODO: model shouldn't be an input here, and n_kv_heads should support None
def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
# huggingface stores Q and K permuted! it is mostly correct without this, but without it makes RoPE different, so it will diverge after 10+ toks.
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2])
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(len(model.layers))},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
experts = collections.defaultdict(dict)
for k, v in weights.items():
if ".rotary_emb." in k: continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if "q_proj" in k and permute_layers:
v = permute(v, n_heads)
elif "k_proj" in k and permute_layers:
v = permute(v, n_kv_heads)
if ("q_proj" in k or "q_norm" in k) and permute_layers: v = permute(v, n_heads)
elif ("k_proj" in k or "k_norm" in k) and permute_layers: v = permute(v, n_kv_heads)
if '.mlp.experts.' in k:
# support MoE models
_, _, layer, _, _, expert, name, _ = k.split('.')
experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v
continue
sd[keymap[k]] = v
for k,v in experts.items(): sd[k] = Tensor.stack(*[v[i] for i in range(len(v))])
return sd
def convert_from_gguf(weights:dict[str, Tensor], model: Transformer):

View File

@@ -124,7 +124,7 @@ def get_parameters(obj) -> list[Tensor]:
"""
return list(get_state_dict(obj).values())
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
"""
Loads a state_dict into a model.
@@ -140,7 +140,8 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
```
"""
start_mem_used = GlobalCounters.mem_used
with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"):
with Timing("loaded weights in ",
lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
@@ -152,9 +153,10 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
if v.shape != state_dict[k].shape:
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
else: v.replace(state_dict[k].to(v.device)).realize()
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k])
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis))
else: v.replace(state_dict[k].to(v.device))
if realize: v.realize()
if consume: del state_dict[k]
@accept_filename