update gaia script

fix erros in api_server.py
update gaia script
This commit is contained in:
LividWo
2024-04-22 11:18:15 +08:00
parent 366247a82a
commit cfed25f198
4 changed files with 70 additions and 16 deletions

View File

@@ -1,42 +1,92 @@
import json
import requests
from oscopilot import FridayAgent
from oscopilot import FridayExecutor, FridayPlanner, FridayRetriever, ToolManager
from oscopilot.utils import setup_config, GAIALoader, GAIA_postprocess
args = setup_config()
args.dataset_type = 'validation'
model = 'gpt4-turbo'
write_file_path = 'gaia_{}_{}_level{}_results.jsonl'.format(model, args.dataset_type, args.level)
def get_numbers(path):
correct = 0
incomplete = 0
with open(path, 'r', encoding='utf-8') as file:
data = [json.loads(line) for line in file]
print(data)
for d in data:
if d["model_answer"] == d["groundtruth"]:
correct += 1
if d["model_answer"] == "" or d["model_answer"] == "incomplete":
incomplete += 1
if len(data) > 0:
return correct, incomplete, data[-1]["index"]
return correct, incomplete, -1 # -1 denotes no previous running
agent = FridayAgent(FridayPlanner, FridayRetriever, FridayExecutor, ToolManager, config=args)
gaia = GAIALoader(args.level, args.dataset_cache)
args.gaia_task_id = "07ed8ebc-535a-4c2f-9677-3e434a08f7fd"
# args.gaia_task_id = "e1fc63a2-da7a-432f-be78-7c4a95598703"
if args.gaia_task_id:
task = gaia.get_data_by_task_id(args.gaia_task_id, args.dataset_type)
query = gaia.task2query(task)
agent.run(query)
if agent.inner_monologue.result != '':
result = GAIA_postprocess(task['Question'], agent.inner_monologue.result)
# agent.run(query)
# if agent.inner_monologue.result != '':
if True:
# print(agent.inner_monologue.result)
result = """17000
"""
# result = GAIA_postprocess(task['Question'], agent.inner_monologue.result)
result = GAIA_postprocess(task['Question'], result)
print('The answer of GAIA Task {0} : {1}'.format(args.gaia_task_id, result))
else:
task_lst = gaia.dataset[args.dataset_type]
with open('gaia_{}_level{}_results.jsonl'.format(args.dataset_type, args.level), 'w', encoding='utf-8') as file:
correct, incomplete, last_run_index = get_numbers(write_file_path)
print(correct, incomplete, last_run_index)
with open(write_file_path, 'a', encoding='utf-8') as file:
count = 0
for task in task_lst:
if count <= last_run_index:
print("\t\t\t skip current run:", count)
count += 1
continue
query = gaia.task2query(task)
agent.run(query)
if agent.inner_monologue.result != '':
result = GAIA_postprocess(task['Question'], agent.inner_monologue.result)
else:
result = ''
result = ''
# agent.run(query)
try:
agent.run(query)
print("$$$$$$" * 30)
if agent.inner_monologue.result != '':
result = GAIA_postprocess(task['Question'], agent.inner_monologue.result)
except requests.exceptions.ConnectionError as ConnectionError:
print(f"Connection error.: {ConnectionError}")
exit()
except Exception as e:
print("$$$$$$" * 30)
# Code to handle any other type of exception
print(f"An error occurred: {e}")
print("$$$$$$" * 30)
result = "incomplete"
incomplete += 1
output_dict = {
"task_id": task['task_id'],
"model_answer": result,
"index": count,
"task_id": task['task_id'],
"model_answer": result,
"groundtruth": task["Final answer"],
"reasoning_trace": ""
}
if result == task["Final answer"]:
correct += 1
json_str = json.dumps(output_dict)
file.write(json_str + '\n')
file.flush()
count += 1
if count > 10:
break
# if count > 2:
# break
print("accuracy:", correct / count)
print("incomplete:", incomplete / count)
print("correct incomplete total,", correct, incomplete, count)

View File

@@ -69,6 +69,7 @@ class FridayPlanner(BaseModule):
self.topological_sort()
else:
print(response)
print('No JSON data found in the string.')
sys.exit()
def replan_task(self, reasoning, current_task, relevant_tool_description_pair):

View File

@@ -4,6 +4,8 @@ prompt = {
1. For numerical questions: Extract numerical values directly from the response.
2. For non-numerical questions: Follow the provided example to guide your extraction.
3. Note that sometimes you need to post-process the values you get follow the instruction in the question.
4. You need to follow the return format specified in the question.
Here are some examples of answer extraction:
Question: Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order.
Response: The page numbers extracted by the 'extract_page_numbers' subtask, already arranged in ascending order, are: 132, 133, 134, 197, 245.
@@ -61,7 +63,7 @@ prompt = {
This statement is not a standard logical equivalence. The correct equivalence for the implication (¬A → B) would be (A B), not (A ¬B). Therefore, this is the full statement that doesn't fit with the others.
Answer: (¬A → B) ↔ (A ¬B)
Based solely on the response provided, extract the answer following these guidelines. Your response should exclusively contain the extracted answer, devoid of any extraneous information.
Based on the Response provided below, extract the answer following above guidelines and instructions in the Question. Your response should only contain the extracted answer.
Question: {question}
Response: {response}
Answer:

View File

@@ -1,8 +1,9 @@
import os
import dotenv
from fastapi import FastAPI
from oscopilot.utils.server_config import ConfigManager
dotenv.load_dotenv(override=True)
app = FastAPI()
# Import your services