mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 13:48:05 -05:00
fix evaluate script
This commit is contained in:
31
evaluate.py
31
evaluate.py
@@ -11,8 +11,8 @@ CTX_SIZE = 2048
|
||||
def tokenize(tokenizer, prompt):
|
||||
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=CTX_SIZE)
|
||||
|
||||
def generate(model, tokenizer, prompt):
|
||||
inputs = tokenize(tokenizer, prompt)
|
||||
def generate(model, tokenizer, prompts):
|
||||
inputs = tokenize(tokenizer, prompts)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(**inputs)
|
||||
text = tokenizer.batch_decode(outputs)
|
||||
@@ -21,18 +21,21 @@ def generate(model, tokenizer, prompt):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate the function calling for a model")
|
||||
parser.add_argument("model")
|
||||
parser.add_argument("--dataset_file", default="./data/home_assistant_test.json")
|
||||
parser.add_argument("--split", default="<|im_start|>assistant")
|
||||
parser.add_argument("--dataset_file", default="./data/home_assistant_test.jsonl")
|
||||
parser.add_argument("--batch-size", default=8)
|
||||
|
||||
args = parser.parse_args()
|
||||
model_folder = f"./models/{args.model}"
|
||||
split = args.split
|
||||
|
||||
dataset = load_dataset("json", data_files={ "train": args.dataset_file })["train"]
|
||||
|
||||
print(f"Got {len(dataset)} examples to test")
|
||||
|
||||
# filter out examples that are status requests
|
||||
dataset = dataset.filter(lambda example: "```homeassistant" in example["text"])
|
||||
if "text" in dataset:
|
||||
dataset = dataset.filter(lambda example: "```homeassistant" in example["text"])
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: "```homeassistant" in example["conversations"][2]["value"])
|
||||
|
||||
service_call_regex = re.compile(r"```homeassistant\n([\S \t\n]*?)```")
|
||||
|
||||
@@ -50,15 +53,23 @@ def main():
|
||||
top_p=1.0,
|
||||
repetition_penalty=1.15,
|
||||
eos_token_id=trained_model.config.eos_token_id,
|
||||
pad_token_id=trained_model.config.pad_token_id,
|
||||
pad_token_id=trained_model.config.pad_token_id if trained_model.config.pad_token_id else trained_model.config.eos_token_id,
|
||||
)
|
||||
|
||||
split = trained_tokenizer.apply_chat_template(conversation=[{"role": "assistant", "content": r"%%%%%%%%%%%%%%%%"}], tokenize=False).split( r"%%%%%%%%%%%%%%%%")[0]
|
||||
|
||||
print("Evaluating...")
|
||||
batch_size = int(args.batch_size)
|
||||
correct_answers = 0
|
||||
total_answers = 0
|
||||
color_mismatches = 0
|
||||
|
||||
# pre-allocate cuda buffers
|
||||
inputs = trained_tokenizer([""] * batch_size, return_tensors="pt", max_length=CTX_SIZE, padding="max_length", truncation=True)
|
||||
inputs = {k: v.to(trained_model.device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = trained_model(**inputs)
|
||||
|
||||
failed_examples = []
|
||||
with tqdm(total=len(dataset), desc="Accuracy") as pbar:
|
||||
for batch_start in range(0, len(dataset), batch_size):
|
||||
@@ -69,8 +80,8 @@ def main():
|
||||
else:
|
||||
prompts = []
|
||||
expected_responses = []
|
||||
for example in batch:
|
||||
conversation = [ { "role": x["from"], "content": x["value"] } for x in example["conversations"] if x["from"] != "assistant"]
|
||||
for example in batch["conversations"]:
|
||||
conversation = [ { "role": x["from"], "content": x["value"] } for x in example if x["from"] != "assistant"]
|
||||
prompts.append(trained_tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
max_length=CTX_SIZE,
|
||||
@@ -78,7 +89,7 @@ def main():
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
))
|
||||
expected_responses.append([x["value"] for x in example["conversations"] if x["from"] == "assistant"][0])
|
||||
expected_responses.append([x["value"] for x in example if x["from"] == "assistant"][0])
|
||||
output = generate(trained_model, trained_tokenizer, prompts)
|
||||
|
||||
for model_output, expected_response in zip(output, expected_responses):
|
||||
|
||||
Reference in New Issue
Block a user