added top k sampling to examples/mamba (#12061)

This commit is contained in:
Steven Shi
2025-09-14 15:27:34 -04:00
committed by GitHub
parent 34a05b31fe
commit 25b1bc8eff

View File

@@ -279,9 +279,15 @@ def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: boo
# Loading in the prompt tokens
logits = model.forward(Tensor([tks]))[:, -1, :]
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
# TODO: topk
if sample:
tok_Tens = (logits/temp).softmax().multinomial()
scaled_logits = logits / temp
if top_k is not None:
topk_values, topk_indices = scaled_logits.topk(top_k)
filtered_logits = Tensor.full_like(scaled_logits, -float("inf"))
filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values)
tok_Tens = filtered_logits.softmax().multinomial()
else:
tok_Tens = scaled_logits.softmax().multinomial()
else:
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
tok = tok_Tens.item()
@@ -298,6 +304,7 @@ if __name__ == "__main__":
parser.add_argument("--size", type=str, default="370m",
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
parser.add_argument("--top_k", type=int, help="Limit sampling to the top k most likely tokens")
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
args = parser.parse_args()
@@ -308,8 +315,9 @@ if __name__ == "__main__":
num_toks = args.n_tokens
sample = args.sample
temp = args.temp
top_k = args.top_k
s = time.time()
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp, top_k=top_k)
print(tinyoutput)
print('TIME: ', time.time() - s)
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"