TeamOne implementation of GAIA (#221)

Port of GAIA benchmark
This commit is contained in:
afourney
2024-07-17 09:51:19 -07:00
committed by GitHub
parent e69dd92c4f
commit 211bfa01c3
17 changed files with 790 additions and 144 deletions

View File

@@ -0,0 +1,7 @@
{
"CHAT_COMPLETION_PROVIDER": "azure",
"CHAT_COMPLETION_KWARGS_JSON": "{\"api_version\": \"2024-05-01-preview\", \"azure_endpoint\": \"YOUR_ENDPOINT_HERE\", \"model_capabilities\": {\"function_calling\": true, \"json_output\": true, \"vision\": true}, \"azure_ad_token_provider\": \"DEFAULT\", \"model\": \"gpt-4o-2024-05-13\"}",
"BING_API_KEY": "YOUR_KEY_KEY",
"HOMEPAGE": "https://www.bing.com/",
"WEB_SURFER_DEBUG_DIR": "/autogen/debug"
}

View File

@@ -0,0 +1,197 @@
import os
import sys
import re
from agbench.tabulate_cmd import default_tabulate
import json
import pandas as pd
import sqlite3
import glob
import numpy as np
EXCLUDE_DIR_NAMES = ["__pycache__"]
def normalize_answer(a):
# Lower case
# Trim (left and right)
# standardize comma separated values
# Replace multiple spaces with one space
# Remove trailing punctuation
norm_answer = ", ".join(a.strip().lower().split(","))
norm_answer = re.sub(r"[\.\!\?]+$", "", re.sub(r"\s+", " ", norm_answer))
return norm_answer
def scorer(instance_dir):
# Read the expected answer
expected_answer_file = os.path.join(instance_dir, "expected_answer.txt")
if not os.path.isfile(expected_answer_file):
return None
expected_answer = None
with open(expected_answer_file, "rt") as fh:
expected_answer = fh.read().strip()
# Read the console
console_log_file = os.path.join(instance_dir, "console_log.txt")
if not os.path.isfile(console_log_file):
return None
console_log = ""
with open(console_log_file, "rt") as fh:
console_log = fh.read()
final_answer = None
m = re.search(r"FINAL ANSWER:(.*?)\n", console_log, re.DOTALL)
if m:
final_answer = m.group(1).strip()
# Missing the final answer line
if final_answer is None:
return None
# Return true if they are equal after normalization
n_ex = normalize_answer(expected_answer)
n_final = normalize_answer(final_answer)
return (
(n_ex != "" and n_ex == n_final),
n_ex,
n_final
)
def get_number_of_chat_messages(chat_messages_dir):
result = 0
for file in glob.glob(f"{chat_messages_dir}/*_messages.json"):
with open(file, "r") as f:
content = json.load(f)
for agent, messages in content.items():
result += len(messages)
return result
def main(args):
parsed_args, all_results = default_tabulate(args, scorer=scorer)
excel_path = parsed_args.excel
if excel_path:
excel_dir = os.path.dirname(excel_path) or "."
if not os.path.exists(excel_dir):
os.makedirs(excel_dir, exist_ok=True)
if not excel_path.endswith((".xlsx", ".xls")):
excel_path += ".xlsx"
runlogs = parsed_args.runlogs if parsed_args.runlogs.endswith("/") else parsed_args.runlogs + "/"
if os.path.isdir(runlogs):
task_ids = sorted(
[task_id for task_id in os.listdir(runlogs) if task_id not in EXCLUDE_DIR_NAMES],
key=lambda s: os.path.getmtime(os.path.join(parsed_args.runlogs, s)),
)
else:
raise ValueError("please input a valid directory to tabulate result")
trials = sorted(os.listdir(f"{runlogs}{task_ids[0]}"), key=lambda x: int(x)) if len(task_ids) > 0 else []
dbnames = [[f"{runlogs}{task_id}/{trial}/telemetry.db" for task_id in task_ids] for trial in trials]
query = """
SELECT cost, session_id, response, start_time, end_time
FROM (
SELECT invocation_id, cost, session_id, response, start_time, end_time,
ROW_NUMBER() OVER (PARTITION BY invocation_id ORDER BY start_time) as rn
FROM chat_completions
)
WHERE rn = 1;
"""
with pd.ExcelWriter(excel_path, engine="openpyxl") as writer:
for trial_index, each_trial in enumerate(dbnames):
result_df = pd.DataFrame(
columns=[
"id",
"status",
"expected_answer",
"final_answer",
"cost",
"latency",
"num_of_llm_requests",
"num_of_chat_messages",
"prompt_tokens",
"completion_tokens",
"total_tokens",
"model",
]
)
result_df_type_mapping = {
"id": str,
"status": bool,
"expected_answer": str,
"final_answer": str,
"cost": float,
"latency": float,
"num_of_llm_requests": int,
"num_of_chat_messages": int,
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int,
}
for dbname, scorer_results in zip(each_trial, all_results):
task_id = scorer_results[0]
scorer_result = scorer_results[trial_index + 1]
status, expected_answer, final_answer = scorer_result if scorer_result else (False,"","")
con = sqlite3.connect(dbname)
# TODO: if large amount of data, add chunksize
telemetry_df = pd.read_sql_query(query, con)
earliest_starttime = pd.to_datetime(telemetry_df["start_time"], format="%Y-%m-%d %H:%M:%S.%f").min()
latest_endtime = pd.to_datetime(telemetry_df["end_time"], format="%Y-%m-%d %H:%M:%S.%f").max()
num_of_chat_messages = get_number_of_chat_messages(chat_messages_dir=os.path.dirname(dbname))
result = {
"id": task_id,
"status": status,
"expected_answer": expected_answer,
"final_answer": final_answer,
"cost": telemetry_df["cost"].sum(),
"latency": (latest_endtime - earliest_starttime).total_seconds(),
"num_of_llm_requests": len(telemetry_df),
"num_of_chat_messages": num_of_chat_messages,
"prompt_tokens": telemetry_df["response"]
.apply(
lambda x: json.loads(x)["usage"]["prompt_tokens"]
if "usage" in json.loads(x) and "prompt_tokens" in json.loads(x)["usage"]
else 0
)
.sum(),
"completion_tokens": telemetry_df["response"]
.apply(
lambda x: json.loads(x)["usage"]["completion_tokens"]
if "usage" in json.loads(x) and "completion_tokens" in json.loads(x)["usage"]
else 0
)
.sum(),
"total_tokens": telemetry_df["response"]
.apply(
lambda x: json.loads(x)["usage"]["total_tokens"]
if "usage" in json.loads(x) and "total_tokens" in json.loads(x)["usage"]
else 0
)
.sum(),
"model": telemetry_df["response"]
.apply(lambda x: json.loads(x)["model"] if "model" in json.loads(x) else "")
.unique(),
}
result_df = result_df.astype(result_df_type_mapping)
result_df = pd.concat([result_df, pd.DataFrame([result])], ignore_index=True)
result_df.to_excel(writer, sheet_name=f"trial_{trial_index}", index=False)
if __name__ == "__main__" and __package__ is None:
main(sys.argv)

View File

@@ -0,0 +1,157 @@
#
# Run this file to download the human_eval dataset, and create a corresponding testbed scenario:
# (default: ../scenarios/human_eval_two_agents_gpt4.jsonl and ./scenarios/human_eval_two_agents_gpt35.jsonl)
#
import json
import os
import re
import sys
from huggingface_hub import snapshot_download
SCRIPT_PATH = os.path.realpath(__file__)
SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
SCRIPT_DIR = os.path.dirname(SCRIPT_PATH)
SCENARIO_DIR = os.path.realpath(os.path.join(SCRIPT_DIR, os.path.pardir))
TEMPLATES_DIR = os.path.join(SCENARIO_DIR, "Templates")
TASKS_DIR = os.path.join(SCENARIO_DIR, "Tasks")
DOWNLOADS_DIR = os.path.join(SCENARIO_DIR, "Downloads")
REPO_DIR = os.path.join(DOWNLOADS_DIR, "GAIA")
def download_gaia():
"""Download the GAIA benchmark from Hugging Face."""
if not os.path.isdir(DOWNLOADS_DIR):
os.mkdir(DOWNLOADS_DIR)
"""Download the GAIA dataset from Hugging Face Hub"""
snapshot_download(
repo_id="gaia-benchmark/GAIA",
repo_type="dataset",
local_dir=REPO_DIR,
local_dir_use_symlinks=True,
)
def create_jsonl(name, tasks, files_dir, template):
"""Creates a JSONL scenario file with a given name, and template path."""
if not os.path.isdir(TASKS_DIR):
os.mkdir(TASKS_DIR)
with open(os.path.join(TASKS_DIR, name + ".jsonl"), "wt") as fh:
for task in tasks:
print(f"Converting: [{name}] {task['task_id']}")
# Figure out what files we need to copy
template_cp_list = [template]
if len(task["file_name"].strip()) > 0:
template_cp_list.append(
[
os.path.join(files_dir, task["file_name"].strip()),
task["file_name"].strip(),
#os.path.join("coding", task["file_name"].strip()),
]
)
record = {
"id": task["task_id"],
"template": template_cp_list,
"substitutions": {
"scenario.py": {
"__FILE_NAME__": task["file_name"],
},
"expected_answer.txt": {"__EXPECTED_ANSWER__": task["Final answer"]},
"prompt.txt": {"__PROMPT__": task["Question"]},
},
}
fh.write(json.dumps(record).strip() + "\n")
###############################################################################
def main():
download_gaia()
gaia_validation_files = os.path.join(REPO_DIR, "2023", "validation")
gaia_test_files = os.path.join(REPO_DIR, "2023", "test")
if not os.path.isdir(gaia_validation_files) or not os.path.isdir(gaia_test_files):
sys.exit(f"Error: '{REPO_DIR}' does not appear to be a copy of the GAIA repository.")
# Load the GAIA data
gaia_validation_tasks = [[], [], []]
with open(os.path.join(gaia_validation_files, "metadata.jsonl")) as fh:
for line in fh:
data = json.loads(line)
gaia_validation_tasks[data["Level"] - 1].append(data)
gaia_test_tasks = [[], [], []]
with open(os.path.join(gaia_test_files, "metadata.jsonl")) as fh:
for line in fh:
data = json.loads(line)
# A welcome message -- not a real task
if data["task_id"] == "0-0-0-0-0":
continue
gaia_test_tasks[data["Level"] - 1].append(data)
# list all directories in the Templates directory
# and populate a dictionary with the name and path
templates = {}
for entry in os.scandir(TEMPLATES_DIR):
if entry.is_dir():
templates[re.sub(r"\s", "", entry.name)] = entry.path
# Add coding directories if needed (these are usually empty and left out of the repo)
#for template in templates.values():
# code_dir_path = os.path.join(template, "coding")
# if not os.path.isdir(code_dir_path):
# os.mkdir(code_dir_path)
# Create the various combinations of [models] x [templates]
for t in templates.items():
create_jsonl(
f"gaia_validation_level_1__{t[0]}",
gaia_validation_tasks[0],
gaia_validation_files,
t[1],
)
create_jsonl(
f"gaia_validation_level_2__{t[0]}",
gaia_validation_tasks[1],
gaia_validation_files,
t[1],
)
create_jsonl(
f"gaia_validation_level_3__{t[0]}",
gaia_validation_tasks[2],
gaia_validation_files,
t[1],
)
create_jsonl(
f"gaia_test_level_1__{t[0]}",
gaia_test_tasks[0],
gaia_test_files,
t[1],
)
create_jsonl(
f"gaia_test_level_2__{t[0]}",
gaia_test_tasks[1],
gaia_test_files,
t[1],
)
create_jsonl(
f"gaia_test_level_3__{t[0]}",
gaia_test_tasks[2],
gaia_test_files,
t[1],
)
if __name__ == "__main__" and __package__ is None:
main()

View File

@@ -0,0 +1 @@
__EXPECTED_ANSWER__

View File

@@ -0,0 +1 @@
__PROMPT__

View File

@@ -0,0 +1 @@
/agnext/teams/team-one

View File

@@ -0,0 +1,197 @@
import asyncio
import logging
import json
import os
from typing import Any, Dict, List, Tuple, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import (
AzureOpenAIChatCompletionClient,
ChatCompletionClient,
ModelCapabilities,
UserMessage,
LLMMessage,
)
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.application.logging import EVENT_LOGGER_NAME
from team_one.markdown_browser import MarkdownConverter, UnsupportedFormatException
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import LedgerOrchestrator
from team_one.messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
from team_one.agents.file_surfer import FileSurfer
from team_one.utils import LogHandler, message_content_to_str, create_completion_client_from_env
async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]):
messages: List[LLMMessage] = [
UserMessage(
content=f"Earlier you were asked the following:\n\n{task}\n\nYour team then worked diligently to address that request. Here is a transcript of that conversation:",
source=source,
)
]
# copy them to this context
for message in transcript:
messages.append(
UserMessage(
content = message_content_to_str(message.content),
source=message.source,
)
)
# ask for the final answer
messages.append(
UserMessage(
content= f"""
Read the above conversation and output a FINAL ANSWER to the question. The question is repeated here for convenience:
{task}
To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
""",
#If you are unable to determine the final answer, output 'FINAL ANSWER: Unable to determine'
source=source,
)
)
response = await client.create(messages)
assert isinstance(response.content, str)
# No answer
if "unable to determine" in response.content.lower():
messages.append( AssistantMessage(content=response.content, source="self" ) )
messages.append(
UserMessage(
content= f"""
I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation.
To output the educated guess, use the following template: EDUCATED GUESS: [YOUR EDUCATED GUESS]
Your EDUCATED GUESS should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. DO NOT OUTPUT 'I don't know', 'Unable to determine', etc.
ADDITIONALLY, your EDUCATED GUESS MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
""".strip(),
source=source,
)
)
response = await client.create(messages)
assert isinstance(response.content, str)
return re.sub(r"EDUCATED GUESS:", "FINAL ANSWER:", response.content)
else:
return response.content
async def main() -> None:
# Read the prompt
prompt = ""
with open("prompt.txt", "rt") as fh:
prompt = fh.read().strip()
filename = "__FILE_NAME__".strip()
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
# Create the AzureOpenAI client, with AAD auth
# token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
client = AzureOpenAIChatCompletionClient(
api_version="2024-02-15-preview",
azure_endpoint="https://aif-complex-tasks-west-us-3.openai.azure.com/",
model="gpt-4o-2024-05-13",
model_capabilities=ModelCapabilities(
function_calling=True, json_output=True, vision=True
),
# azure_ad_token_provider=token_provider
)
# Register agents.
coder = runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
executor = runtime.register_and_get_proxy(
"Executor",
lambda: Executor(
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
),
)
file_surfer = runtime.register_and_get_proxy(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
web_surfer = runtime.register_and_get_proxy(
"WebSurfer",
lambda: MultimodalWebSurfer(), # Configuration is set later by init()
)
orchestrator = runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator(
agents=[coder, executor, file_surfer, web_surfer],
model_client=client,
))
run_context = runtime.start()
actual_surfer = runtime._get_agent(web_surfer.id) # type: ignore
assert isinstance(actual_surfer, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
#await runtime.send_message(RequestReplyMessage(), user_proxy.id)
filename_prompt = ""
if len(filename) > 0:
#relpath = os.path.join("coding", filename)
#file_uri = pathlib.Path(os.path.abspath(os.path.expanduser(relpath))).as_uri()
filename_prompt = f"The question is about a file, document or image, which can be accessed by the filename '{filename}' in the current working directory."
try:
mdconverter = MarkdownConverter()
res = mdconverter.convert(filename)
if res.text_content:
#if count_token(res.text_content) < 8000: # Don't put overly-large documents into the prompt
filename_prompt += "\n\nHere are the file's contents:\n\n" + res.text_content
except UnsupportedFormatException:
pass
#mdconverter = MarkdownConverter(mlm_client=client)
#mlm_prompt = f"""Write a detailed caption for this image. Pay special attention to any details that might be useful for someone answering the following:
#{PROMPT}
#""".strip()
task = f"{prompt}\n\n{filename_prompt}"
await runtime.publish_message(
BroadcastMessage(content=UserMessage(content=task.strip(), source="human")),
namespace="default",
)
await run_context.stop_when_idle()
# Output the final answer
actual_orchestrator = runtime._get_agent(orchestrator.id) # type: ignore
assert isinstance(actual_orchestrator, LedgerOrchestrator)
transcript: List[LLMMessage] = actual_orchestrator._chat_history # type: ignore
print(await response_preparer(task=task, source=orchestrator.metadata["name"], client=client, transcript=transcript))
if __name__ == "__main__":
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
log_handler = LogHandler()
logger.handlers = [log_handler]
asyncio.run(main())

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
@@ -36,7 +37,7 @@ async def main() -> None:
actual_surfer = runtime._get_agent(web_surfer.id) # type: ignore
assert isinstance(actual_surfer, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, browser_channel="chromium")
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
await run_context.stop_when_idle()

View File

@@ -23,9 +23,9 @@ dependencies = [
"markdownify",
"numpy",
"python-pptx",
"easyocr",
"pandas",
"pdfminer.six",
"easyocr",
"puremagic",
"binaryornot",
"pydub",

View File

@@ -8,7 +8,7 @@ from agnext.components.models import (
)
from agnext.core import CancellationToken
from team_one.messages import BroadcastMessage, RequestReplyMessage, UserContent
from team_one.messages import BroadcastMessage, RequestReplyMessage, ResetMessage, UserContent
from team_one.utils import message_content_to_str
@@ -28,6 +28,11 @@ class BaseAgent(TypeRoutedAgent):
assert isinstance(message.content, UserMessage)
self._chat_history.append(message.content)
@message_handler
async def handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
"""Handle a reset message."""
await self._reset(cancellation_token)
@message_handler
async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
"""Respond to a reply request."""
@@ -42,3 +47,6 @@ class BaseAgent(TypeRoutedAgent):
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""
raise NotImplementedError()
async def _reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history = []

View File

@@ -20,7 +20,7 @@ class Coder(BaseAgent):
DEFAULT_DESCRIPTION = "A Python coder assistant."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage("""You are a helpful AI assistant. Solve tasks using your Python coding skills. The code you output must be formatted in Markdown code blocks demarcated by triple backticks (```). As an example:
SystemMessage("""You are a helpful AI assistant. Solve tasks using your Python coding skills. The code you output must be formatted in Markdown code blocks demarcated by triple backticks (```), and must print their final output to console. As an example:
```python
@@ -86,10 +86,18 @@ class Executor(BaseAgent):
)
cancellation_token.link_future(future)
result = await future
return (
False,
f"The script ran, then exited with Unix exit code: {result.exit_code}\nIts output was:\n{result.output}",
)
if result.output.strip() == "":
# Sometimes agents forget to print(). Remind the to print something
return (
False,
f"The script ran but produced no output to console. The Unix exit code was: {result.exit_code}. If you were expecting output, consider revising the script to ensure content is printed to stdout.",
)
else:
return (
False,
f"The script ran, then exited with Unix exit code: {result.exit_code}\nIts output was:\n{result.output}",
)
else:
n_messages_checked += 1
if n_messages_checked > self._check_last_n_message:

View File

@@ -2,6 +2,7 @@ import base64
import hashlib
import io
import json
import logging
import os
import pathlib
import re
@@ -10,6 +11,7 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callabl
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
import aiofiles
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components import Image as AGImage
from agnext.components.models import (
AssistantMessage,
@@ -24,10 +26,10 @@ from playwright._impl._errors import Error as PlaywrightError
from playwright._impl._errors import TimeoutError
# from playwright._impl._async_base.AsyncEventInfo
from playwright.async_api import BrowserContext, Page, Playwright, async_playwright
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from team_one.agents.base_agent import BaseAgent
from team_one.messages import UserContent
from team_one.messages import UserContent, WebSurferEvent
from team_one.utils import SentinelMeta, message_content_to_str
# TODO: Fix mdconvert
@@ -68,6 +70,8 @@ VIEWPORT_WIDTH = 1440
MLM_HEIGHT = 765
MLM_WIDTH = 1224
logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")
# Sentinels
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
@@ -92,6 +96,7 @@ class MultimodalWebSurfer(BaseAgent):
self._playwright: Playwright | None = None
self._context: BrowserContext | None = None
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
# Read page_script
@@ -99,6 +104,12 @@ class MultimodalWebSurfer(BaseAgent):
with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt") as fh:
self._page_script = fh.read()
# Define the download handler
def _download_handler(download: Download) -> None:
self._last_download = download
self._download_handler = _download_handler
async def init(
self,
model_client: ChatCompletionClient,
@@ -115,12 +126,7 @@ class MultimodalWebSurfer(BaseAgent):
self.start_page = start_page or self.DEFAULT_START_PAGE
self.downloads_folder = downloads_folder
self._chat_history: List[LLMMessage] = []
# def _download_handler(download):
# self._last_download = download
#
# self._download_handler = _download_handler
# self._last_download = None
self._last_download = None
self._prior_metadata_hash = None
## Create or use the provided MarkdownConverter
@@ -148,14 +154,13 @@ class MultimodalWebSurfer(BaseAgent):
self._context.set_default_timeout(60000) # One minute
self._page = await self._context.new_page()
# self._page.route(lambda x: True, self._route_handler)
# self._page.on("download", self._download_handler)
self._page.on("download", self._download_handler)
await self._page.set_viewport_size({"width": VIEWPORT_WIDTH, "height": VIEWPORT_HEIGHT})
await self._page.add_init_script(
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
)
await self._page.goto(self.start_page)
await self._page.wait_for_load_state()
# self._sleep(1)
# Prepare the debug directory -- which stores the screenshots generated throughout the process
await self._set_debug_dir(debug_dir)
@@ -194,15 +199,21 @@ setInterval(function() {{
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
print(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
# def reset(self):
# super().reset()
# self._log_to_console(fname="reset", args={"home": self.start_page})
# self._visit_page(self.start_page)
# self._page.wait_for_load_state()
# if self.debug_dir:
# screenshot = self._page.screenshot()
# with open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as png:
# png.write(screenshot)
async def _reset(self, cancellation_token: CancellationToken) -> None:
assert self._page is not None
future = super()._reset(cancellation_token)
await future
await self._visit_page(self.start_page)
await self._page.wait_for_load_state()
if self.debug_dir:
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
message="Resetting browser.",
)
)
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
try:
@@ -358,122 +369,131 @@ When deciding between tools, consider if the request can be best addressed by:
message = response.content
action_description = ""
# self._last_download = None
# try:
if True:
if isinstance(message, str):
# Answer directly
return False, message
self._last_download = None
elif isinstance(message, list):
# Take an action
if isinstance(message, str):
# Answer directly
return False, message
name = message[0].name
args = json.loads(message[0].arguments)
elif isinstance(message, list):
# Take an action
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
await self._visit_page(url)
# If the argument contains a space, treat it as a search query
elif " " in url:
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH")
# Otherwise, prefix with https://
else:
await self._visit_page("https://" + url)
name = message[0].name
args = json.loads(message[0].arguments)
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._back()
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH")
elif name == "page_up":
action_description = "I scrolled up one page in the browser."
await self._page_up()
elif name == "page_down":
action_description = "I scrolled down one page in the browser."
await self._page_down()
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
await self._click_id(target_id)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._fill_id(input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._scroll_id(target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._scroll_id(target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._sleep(3) # There's a 2s sleep below too
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
action=name,
arguments=args,
message=f"{name}( {json.dumps(args)} )",
)
)
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
await self._visit_page(url)
# If the argument contains a space, treat it as a search query
elif " " in url:
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH")
# Otherwise, prefix with https://
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
await self._visit_page("https://" + url)
# except ValueError as e:
# return True, str(e)
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._back()
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH")
elif name == "page_up":
action_description = "I scrolled up one page in the browser."
await self._page_up()
elif name == "page_down":
action_description = "I scrolled down one page in the browser."
await self._page_down()
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
await self._click_id(target_id)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._fill_id(input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._scroll_id(target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._scroll_id(target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._sleep(3) # There's a 2s sleep below too
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
await self._page.wait_for_load_state()
await self._sleep(3)
# # Handle downloads
# if self._last_download is not None and self.downloads_folder is not None:
# fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
# self._last_download.save_as(fname)
# page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
# self._page.goto("data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8"))
# self._page.wait_for_load_state()
# Handle downloads
if self._last_download is not None and self.downloads_folder is not None:
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
# TODO: Fix this type
await self._last_download.save_as(fname) # type: ignore
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
await self._page.goto(
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
)
await self._page.wait_for_load_state()
# Handle metadata
page_metadata = json.dumps(await self._get_page_metadata(), indent=4)
@@ -571,7 +591,7 @@ When deciding between tools, consider if the request can be best addressed by:
async def _on_new_page(self, page: Page) -> None:
self._page = page
# self._page.route(lambda x: True, self._route_handler)
# self._page.on("download", self._download_handler)
self._page.on("download", self._download_handler)
await self._page.set_viewport_size({"width": VIEWPORT_WIDTH, "height": VIEWPORT_HEIGHT})
await self._sleep(0.2)
self._prior_metadata_hash = None
@@ -644,6 +664,15 @@ When deciding between tools, consider if the request can be best addressed by:
assert isinstance(new_page, Page)
await self._on_new_page(new_page)
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
message="New tab or window.",
)
)
except TimeoutError:
pass

View File

@@ -21,7 +21,7 @@ TOOL_VISIT_URL: ToolSchema = _load_tool(
"type": "function",
"function": {
"name": "visit_url",
"description": "Inputs the given url into the browser's address bar, navigating directly to the requested page.",
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
"parameters": {
"type": "object",
"properties": {

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from agnext.core import AgentProxy
from ..messages import BroadcastMessage, OrchestrationEvent
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger
from .orchestrator_prompts import (
ORCHESTRATOR_CLOSED_BOOK_PROMPT,
@@ -186,6 +186,10 @@ class LedgerOrchestrator(BaseOrchestrator):
f"New plan:\n{plan_str}",
)
)
# Reset
self._chat_history = [self._chat_history[0]]
await self.publish_message(ResetMessage())
self._chat_history.append(plan_user_message)
ledger_dict = await self.update_ledger()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import Any, Dict, List, Union
from agnext.components import FunctionCall, Image
from agnext.components.models import FunctionExecutionResult, LLMMessage
@@ -22,7 +22,21 @@ class RequestReplyMessage:
pass
@dataclass
class ResetMessage:
pass
@dataclass
class OrchestrationEvent:
source: str
message: str
@dataclass
class WebSurferEvent:
source: str
message: str
url: str
action: str | None = None
arguments: Dict[str, Any] | None = None

View File

@@ -1,10 +1,12 @@
import json
import logging
import os
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Literal
from agnext.application.logging.events import LLMCallEvent
from agnext.components import Image
from agnext.components.models import (
AzureOpenAIChatCompletionClient,
ChatCompletionClient,
@@ -12,7 +14,14 @@ from agnext.components.models import (
OpenAIChatCompletionClient,
)
from .messages import AssistantContent, FunctionExecutionContent, OrchestrationEvent, SystemContent, UserContent
from .messages import (
AssistantContent,
FunctionExecutionContent,
OrchestrationEvent,
SystemContent,
UserContent,
WebSurferEvent,
)
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
@@ -82,6 +91,8 @@ def message_content_to_str(
for item in message_content:
if isinstance(item, str):
converted.append(item.rstrip())
elif isinstance(item, Image):
converted.append("<Image>")
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
@@ -95,7 +106,8 @@ class LogHandler(logging.FileHandler):
super().__init__(filename)
def emit(self, record: logging.LogRecord) -> None:
try:
# try:
if True:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, OrchestrationEvent):
console_message = (
@@ -111,6 +123,16 @@ class LogHandler(logging.FileHandler):
}
)
super().emit(record)
if isinstance(record.msg, WebSurferEvent):
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
print(console_message, flush=True)
payload: Dict[str, Any] = {
"timestamp": ts,
"type": "WebSurferEvent",
}
payload.update(asdict(record.msg))
record.msg = json.dumps(payload)
super().emit(record)
if isinstance(record.msg, LLMCallEvent):
record.msg = json.dumps(
{
@@ -120,9 +142,9 @@ class LogHandler(logging.FileHandler):
"type": "LLMCallEvent",
}
)
super().emit(record)
except Exception:
self.handleError(record)
super().emit(record)
# except Exception:
# self.handleError(record)
class SentinelMeta(type):

View File

@@ -18,8 +18,7 @@ RUN pip install openai pillow aiohttp typing-extensions pydantic types-aiofiles
RUN pip install numpy pandas matplotlib seaborn scikit-learn requests urllib3 nltk pytest
# Pre-load packages needed for mdconvert file utils
RUN pip install python-docx pdfminer.six python-pptx SpeechRecognition openpyxl pydub mammoth puremagic youtube_transcript_api==0.6.0
# easyocr
RUN pip install python-docx pdfminer.six python-pptx SpeechRecognition openpyxl pydub mammoth puremagic youtube_transcript_api==0.6.0 easyocr
# Pre-load Playwright
RUN pip install playwright
@@ -30,7 +29,7 @@ RUN pip uninstall --yes numpy
RUN pip install "numpy<2.0"
# Pre-load the OCR model
#RUN /usr/bin/echo -e "import easyocr\nreader = easyocr.Reader(['en'])" | python
RUN /usr/bin/echo -e "import easyocr\nreader = easyocr.Reader(['en'])" | python
# Webarena (evaluation code)
RUN pip install beartype aiolimiter