mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
added top k sampling to examples/mamba (#12061)
This commit is contained in:
@@ -279,9 +279,15 @@ def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: boo
|
||||
# Loading in the prompt tokens
|
||||
logits = model.forward(Tensor([tks]))[:, -1, :]
|
||||
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
||||
# TODO: topk
|
||||
if sample:
|
||||
tok_Tens = (logits/temp).softmax().multinomial()
|
||||
scaled_logits = logits / temp
|
||||
if top_k is not None:
|
||||
topk_values, topk_indices = scaled_logits.topk(top_k)
|
||||
filtered_logits = Tensor.full_like(scaled_logits, -float("inf"))
|
||||
filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values)
|
||||
tok_Tens = filtered_logits.softmax().multinomial()
|
||||
else:
|
||||
tok_Tens = scaled_logits.softmax().multinomial()
|
||||
else:
|
||||
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
|
||||
tok = tok_Tens.item()
|
||||
@@ -298,6 +304,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--size", type=str, default="370m",
|
||||
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
|
||||
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
|
||||
parser.add_argument("--top_k", type=int, help="Limit sampling to the top k most likely tokens")
|
||||
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
|
||||
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
|
||||
args = parser.parse_args()
|
||||
@@ -308,8 +315,9 @@ if __name__ == "__main__":
|
||||
num_toks = args.n_tokens
|
||||
sample = args.sample
|
||||
temp = args.temp
|
||||
top_k = args.top_k
|
||||
s = time.time()
|
||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
|
||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp, top_k=top_k)
|
||||
print(tinyoutput)
|
||||
print('TIME: ', time.time() - s)
|
||||
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
|
||||
|
||||
Reference in New Issue
Block a user