From 79f4627fbcf012da45c7fca47c1660808d8e03bf Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 14 Jan 2024 13:10:01 -0500 Subject: [PATCH] fix conversation: llama generates token not prob now (#3120) --- examples/conversation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/conversation.py b/examples/conversation.py index 750829c672..e3ff7645cf 100644 --- a/examples/conversation.py +++ b/examples/conversation.py @@ -79,8 +79,7 @@ def llama_generate( outputted = llama.tokenizer.decode(toks) init_length = len(outputted) for _ in range(max_tokens): - probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy() - token = int(np.random.choice(len(probs_np), p=probs_np)) + token = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).item() start_pos = len(toks) toks.append(token)