diff --git a/evaluate.py b/evaluate.py index 0b7736a..e6984e9 100644 --- a/evaluate.py +++ b/evaluate.py @@ -63,8 +63,22 @@ def main(): with tqdm(total=len(dataset), desc="Accuracy") as pbar: for batch_start in range(0, len(dataset), batch_size): batch = dataset[batch_start:batch_start + batch_size] - prompts = [ example.split(split)[0] + split for example in batch["text"] ] - expected_responses = [ example.split(split)[1] for example in batch["text"] ] + if "text" in batch: + prompts = [ example.split(split)[0] + split for example in batch["text"] ] + expected_responses = [ example.split(split)[1] for example in batch["text"] ] + else: + prompts = [] + expected_responses = [] + for example in batch: + conversation = [ { "role": x["from"], "content": x["value"] } for x in example["conversations"] if x["from"] != "assistant"] + prompts.append(trained_tokenizer.apply_chat_template( + conversation=conversation, + max_length=CTX_SIZE, + truncation=True, + tokenize=False, + add_generation_prompt=True, + )) + expected_responses.append([x["value"] for x in example["conversations"] if x["from"] == "assistant"][0]) output = generate(trained_model, trained_tokenizer, prompts) for model_output, expected_response in zip(output, expected_responses):