From 75a6a03664c1f79c2f30751d6a48b6291b682135 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 21 Dec 2025 12:36:02 -0400 Subject: [PATCH] add qwen3 moe support to tinygrad.apps.llm (#13775) * qwen moe works * simple moe * one test * integration --- test/unit/test_llm_moe.py | 30 +++++++++++++++++++++++++ tinygrad/apps/llm.py | 47 ++++++++++++++++++++++++++++++--------- 2 files changed, 66 insertions(+), 11 deletions(-) create mode 100644 test/unit/test_llm_moe.py diff --git a/test/unit/test_llm_moe.py b/test/unit/test_llm_moe.py new file mode 100644 index 0000000000..df55180f22 --- /dev/null +++ b/test/unit/test_llm_moe.py @@ -0,0 +1,30 @@ +import unittest +import numpy as np +from tinygrad import Tensor + +class TestMoEFeedForward(unittest.TestCase): + def test_moe_feed_forward(self): + from tinygrad.apps.llm import TransformerBlock + dim, hidden, n_heads = 8, 16, 2 + num_experts, k = 4, 2 + + block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads, + rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k) + + # set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2 + block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)]) + block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)]) + block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)]) + block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T # router strongly prefers experts 0 and 2 + block.ffn_norm.weight = Tensor.ones(dim) # identity norm + + # input of ones -> after norm still ~ones -> experts 0,2 selected -> weighted sum of silu outputs + h = Tensor.ones(1, 1, dim) + out = block._feed_forward(h) + + # expected: residual + moe_output ≈ 1 + avg(silu(1), silu(3)) + expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2 + np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index bc07934396..342c3bcbf5 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -62,6 +62,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) return freqs.cos().cat(freqs.sin(), dim=-1).contiguous() +class ExpertWeights: + """Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features).""" + def __init__(self, num_experts:int, in_features:int, out_features:int): + self.weight = Tensor.zeros(num_experts, out_features, in_features) + def __call__(self, sel:Tensor, x:Tensor) -> Tensor: + # sel: (T, k), w: (T, k, in, out) -> output: (T, k, out) + # x is (T, 1, in) for gate/up (broadcast across experts) or (T, k, in) for down (per-expert) + return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2) + def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: assert x.shape[-1] % 2 == 0 cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1) @@ -70,7 +79,7 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float, - max_context:int=0, qk_norm:bool=False): + max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0): self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim @@ -90,10 +99,17 @@ class TransformerBlock: self.ffn_norm = nn.RMSNorm(dim, norm_eps) if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps) - # --- feed-forward ---------------------------------------------------- - self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False) - self.ffn_up = nn.Linear(dim, hidden_dim, bias=False) - self.ffn_down = nn.Linear(hidden_dim, dim, bias=False) + # --- feed-forward (MoE or dense) ------------------------------------- + if num_experts > 0: + self.num_experts_per_tok = num_experts_per_tok + self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router + self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim) + self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim) + self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim) + else: + self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False) + self.ffn_up = nn.Linear(dim, hidden_dim, bias=False) + self.ffn_down = nn.Linear(hidden_dim, dim, bias=False) def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: x_norm = self.attn_norm(x) # (B,T,D) @@ -127,6 +143,12 @@ class TransformerBlock: def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) + if hasattr(self, 'ffn_gate_exps'): + assert h.shape[0] == 1, "only BS=1" + x = h_norm.squeeze(0).unsqueeze(1) # (T, 1, D) + probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).squeeze(0).topk(self.num_experts_per_tok) # (T, k) each + x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (T, k, D) + return h + (x_down * probs.unsqueeze(-1)).sum(axis=1).unsqueeze(0) # (1, T, D) # TODO: remove the need for this contiguous gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm) return h + self.ffn_down(gated) @@ -136,9 +158,9 @@ class TransformerBlock: class Transformer: def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float, - max_context:int=0, qk_norm:bool=False): - self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm) - for _ in range(num_blocks)] + max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0): + self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm, + num_experts, num_experts_per_tok) for _ in range(num_blocks)] self.token_embd = nn.Embedding(vocab_size, dim) self.output_norm = nn.RMSNorm(dim, norm_eps) self.output = nn.Linear(dim, vocab_size, bias=False) @@ -171,15 +193,17 @@ class Transformer: n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv'] # permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...] - if arch != 'qwen3': + if arch not in ('qwen3', 'qwen3moe'): for name in state_dict: if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2) if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2) - model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'], + model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], + hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']), n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'], - rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict) + rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict, + num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0)) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused # NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous()) @@ -207,6 +231,7 @@ models = { "qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf", "qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf", "qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf", + "qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf", } # *** simple OpenAI compatible server on 11434 to match ollama ***