add olmoe support to llm (#13792)

* add olmoe support to llm

* cleanups

* simpler

* clean

* fix mypy

* lil

* remove dumb assert
This commit is contained in:
George Hotz
2025-12-22 10:41:35 -04:00
committed by GitHub
parent 81d9053013
commit df0f9d6860
2 changed files with 48 additions and 19 deletions

View File

@@ -26,5 +26,28 @@ class TestMoEFeedForward(unittest.TestCase):
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
def test_moe_feed_forward_batched(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
# same setup as BS=1 test
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T
block.ffn_norm.weight = Tensor.ones(dim)
# test with BS=2, T=3
h = Tensor.ones(2, 3, dim)
out = block._feed_forward(h)
# all outputs should match the BS=1 expected value
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
if __name__ == '__main__':
unittest.main()

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler,
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
@@ -52,9 +52,13 @@ class SimpleTokenizer:
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
def role(self, role:str):
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
def end_turn(self, eos_id:int):
if self.preset == 'olmo': return self.encode("\n")
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
return [eos_id]
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
@@ -67,8 +71,7 @@ class ExpertWeights:
def __init__(self, num_experts:int, in_features:int, out_features:int):
self.weight = Tensor.zeros(num_experts, out_features, in_features)
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
# sel: (T, k), w: (T, k, in, out) -> output: (T, k, out)
# x is (T, 1, in) for gate/up (broadcast across experts) or (T, k, in) for down (per-expert)
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
@@ -79,12 +82,13 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.max_context = max_context
self.rope_theta = rope_theta
self.max_context = max_context
self.qk_norm = qk_norm
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = self.head_dim * n_heads
@@ -97,7 +101,7 @@ class TransformerBlock:
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
# --- feed-forward (MoE or dense) -------------------------------------
if num_experts > 0:
@@ -114,13 +118,13 @@ class TransformerBlock:
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
@@ -144,11 +148,10 @@ class TransformerBlock:
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):
assert h.shape[0] == 1, "only BS=1"
x = h_norm.squeeze(0).unsqueeze(1) # (T, 1, D)
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).squeeze(0).topk(self.num_experts_per_tok) # (T, k) each
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (T, k, D)
return h + (x_down * probs.unsqueeze(-1)).sum(axis=1).unsqueeze(0) # (1, T, D)
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
# TODO: remove the need for this contiguous
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
return h + self.ffn_down(gated)
@@ -158,7 +161,7 @@ class TransformerBlock:
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
@@ -192,8 +195,8 @@ class Transformer:
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
if arch not in ('qwen3', 'qwen3moe'):
# Permute Q/K weights from interleaved to half-split RoPE layout (llama-style models only)
if arch == 'llama':
for name in state_dict:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
@@ -201,8 +204,10 @@ class Transformer:
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict,
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
@@ -232,6 +237,7 @@ models = {
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
}
# *** simple OpenAI compatible server on 11434 to match ollama ***