mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-10 06:07:58 -05:00
support sharegpt format in evaluate.py
This commit is contained in:
18
evaluate.py
18
evaluate.py
@@ -63,8 +63,22 @@ def main():
|
||||
with tqdm(total=len(dataset), desc="Accuracy") as pbar:
|
||||
for batch_start in range(0, len(dataset), batch_size):
|
||||
batch = dataset[batch_start:batch_start + batch_size]
|
||||
prompts = [ example.split(split)[0] + split for example in batch["text"] ]
|
||||
expected_responses = [ example.split(split)[1] for example in batch["text"] ]
|
||||
if "text" in batch:
|
||||
prompts = [ example.split(split)[0] + split for example in batch["text"] ]
|
||||
expected_responses = [ example.split(split)[1] for example in batch["text"] ]
|
||||
else:
|
||||
prompts = []
|
||||
expected_responses = []
|
||||
for example in batch:
|
||||
conversation = [ { "role": x["from"], "content": x["value"] } for x in example["conversations"] if x["from"] != "assistant"]
|
||||
prompts.append(trained_tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
max_length=CTX_SIZE,
|
||||
truncation=True,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
))
|
||||
expected_responses.append([x["value"] for x in example["conversations"] if x["from"] == "assistant"][0])
|
||||
output = generate(trained_model, trained_tokenizer, prompts)
|
||||
|
||||
for model_output, expected_response in zip(output, expected_responses):
|
||||
|
||||
Reference in New Issue
Block a user