remove .float calls in olmoe (#11610)

still matches torch
This commit is contained in:
chenyu
2025-08-10 17:33:22 -07:00
committed by GitHub
parent a67e0917c3
commit 630edcffd8

View File

@@ -1,8 +1,7 @@
# https://arxiv.org/pdf/2409.02060
import time
import time, functools
import numpy as np
np.set_printoptions(suppress=True, linewidth=1000)
import functools
from tinygrad import Tensor, nn, Device, GlobalCounters
from tinygrad.helpers import Timing, getenv
from extra.models.llama import Transformer, convert_from_huggingface
@@ -17,7 +16,7 @@ class MixtureFeedForward:
def __call__(self, x:Tensor) -> Tensor:
assert x.shape[0] == 1, "only BS=1"
assert x.shape[1] == 1, "only length=1"
g = self.gate(x).float().softmax(-1)
g = self.gate(x).softmax(-1)
g = g.squeeze() # (BS, length, num_experts) -> (num_experts,)
probs, sel = g.topk(self.activated_experts)
@@ -25,7 +24,7 @@ class MixtureFeedForward:
# run MoE
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1))
return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
return (x_down * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
# model is bf16, 1.3B active, 6.9B total
# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s