* jit the forward

* might timeout, idk just send it

* this is dumb

* naive bitonic lol

* idk if this is correct, but that squeeze before is definitly not

* vectorized bitonic sort, but still slow

* yay 1 layer is correct

* alright its pretty good

* good enough

* rerun CI

* nit improve comment
This commit is contained in:
geohotstan
2025-03-19 02:49:02 +08:00
committed by GitHub
parent 5c56cac0a0
commit f7506c6c25

View File

@@ -18,11 +18,8 @@ class MixtureFeedForward:
assert x.shape[1] == 1, "only length=1"
g = self.gate(x).float().softmax(-1)
# TODO: don't go to CPU here
choice = g.data().tolist()[0][0]
top = sorted(enumerate(choice), key=lambda x: -x[1])[:self.activated_experts]
sel, probs = Tensor([x[0] for x in top]), Tensor([x[1] for x in top])
#print(sel.numpy(), probs.numpy())
g = g.squeeze() # (BS, length, num_experts) -> (num_experts,)
probs, sel = g.topk(self.activated_experts)
# run MoE
x_up_gate = x.dot(self.gate_proj[sel].permute(0,2,1)).silu() * x.dot(self.up_proj[sel].permute(0,2,1))
@@ -52,7 +49,7 @@ if __name__ == "__main__":
with Timing("create model: "):
model = Transformer(n_layers=16, dim=2048, hidden_dim=1024, n_heads=16, norm_eps=1e-5, qk_norm=1e-5, max_context=1024,
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8), jit=False)
vocab_size=50304, feed_forward=functools.partial(MixtureFeedForward, 64, 8))
model_state_dict = nn.state.get_state_dict(model)
del model_state_dict['freqs_cis']
@@ -74,7 +71,7 @@ if __name__ == "__main__":
toks = [12092]
start_pos = 0
for i in range(count):
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), temperature).item()
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
toks.append(tok)
start_pos += 1
print(toks)