add olmoe support to llm (#13792)

* add olmoe support to llm

* cleanups

* simpler

* clean

* fix mypy

* lil

* remove dumb assert
This commit is contained in:
George Hotz
2025-12-22 10:41:35 -04:00
committed by GitHub
parent 81d9053013
commit df0f9d6860
2 changed files with 48 additions and 19 deletions

View File

@@ -26,5 +26,28 @@ class TestMoEFeedForward(unittest.TestCase):
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
def test_moe_feed_forward_batched(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
# same setup as BS=1 test
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T
block.ffn_norm.weight = Tensor.ones(dim)
# test with BS=2, T=3
h = Tensor.ones(2, 3, dim)
out = block._feed_forward(h)
# all outputs should match the BS=1 expected value
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
if __name__ == '__main__':
unittest.main()