remove .float calls in olmoe (#11610)

still matches torch
This commit is contained in:
chenyu
2025-08-10 17:33:22 -07:00
committed by GitHub
parent a67e0917c3
commit 630edcffd8

View File

@@ -1,8 +1,7 @@
# https://arxiv.org/pdf/2409.02060 # https://arxiv.org/pdf/2409.02060
import time import time, functools
import numpy as np import numpy as np
np.set_printoptions(suppress=True, linewidth=1000) np.set_printoptions(suppress=True, linewidth=1000)
import functools
from tinygrad import Tensor, nn, Device, GlobalCounters from tinygrad import Tensor, nn, Device, GlobalCounters
from tinygrad.helpers import Timing, getenv from tinygrad.helpers import Timing, getenv
from extra.models.llama import Transformer, convert_from_huggingface from extra.models.llama import Transformer, convert_from_huggingface
@@ -17,7 +16,7 @@ class MixtureFeedForward:
def __call__(self, x:Tensor) -> Tensor: def __call__(self, x:Tensor) -> Tensor:
assert x.shape[0] == 1, "only BS=1" assert x.shape[0] == 1, "only BS=1"
assert x.shape[1] == 1, "only length=1" assert x.shape[1] == 1, "only length=1"
g = self.gate(x).float().softmax(-1) g = self.gate(x).softmax(-1)
g = g.squeeze() # (BS, length, num_experts) -> (num_experts,) g = g.squeeze() # (BS, length, num_experts) -> (num_experts,)
probs, sel = g.topk(self.activated_experts) probs, sel = g.topk(self.activated_experts)
@@ -25,7 +24,7 @@ class MixtureFeedForward:
# run MoE # run MoE
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1)) x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1)) x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1))
return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0) return (x_down * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
# model is bf16, 1.3B active, 6.9B total # model is bf16, 1.3B active, 6.9B total
# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s # M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s