mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
@@ -1,8 +1,7 @@
|
||||
# https://arxiv.org/pdf/2409.02060
|
||||
import time
|
||||
import time, functools
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True, linewidth=1000)
|
||||
import functools
|
||||
from tinygrad import Tensor, nn, Device, GlobalCounters
|
||||
from tinygrad.helpers import Timing, getenv
|
||||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
@@ -17,7 +16,7 @@ class MixtureFeedForward:
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert x.shape[0] == 1, "only BS=1"
|
||||
assert x.shape[1] == 1, "only length=1"
|
||||
g = self.gate(x).float().softmax(-1)
|
||||
g = self.gate(x).softmax(-1)
|
||||
|
||||
g = g.squeeze() # (BS, length, num_experts) -> (num_experts,)
|
||||
probs, sel = g.topk(self.activated_experts)
|
||||
@@ -25,7 +24,7 @@ class MixtureFeedForward:
|
||||
# run MoE
|
||||
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
|
||||
x_down = x_up_gate.dot(self.down_proj[sel].permute(0,2,1))
|
||||
return (x_down.float() * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
|
||||
return (x_down * probs.reshape(self.activated_experts, 1, 1)).sum(axis=0)
|
||||
|
||||
# model is bf16, 1.3B active, 6.9B total
|
||||
# M3 Max is 400 GB/s, so 400/2.6 = ~154 tok/s
|
||||
|
||||
Reference in New Issue
Block a user