diff --git a/examples/olmoe.py b/examples/olmoe.py new file mode 100644 index 0000000000..b7bd383669 --- /dev/null +++ b/examples/olmoe.py @@ -0,0 +1,86 @@ +# https://arxiv.org/pdf/2409.02060 +import numpy as np +np.set_printoptions(suppress=True, linewidth=1000) +import functools, collections, json +from tinygrad import Tensor, nn, Device +from tinygrad.helpers import tqdm, CI, Profiling, Timing, fetch, getenv +from extra.models.llama import Transformer, Variable, convert_from_huggingface + +class MixtureFeedForward: + def __init__(self, num_experts:int, activated_experts:int, dim:int, hidden_dim:int, linear=nn.Linear): + self.activated_experts = activated_experts + self.gate = nn.Linear(dim, num_experts, bias=False) + self.up_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16') + self.down_proj = Tensor.zeros(num_experts, dim, hidden_dim, dtype='bfloat16') + self.gate_proj = Tensor.zeros(num_experts, hidden_dim, dim, dtype='bfloat16') + def __call__(self, x:Tensor) -> Tensor: + assert x.shape[0] == 1, "only BS=1" + assert x.shape[1] == 1, "only length=1" + g = self.gate(x).float().softmax(-1) + + # TODO: don't go to CPU here + choice = g.data().tolist()[0][0] + top = sorted(enumerate(choice), key=lambda x: -x[1])[:self.activated_experts] + sel, probs = Tensor([x[0] for x in top]), Tensor([x[1] for x in top]) + #print(sel.numpy(), probs.numpy()) + + # run MoE + x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1)) + x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1)) + return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0) + +# model is bf16, 1.3B active, 6.9B total +# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s + +def fetch_weights() -> dict[str, Tensor]: + # TODO: make this lazy so the 3 fetches can happen in parallel + m1 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00001-of-00003.safetensors").to(Device.DEFAULT) + m2 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00002-of-00003.safetensors").to(Device.DEFAULT) + m3 = Tensor.from_url("https://huggingface.co/allenai/OLMoE-1B-7B-0924/resolve/main/model-00003-of-00003.safetensors").to(Device.DEFAULT) + return {**nn.state.safe_load(m1), **nn.state.safe_load(m2), **nn.state.safe_load(m3)} + +if __name__ == "__main__": + if getenv("TORCH"): + from transformers import OlmoeForCausalLM, AutoTokenizer + model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924") + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") + inputs = tokenizer("Hello", return_tensors="pt") + generate_ids = model.generate(inputs.input_ids, max_length=30) + out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(out) + exit(0) + + with Timing("create model: "): + model = Transformer(n_layers=16, dim=2048, hidden_dim=1024, n_heads=16, norm_eps=1e-5, qk_norm=1e-5, max_context=1024, + vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8), jit=False) + model_state_dict = nn.state.get_state_dict(model) + del model_state_dict['freqs_cis'] + + with Timing("fetch and load weights: "): + state = fetch_weights() + nhf_state = convert_from_huggingface(state, model, 16, 16) + # NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this + for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float() + nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False) + assert len(nhf_state) == 0 + + count = 30 + temperature = 0 + + with Timing("load tokenizer: "): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924") + + toks = [12092] + start_pos = 0 + for i in range(count): + tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), temperature).item() + toks.append(tok) + start_pos += 1 + print(toks) + print(tokenizer.decode(toks)) + + # Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a + assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755, + 247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!" + diff --git a/extra/models/llama.py b/extra/models/llama.py index d7db5f63ad..0638888efa 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -1,6 +1,7 @@ from typing import Union, Optional, Any +import collections from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, DEBUG # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: @@ -8,6 +9,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).reshape(1, end, 1, dim//2, 2) +# matches meta, non hugging face weights # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc) def complex_mult(A, c, d): a,b = A[..., 0:1], A[..., 1:2] @@ -32,7 +34,7 @@ def repeat_kv(x:Tensor, n_rep:int) -> Tensor: return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) class Attention: - def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear): + def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear, qk_norm:float|None=None): self.n_heads = n_heads self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1] self.head_dim = dim // n_heads @@ -44,6 +46,9 @@ class Attention: self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) + self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None + self.k_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None + def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor: if getenv("WQKV"): if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight) @@ -52,6 +57,10 @@ class Attention: else: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + if self.q_norm is not None and self.k_norm is not None: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) @@ -89,8 +98,9 @@ class FeedForward: return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)] class TransformerBlock: - def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward): - self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear) + def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, + feed_forward=FeedForward, qk_norm=None): + self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear, qk_norm) self.feed_forward = feed_forward(dim, hidden_dim, linear) self.attention_norm = nn.RMSNorm(dim, norm_eps) self.ffn_norm = nn.RMSNorm(dim, norm_eps) @@ -151,8 +161,9 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): return output_token class Transformer: - def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): - self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)] + def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, + n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward, qk_norm=None): + self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward, qk_norm=qk_norm) for _ in range(n_layers)] self.norm = nn.RMSNorm(dim, norm_eps) self.tok_embeddings = embedding(vocab_size, dim) self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False) @@ -181,30 +192,39 @@ class Transformer: # *** helpers *** +# TODO: model shouldn't be an input here, and n_kv_heads should support None def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True): + # huggingface stores Q and K permuted! it is mostly correct without this, but without it makes RoPE different, so it will diverge after 10+ toks. def permute(v: Tensor, n_heads: int): - return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2]) + return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1] if len(v.shape) > 1 else 1).transpose(1, 2).reshape(*v.shape[:2]) keymap = { "model.embed_tokens.weight": "tok_embeddings.weight", **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, + **{f"model.layers.{l}.self_attn.{x}_norm.weight": f"layers.{l}.attention.{x}_norm.weight" for x in ["q", "k"] for l in range(len(model.layers))}, **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, **{f"model.layers.{l}.self_attn.{x}_proj.bias": f"layers.{l}.attention.w{x}.bias" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, + **{f"model.layers.{l}.mlp.gate.weight": f"layers.{l}.feed_forward.gate.weight" for l in range(len(model.layers))}, "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } sd = {} + experts = collections.defaultdict(dict) for k, v in weights.items(): if ".rotary_emb." in k: continue v = v.to(Device.DEFAULT) if "model.layers" in k: - if "q_proj" in k and permute_layers: - v = permute(v, n_heads) - elif "k_proj" in k and permute_layers: - v = permute(v, n_kv_heads) + if ("q_proj" in k or "q_norm" in k) and permute_layers: v = permute(v, n_heads) + elif ("k_proj" in k or "k_norm" in k) and permute_layers: v = permute(v, n_kv_heads) + if '.mlp.experts.' in k: + # support MoE models + _, _, layer, _, _, expert, name, _ = k.split('.') + experts[f'layers.{layer}.feed_forward.{name}'][int(expert)] = v + continue sd[keymap[k]] = v + for k,v in experts.items(): sd[k] = Tensor.stack(*[v[i] for i in range(len(v))]) return sd def convert_from_gguf(weights:dict[str, Tensor], model: Transformer): diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 024ae6aba4..12c56214cc 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -124,7 +124,7 @@ def get_parameters(obj) -> list[Tensor]: """ return list(get_state_dict(obj).values()) -def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False) -> None: +def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None: """ Loads a state_dict into a model. @@ -140,7 +140,8 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr ``` """ start_mem_used = GlobalCounters.mem_used - with Timing("loaded weights in ", lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s"): + with Timing("loaded weights in ", + lambda et_ns: f", {(B:=(GlobalCounters.mem_used-start_mem_used))/1e9:.2f} GB loaded at {B/et_ns:.2f} GB/s", enabled=verbose): model_state_dict = get_state_dict(model) if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) @@ -152,9 +153,10 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr if v.shape != state_dict[k].shape: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') if isinstance(v.device, tuple): - if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize() - else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize() - else: v.replace(state_dict[k].to(v.device)).realize() + if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]) + else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)) + else: v.replace(state_dict[k].to(v.device)) + if realize: v.realize() if consume: del state_dict[k] @accept_filename