From 25b1bc8effe248b12de36f66c8d7caac761536fd Mon Sep 17 00:00:00 2001 From: Steven Shi <60683228+Steven-Yiran@users.noreply.github.com> Date: Sun, 14 Sep 2025 15:27:34 -0400 Subject: [PATCH] added top k sampling to examples/mamba (#12061) --- examples/mamba.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/mamba.py b/examples/mamba.py index d6093eabf5..d309807484 100644 --- a/examples/mamba.py +++ b/examples/mamba.py @@ -279,9 +279,15 @@ def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: boo # Loading in the prompt tokens logits = model.forward(Tensor([tks]))[:, -1, :] for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): - # TODO: topk if sample: - tok_Tens = (logits/temp).softmax().multinomial() + scaled_logits = logits / temp + if top_k is not None: + topk_values, topk_indices = scaled_logits.topk(top_k) + filtered_logits = Tensor.full_like(scaled_logits, -float("inf")) + filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values) + tok_Tens = filtered_logits.softmax().multinomial() + else: + tok_Tens = scaled_logits.softmax().multinomial() else: tok_Tens = logits.argmax(axis=-1).unsqueeze(0) tok = tok_Tens.item() @@ -298,6 +304,7 @@ if __name__ == "__main__": parser.add_argument("--size", type=str, default="370m", help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]") parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate") + parser.add_argument("--top_k", type=int, help="Limit sampling to the top k most likely tokens") parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag") parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0") args = parser.parse_args() @@ -308,8 +315,9 @@ if __name__ == "__main__": num_toks = args.n_tokens sample = args.sample temp = args.temp + top_k = args.top_k s = time.time() - tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp) + tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp, top_k=top_k) print(tinyoutput) print('TIME: ', time.time() - s) TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"