diff --git a/examples/olmoe.py b/examples/olmoe.py index 9789cb1cef..3f216ab8dd 100644 --- a/examples/olmoe.py +++ b/examples/olmoe.py @@ -1,8 +1,7 @@ # https://arxiv.org/pdf/2409.02060 -import time +import time, functools import numpy as np np.set_printoptions(suppress=True, linewidth=1000) -import functools from tinygrad import Tensor, nn, Device, GlobalCounters from tinygrad.helpers import Timing, getenv from extra.models.llama import Transformer, convert_from_huggingface @@ -17,7 +16,7 @@ class MixtureFeedForward: def __call__(self, x:Tensor) -> Tensor: assert x.shape[0] == 1, "only BS=1" assert x.shape[1] == 1, "only length=1" - g = self.gate(x).float().softmax(-1) + g = self.gate(x).softmax(-1) g = g.squeeze() # (BS, length, num_experts) -> (num_experts,) probs, sel = g.topk(self.activated_experts) @@ -25,7 +24,7 @@ class MixtureFeedForward: # run MoE x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1)) x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1)) - return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0) + return (x_down * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0) # model is bf16, 1.3B active, 6.9B total # M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s