mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
JIT OLMoE (#9396)
* jit the forward * might timeout, idk just send it * this is dumb * naive bitonic lol * idk if this is correct, but that squeeze before is definitly not * vectorized bitonic sort, but still slow * yay 1 layer is correct * alright its pretty good * good enough * rerun CI * nit improve comment
This commit is contained in:
@@ -18,11 +18,8 @@ class MixtureFeedForward:
|
||||
assert x.shape[1] == 1, "only length=1"
|
||||
g = self.gate(x).float().softmax(-1)
|
||||
|
||||
# TODO: don't go to CPU here
|
||||
choice = g.data().tolist()[0][0]
|
||||
top = sorted(enumerate(choice), key=lambda x: -x[1])[:self.activated_experts]
|
||||
sel, probs = Tensor([x[0] for x in top]), Tensor([x[1] for x in top])
|
||||
#print(sel.numpy(), probs.numpy())
|
||||
g = g.squeeze() # (BS, length, num_experts) -> (num_experts,)
|
||||
probs, sel = g.topk(self.activated_experts)
|
||||
|
||||
# run MoE
|
||||
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
|
||||
@@ -52,7 +49,7 @@ if __name__ == "__main__":
|
||||
|
||||
with Timing("create model: "):
|
||||
model = Transformer(n_layers=16, dim=2048, hidden_dim=1024, n_heads=16, norm_eps=1e-5, qk_norm=1e-5, max_context=1024,
|
||||
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8), jit=False)
|
||||
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8))
|
||||
model_state_dict = nn.state.get_state_dict(model)
|
||||
del model_state_dict['freqs_cis']
|
||||
|
||||
@@ -74,7 +71,7 @@ if __name__ == "__main__":
|
||||
toks = [12092]
|
||||
start_pos = 0
|
||||
for i in range(count):
|
||||
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), temperature).item()
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
toks.append(tok)
|
||||
start_pos += 1
|
||||
print(toks)
|
||||
|
||||
Reference in New Issue
Block a user