mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add qwen3 moe support to tinygrad.apps.llm (#13775)
* qwen moe works * simple moe * one test * integration
This commit is contained in:
30
test/unit/test_llm_moe.py
Normal file
30
test/unit/test_llm_moe.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestMoEFeedForward(unittest.TestCase):
|
||||
def test_moe_feed_forward(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
|
||||
# set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
|
||||
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T # router strongly prefers experts 0 and 2
|
||||
block.ffn_norm.weight = Tensor.ones(dim) # identity norm
|
||||
|
||||
# input of ones -> after norm still ~ones -> experts 0,2 selected -> weighted sum of silu outputs
|
||||
h = Tensor.ones(1, 1, dim)
|
||||
out = block._feed_forward(h)
|
||||
|
||||
# expected: residual + moe_output ≈ 1 + avg(silu(1), silu(3))
|
||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -62,6 +62,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
class ExpertWeights:
|
||||
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
||||
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
||||
# sel: (T, k), w: (T, k, in, out) -> output: (T, k, out)
|
||||
# x is (T, 1, in) for gate/up (broadcast across experts) or (T, k, in) for down (per-expert)
|
||||
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
assert x.shape[-1] % 2 == 0
|
||||
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
||||
@@ -70,7 +79,7 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = head_dim
|
||||
@@ -90,10 +99,17 @@ class TransformerBlock:
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
|
||||
|
||||
# --- feed-forward ----------------------------------------------------
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if num_experts > 0:
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
|
||||
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
@@ -127,6 +143,12 @@ class TransformerBlock:
|
||||
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
assert h.shape[0] == 1, "only BS=1"
|
||||
x = h_norm.squeeze(0).unsqueeze(1) # (T, 1, D)
|
||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).squeeze(0).topk(self.num_experts_per_tok) # (T, k) each
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (T, k, D)
|
||||
return h + (x_down * probs.unsqueeze(-1)).sum(axis=1).unsqueeze(0) # (1, T, D)
|
||||
# TODO: remove the need for this contiguous
|
||||
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
||||
return h + self.ffn_down(gated)
|
||||
@@ -136,9 +158,9 @@ class TransformerBlock:
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm)
|
||||
for _ in range(num_blocks)]
|
||||
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
||||
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
@@ -171,15 +193,17 @@ class Transformer:
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
|
||||
if arch != 'qwen3':
|
||||
if arch not in ('qwen3', 'qwen3moe'):
|
||||
for name in state_dict:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict)
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
@@ -207,6 +231,7 @@ models = {
|
||||
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
|
||||
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
||||
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
||||
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
|
||||
}
|
||||
|
||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||
|
||||
Reference in New Issue
Block a user