mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add olmoe support to llm (#13792)
* add olmoe support to llm * cleanups * simpler * clean * fix mypy * lil * remove dumb assert
This commit is contained in:
@@ -26,5 +26,28 @@ class TestMoEFeedForward(unittest.TestCase):
|
|||||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||||
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_moe_feed_forward_batched(self):
|
||||||
|
from tinygrad.apps.llm import TransformerBlock
|
||||||
|
dim, hidden, n_heads = 8, 16, 2
|
||||||
|
num_experts, k = 4, 2
|
||||||
|
|
||||||
|
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||||
|
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||||
|
|
||||||
|
# same setup as BS=1 test
|
||||||
|
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||||
|
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||||
|
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
|
||||||
|
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T
|
||||||
|
block.ffn_norm.weight = Tensor.ones(dim)
|
||||||
|
|
||||||
|
# test with BS=2, T=3
|
||||||
|
h = Tensor.ones(2, 3, dim)
|
||||||
|
out = block._feed_forward(h)
|
||||||
|
|
||||||
|
# all outputs should match the BS=1 expected value
|
||||||
|
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||||
|
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler,
|
|||||||
|
|
||||||
class SimpleTokenizer:
|
class SimpleTokenizer:
|
||||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||||
@@ -52,9 +52,13 @@ class SimpleTokenizer:
|
|||||||
|
|
||||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
||||||
def role(self, role:str):
|
def role(self, role:str):
|
||||||
|
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
||||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
||||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||||
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
|
def end_turn(self, eos_id:int):
|
||||||
|
if self.preset == 'olmo': return self.encode("\n")
|
||||||
|
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
|
||||||
|
return [eos_id]
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||||
@@ -67,8 +71,7 @@ class ExpertWeights:
|
|||||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||||
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
||||||
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
||||||
# sel: (T, k), w: (T, k, in, out) -> output: (T, k, out)
|
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
|
||||||
# x is (T, 1, in) for gate/up (broadcast across experts) or (T, k, in) for down (per-expert)
|
|
||||||
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
||||||
|
|
||||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||||
@@ -79,12 +82,13 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
|||||||
|
|
||||||
class TransformerBlock:
|
class TransformerBlock:
|
||||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||||
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
|
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_kv_heads = n_kv_heads
|
self.n_kv_heads = n_kv_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.max_context = max_context
|
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
self.max_context = max_context
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
# --- attention projections (all linear, bias-free) ------------------
|
# --- attention projections (all linear, bias-free) ------------------
|
||||||
q_proj_out = self.head_dim * n_heads
|
q_proj_out = self.head_dim * n_heads
|
||||||
@@ -97,7 +101,7 @@ class TransformerBlock:
|
|||||||
# --- RMSNorms --------------------------------------------------------
|
# --- RMSNorms --------------------------------------------------------
|
||||||
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
|
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
|
||||||
|
|
||||||
# --- feed-forward (MoE or dense) -------------------------------------
|
# --- feed-forward (MoE or dense) -------------------------------------
|
||||||
if num_experts > 0:
|
if num_experts > 0:
|
||||||
@@ -114,13 +118,13 @@ class TransformerBlock:
|
|||||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||||
x_norm = self.attn_norm(x) # (B,T,D)
|
x_norm = self.attn_norm(x) # (B,T,D)
|
||||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||||
|
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||||
|
|
||||||
B, T, _ = x.shape
|
B, T, _ = x.shape
|
||||||
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||||
|
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||||
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
|
||||||
|
|
||||||
# TODO: make UOp have SupportsIndex
|
# TODO: make UOp have SupportsIndex
|
||||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
|
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
|
||||||
@@ -144,11 +148,10 @@ class TransformerBlock:
|
|||||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||||
h_norm = self.ffn_norm(h)
|
h_norm = self.ffn_norm(h)
|
||||||
if hasattr(self, 'ffn_gate_exps'):
|
if hasattr(self, 'ffn_gate_exps'):
|
||||||
assert h.shape[0] == 1, "only BS=1"
|
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||||
x = h_norm.squeeze(0).unsqueeze(1) # (T, 1, D)
|
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
|
||||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).squeeze(0).topk(self.num_experts_per_tok) # (T, k) each
|
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
|
||||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (T, k, D)
|
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||||
return h + (x_down * probs.unsqueeze(-1)).sum(axis=1).unsqueeze(0) # (1, T, D)
|
|
||||||
# TODO: remove the need for this contiguous
|
# TODO: remove the need for this contiguous
|
||||||
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
||||||
return h + self.ffn_down(gated)
|
return h + self.ffn_down(gated)
|
||||||
@@ -158,7 +161,7 @@ class TransformerBlock:
|
|||||||
|
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||||
max_context:int=0, qk_norm:bool=False, num_experts:int=0, num_experts_per_tok:int=0):
|
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
||||||
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
||||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||||
@@ -192,8 +195,8 @@ class Transformer:
|
|||||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||||
|
|
||||||
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
|
# Permute Q/K weights from interleaved to half-split RoPE layout (llama-style models only)
|
||||||
if arch not in ('qwen3', 'qwen3moe'):
|
if arch == 'llama':
|
||||||
for name in state_dict:
|
for name in state_dict:
|
||||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||||
@@ -201,8 +204,10 @@ class Transformer:
|
|||||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
||||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||||
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
|
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict,
|
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
||||||
|
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
||||||
|
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
||||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||||
@@ -232,6 +237,7 @@ models = {
|
|||||||
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
||||||
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
||||||
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
|
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
|
||||||
|
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
|
||||||
}
|
}
|
||||||
|
|
||||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||||
|
|||||||
Reference in New Issue
Block a user