mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
added top k sampling to examples/mamba (#12061)
This commit is contained in:
@@ -279,9 +279,15 @@ def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: boo
|
|||||||
# Loading in the prompt tokens
|
# Loading in the prompt tokens
|
||||||
logits = model.forward(Tensor([tks]))[:, -1, :]
|
logits = model.forward(Tensor([tks]))[:, -1, :]
|
||||||
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
|
||||||
# TODO: topk
|
|
||||||
if sample:
|
if sample:
|
||||||
tok_Tens = (logits/temp).softmax().multinomial()
|
scaled_logits = logits / temp
|
||||||
|
if top_k is not None:
|
||||||
|
topk_values, topk_indices = scaled_logits.topk(top_k)
|
||||||
|
filtered_logits = Tensor.full_like(scaled_logits, -float("inf"))
|
||||||
|
filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values)
|
||||||
|
tok_Tens = filtered_logits.softmax().multinomial()
|
||||||
|
else:
|
||||||
|
tok_Tens = scaled_logits.softmax().multinomial()
|
||||||
else:
|
else:
|
||||||
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
|
tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
|
||||||
tok = tok_Tens.item()
|
tok = tok_Tens.item()
|
||||||
@@ -298,6 +304,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--size", type=str, default="370m",
|
parser.add_argument("--size", type=str, default="370m",
|
||||||
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
|
help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]")
|
||||||
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
|
parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate")
|
||||||
|
parser.add_argument("--top_k", type=int, help="Limit sampling to the top k most likely tokens")
|
||||||
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
|
parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag")
|
||||||
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
|
parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -308,8 +315,9 @@ if __name__ == "__main__":
|
|||||||
num_toks = args.n_tokens
|
num_toks = args.n_tokens
|
||||||
sample = args.sample
|
sample = args.sample
|
||||||
temp = args.temp
|
temp = args.temp
|
||||||
|
top_k = args.top_k
|
||||||
s = time.time()
|
s = time.time()
|
||||||
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp)
|
tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp, top_k=top_k)
|
||||||
print(tinyoutput)
|
print(tinyoutput)
|
||||||
print('TIME: ', time.time() - s)
|
print('TIME: ', time.time() - s)
|
||||||
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
|
TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only"
|
||||||
|
|||||||
Reference in New Issue
Block a user