TeamOne implementation of GAIA (#221)

Port of GAIA benchmark
This commit is contained in:
afourney
2024-07-17 09:51:19 -07:00
committed by GitHub
parent e69dd92c4f
commit 211bfa01c3
17 changed files with 790 additions and 144 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
@@ -36,7 +37,7 @@ async def main() -> None:
actual_surfer = runtime._get_agent(web_surfer.id) # type: ignore
assert isinstance(actual_surfer, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, browser_channel="chromium")
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
await run_context.stop_when_idle()

View File

@@ -23,9 +23,9 @@ dependencies = [
"markdownify",
"numpy",
"python-pptx",
"easyocr",
"pandas",
"pdfminer.six",
"easyocr",
"puremagic",
"binaryornot",
"pydub",

View File

@@ -8,7 +8,7 @@ from agnext.components.models import (
)
from agnext.core import CancellationToken
from team_one.messages import BroadcastMessage, RequestReplyMessage, UserContent
from team_one.messages import BroadcastMessage, RequestReplyMessage, ResetMessage, UserContent
from team_one.utils import message_content_to_str
@@ -28,6 +28,11 @@ class BaseAgent(TypeRoutedAgent):
assert isinstance(message.content, UserMessage)
self._chat_history.append(message.content)
@message_handler
async def handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None:
"""Handle a reset message."""
await self._reset(cancellation_token)
@message_handler
async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
"""Respond to a reply request."""
@@ -42,3 +47,6 @@ class BaseAgent(TypeRoutedAgent):
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""
raise NotImplementedError()
async def _reset(self, cancellation_token: CancellationToken) -> None:
self._chat_history = []

View File

@@ -20,7 +20,7 @@ class Coder(BaseAgent):
DEFAULT_DESCRIPTION = "A Python coder assistant."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage("""You are a helpful AI assistant. Solve tasks using your Python coding skills. The code you output must be formatted in Markdown code blocks demarcated by triple backticks (```). As an example:
SystemMessage("""You are a helpful AI assistant. Solve tasks using your Python coding skills. The code you output must be formatted in Markdown code blocks demarcated by triple backticks (```), and must print their final output to console. As an example:
```python
@@ -86,10 +86,18 @@ class Executor(BaseAgent):
)
cancellation_token.link_future(future)
result = await future
return (
False,
f"The script ran, then exited with Unix exit code: {result.exit_code}\nIts output was:\n{result.output}",
)
if result.output.strip() == "":
# Sometimes agents forget to print(). Remind the to print something
return (
False,
f"The script ran but produced no output to console. The Unix exit code was: {result.exit_code}. If you were expecting output, consider revising the script to ensure content is printed to stdout.",
)
else:
return (
False,
f"The script ran, then exited with Unix exit code: {result.exit_code}\nIts output was:\n{result.output}",
)
else:
n_messages_checked += 1
if n_messages_checked > self._check_last_n_message:

View File

@@ -2,6 +2,7 @@ import base64
import hashlib
import io
import json
import logging
import os
import pathlib
import re
@@ -10,6 +11,7 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callabl
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
import aiofiles
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components import Image as AGImage
from agnext.components.models import (
AssistantMessage,
@@ -24,10 +26,10 @@ from playwright._impl._errors import Error as PlaywrightError
from playwright._impl._errors import TimeoutError
# from playwright._impl._async_base.AsyncEventInfo
from playwright.async_api import BrowserContext, Page, Playwright, async_playwright
from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright
from team_one.agents.base_agent import BaseAgent
from team_one.messages import UserContent
from team_one.messages import UserContent, WebSurferEvent
from team_one.utils import SentinelMeta, message_content_to_str
# TODO: Fix mdconvert
@@ -68,6 +70,8 @@ VIEWPORT_WIDTH = 1440
MLM_HEIGHT = 765
MLM_WIDTH = 1224
logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")
# Sentinels
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
@@ -92,6 +96,7 @@ class MultimodalWebSurfer(BaseAgent):
self._playwright: Playwright | None = None
self._context: BrowserContext | None = None
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
# Read page_script
@@ -99,6 +104,12 @@ class MultimodalWebSurfer(BaseAgent):
with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), "rt") as fh:
self._page_script = fh.read()
# Define the download handler
def _download_handler(download: Download) -> None:
self._last_download = download
self._download_handler = _download_handler
async def init(
self,
model_client: ChatCompletionClient,
@@ -115,12 +126,7 @@ class MultimodalWebSurfer(BaseAgent):
self.start_page = start_page or self.DEFAULT_START_PAGE
self.downloads_folder = downloads_folder
self._chat_history: List[LLMMessage] = []
# def _download_handler(download):
# self._last_download = download
#
# self._download_handler = _download_handler
# self._last_download = None
self._last_download = None
self._prior_metadata_hash = None
## Create or use the provided MarkdownConverter
@@ -148,14 +154,13 @@ class MultimodalWebSurfer(BaseAgent):
self._context.set_default_timeout(60000) # One minute
self._page = await self._context.new_page()
# self._page.route(lambda x: True, self._route_handler)
# self._page.on("download", self._download_handler)
self._page.on("download", self._download_handler)
await self._page.set_viewport_size({"width": VIEWPORT_WIDTH, "height": VIEWPORT_HEIGHT})
await self._page.add_init_script(
path=os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js")
)
await self._page.goto(self.start_page)
await self._page.wait_for_load_state()
# self._sleep(1)
# Prepare the debug directory -- which stores the screenshots generated throughout the process
await self._set_debug_dir(debug_dir)
@@ -194,15 +199,21 @@ setInterval(function() {{
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
print(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
# def reset(self):
# super().reset()
# self._log_to_console(fname="reset", args={"home": self.start_page})
# self._visit_page(self.start_page)
# self._page.wait_for_load_state()
# if self.debug_dir:
# screenshot = self._page.screenshot()
# with open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as png:
# png.write(screenshot)
async def _reset(self, cancellation_token: CancellationToken) -> None:
assert self._page is not None
future = super()._reset(cancellation_token)
await future
await self._visit_page(self.start_page)
await self._page.wait_for_load_state()
if self.debug_dir:
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
message="Resetting browser.",
)
)
def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None:
try:
@@ -358,122 +369,131 @@ When deciding between tools, consider if the request can be best addressed by:
message = response.content
action_description = ""
# self._last_download = None
# try:
if True:
if isinstance(message, str):
# Answer directly
return False, message
self._last_download = None
elif isinstance(message, list):
# Take an action
if isinstance(message, str):
# Answer directly
return False, message
name = message[0].name
args = json.loads(message[0].arguments)
elif isinstance(message, list):
# Take an action
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
await self._visit_page(url)
# If the argument contains a space, treat it as a search query
elif " " in url:
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH")
# Otherwise, prefix with https://
else:
await self._visit_page("https://" + url)
name = message[0].name
args = json.loads(message[0].arguments)
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._back()
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH")
elif name == "page_up":
action_description = "I scrolled up one page in the browser."
await self._page_up()
elif name == "page_down":
action_description = "I scrolled down one page in the browser."
await self._page_down()
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
await self._click_id(target_id)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._fill_id(input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._scroll_id(target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._scroll_id(target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._sleep(3) # There's a 2s sleep below too
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
action=name,
arguments=args,
message=f"{name}( {json.dumps(args)} )",
)
)
if name == "visit_url":
url = args.get("url")
action_description = f"I typed '{url}' into the browser address bar."
# Check if the argument starts with a known protocol
if url.startswith(("https://", "http://", "file://", "about:")):
await self._visit_page(url)
# If the argument contains a space, treat it as a search query
elif " " in url:
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(url)}&FORM=QBLH")
# Otherwise, prefix with https://
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
await self._visit_page("https://" + url)
# except ValueError as e:
# return True, str(e)
elif name == "history_back":
action_description = "I clicked the browser back button."
await self._back()
elif name == "web_search":
query = args.get("query")
action_description = f"I typed '{query}' into the browser search bar."
await self._visit_page(f"https://www.bing.com/search?q={quote_plus(query)}&FORM=QBLH")
elif name == "page_up":
action_description = "I scrolled up one page in the browser."
await self._page_up()
elif name == "page_down":
action_description = "I scrolled down one page in the browser."
await self._page_down()
elif name == "click":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I clicked '{target_name}'."
else:
action_description = "I clicked the control."
await self._click_id(target_id)
elif name == "input_text":
input_field_id = str(args.get("input_field_id"))
text_value = str(args.get("text_value"))
input_field_name = self._target_name(input_field_id, rects)
if input_field_name:
action_description = f"I typed '{text_value}' into '{input_field_name}'."
else:
action_description = f"I input '{text_value}'."
await self._fill_id(input_field_id, text_value)
elif name == "scroll_element_up":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' up."
else:
action_description = "I scrolled the control up."
await self._scroll_id(target_id, "up")
elif name == "scroll_element_down":
target_id = str(args.get("target_id"))
target_name = self._target_name(target_id, rects)
if target_name:
action_description = f"I scrolled '{target_name}' down."
else:
action_description = "I scrolled the control down."
await self._scroll_id(target_id, "down")
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
await self._sleep(3) # There's a 2s sleep below too
else:
raise ValueError(f"Unknown tool '{name}'. Please choose from:\n\n{tool_names}")
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
await self._page.wait_for_load_state()
await self._sleep(3)
# # Handle downloads
# if self._last_download is not None and self.downloads_folder is not None:
# fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
# self._last_download.save_as(fname)
# page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
# self._page.goto("data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8"))
# self._page.wait_for_load_state()
# Handle downloads
if self._last_download is not None and self.downloads_folder is not None:
fname = os.path.join(self.downloads_folder, self._last_download.suggested_filename)
# TODO: Fix this type
await self._last_download.save_as(fname) # type: ignore
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
await self._page.goto(
"data:text/html;base64," + base64.b64encode(page_body.encode("utf-8")).decode("utf-8")
)
await self._page.wait_for_load_state()
# Handle metadata
page_metadata = json.dumps(await self._get_page_metadata(), indent=4)
@@ -571,7 +591,7 @@ When deciding between tools, consider if the request can be best addressed by:
async def _on_new_page(self, page: Page) -> None:
self._page = page
# self._page.route(lambda x: True, self._route_handler)
# self._page.on("download", self._download_handler)
self._page.on("download", self._download_handler)
await self._page.set_viewport_size({"width": VIEWPORT_WIDTH, "height": VIEWPORT_HEIGHT})
await self._sleep(0.2)
self._prior_metadata_hash = None
@@ -644,6 +664,15 @@ When deciding between tools, consider if the request can be best addressed by:
assert isinstance(new_page, Page)
await self._on_new_page(new_page)
logger.info(
WebSurferEvent(
source=self.metadata["name"],
url=self._page.url,
message="New tab or window.",
)
)
except TimeoutError:
pass

View File

@@ -21,7 +21,7 @@ TOOL_VISIT_URL: ToolSchema = _load_tool(
"type": "function",
"function": {
"name": "visit_url",
"description": "Inputs the given url into the browser's address bar, navigating directly to the requested page.",
"description": "Navigate directly to a provided URL using the browser's address bar. Prefer this tool over other navigation techniques in cases where the user provides a fully-qualified URL (e.g., choose it over clicking links, or inputing queries into search boxes).",
"parameters": {
"type": "object",
"properties": {

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from agnext.core import AgentProxy
from ..messages import BroadcastMessage, OrchestrationEvent
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger
from .orchestrator_prompts import (
ORCHESTRATOR_CLOSED_BOOK_PROMPT,
@@ -186,6 +186,10 @@ class LedgerOrchestrator(BaseOrchestrator):
f"New plan:\n{plan_str}",
)
)
# Reset
self._chat_history = [self._chat_history[0]]
await self.publish_message(ResetMessage())
self._chat_history.append(plan_user_message)
ledger_dict = await self.update_ledger()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import Any, Dict, List, Union
from agnext.components import FunctionCall, Image
from agnext.components.models import FunctionExecutionResult, LLMMessage
@@ -22,7 +22,21 @@ class RequestReplyMessage:
pass
@dataclass
class ResetMessage:
pass
@dataclass
class OrchestrationEvent:
source: str
message: str
@dataclass
class WebSurferEvent:
source: str
message: str
url: str
action: str | None = None
arguments: Dict[str, Any] | None = None

View File

@@ -1,10 +1,12 @@
import json
import logging
import os
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Literal
from agnext.application.logging.events import LLMCallEvent
from agnext.components import Image
from agnext.components.models import (
AzureOpenAIChatCompletionClient,
ChatCompletionClient,
@@ -12,7 +14,14 @@ from agnext.components.models import (
OpenAIChatCompletionClient,
)
from .messages import AssistantContent, FunctionExecutionContent, OrchestrationEvent, SystemContent, UserContent
from .messages import (
AssistantContent,
FunctionExecutionContent,
OrchestrationEvent,
SystemContent,
UserContent,
WebSurferEvent,
)
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER = "CHAT_COMPLETION_PROVIDER"
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON = "CHAT_COMPLETION_KWARGS_JSON"
@@ -82,6 +91,8 @@ def message_content_to_str(
for item in message_content:
if isinstance(item, str):
converted.append(item.rstrip())
elif isinstance(item, Image):
converted.append("<Image>")
else:
converted.append(str(item).rstrip())
return "\n".join(converted)
@@ -95,7 +106,8 @@ class LogHandler(logging.FileHandler):
super().__init__(filename)
def emit(self, record: logging.LogRecord) -> None:
try:
# try:
if True:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, OrchestrationEvent):
console_message = (
@@ -111,6 +123,16 @@ class LogHandler(logging.FileHandler):
}
)
super().emit(record)
if isinstance(record.msg, WebSurferEvent):
console_message = f"\033[96m[{ts}], {record.msg.source}: {record.msg.message}\033[0m"
print(console_message, flush=True)
payload: Dict[str, Any] = {
"timestamp": ts,
"type": "WebSurferEvent",
}
payload.update(asdict(record.msg))
record.msg = json.dumps(payload)
super().emit(record)
if isinstance(record.msg, LLMCallEvent):
record.msg = json.dumps(
{
@@ -120,9 +142,9 @@ class LogHandler(logging.FileHandler):
"type": "LLMCallEvent",
}
)
super().emit(record)
except Exception:
self.handleError(record)
super().emit(record)
# except Exception:
# self.handleError(record)
class SentinelMeta(type):