Update tool use examples to use inner agents rather than subclassing (#286)

* Update tool use examples to use inner agents rather than subclassing

* fix

* Merge remote-tracking branch 'origin/main' into ekzhu-update-tool-use-example
This commit is contained in:
Eric Zhu
2024-07-26 15:04:52 -07:00
committed by GitHub
parent 6437374f63
commit 47e1cf464f
8 changed files with 198 additions and 102 deletions

View File

@@ -1,13 +1,13 @@
"""
This example implements a tool-enabled agent that uses tools to perform tasks.
1. The agent receives a user message, and makes an inference using a model.
If the response is a list of function calls, the agent executes the tools by
sending tool execution task to itself.
2. The agent executes the tools and sends the results back to itself, and
makes an inference using the model again.
3. The agent keeps executing the tools until the inference response is not a
1. The tool use agent receives a user message, and makes an inference using a model.
If the response is a list of function calls, the tool use agent executes the tools by
sending tool execution task to a tool executor agent.
2. The tool executor agent executes the tools and sends the results back to the
tool use agent, who makes an inference using the model again.
3. The agents keep executing the tools until the inference response is not a
list of function calls.
4. The agent returns the final response to the user.
4. The tool use agent returns the final response to the user.
"""
import asyncio
@@ -17,7 +17,7 @@ from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, message_handler
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import (
AssistantMessage,
@@ -29,8 +29,8 @@ from agnext.components.models import (
UserMessage,
)
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import CancellationToken
from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema
from agnext.core import AgentId, CancellationToken
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -42,7 +42,7 @@ class Message:
content: str
class ToolEnabledAgent(ToolAgent):
class ToolUseAgent(TypeRoutedAgent):
"""An agent that uses tools to perform tasks. It executes the tools
by itself by sending the tool execution task to itself."""
@@ -51,24 +51,30 @@ class ToolEnabledAgent(ToolAgent):
description: str,
system_messages: List[SystemMessage],
model_client: ChatCompletionClient,
tools: List[Tool],
tool_schema: List[ToolSchema],
tool_agent: AgentId,
) -> None:
super().__init__(description, tools)
super().__init__(description)
self._model_client = model_client
self._system_messages = system_messages
self._tool_schema = tool_schema
self._tool_agent = tool_agent
@message_handler
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
"""Handle a user message, execute the model and tools, and returns the response."""
session: List[LLMMessage] = []
session.append(UserMessage(content=message.content, source="User"))
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
response = await self._model_client.create(self._system_messages + session, tools=self._tool_schema)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
# Keep executing the tools until the response is not a list of function calls.
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[self.send_message(call, self.id) for call in response.content],
*[
self.send_message(call, self._tool_agent, cancellation_token=cancellation_token)
for call in response.content
],
return_exceptions=True,
)
# Combine the results into a single response and handle exceptions.
@@ -82,7 +88,7 @@ class ToolEnabledAgent(ToolAgent):
raise result
session.append(FunctionExecutionResultMessage(content=function_results))
# Execute the model again with the new response.
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
response = await self._model_client.create(self._system_messages + session, tools=self._tool_schema)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
assert isinstance(response.content, str)
@@ -100,20 +106,30 @@ async def main() -> None:
)
]
# Register agents.
tool_agent = await runtime.register_and_get(
tool_executor_agent = await runtime.register_and_get(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
"tool_enabled_agent",
lambda: ToolEnabledAgent(
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tools=tools,
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
),
)
run_context = runtime.start()
# Send a task to the tool user.
response = await runtime.send_message(Message("Run the following Python code: print('Hello, World!')"), tool_agent)
response = await runtime.send_message(
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
)
print(response.content)
# Run the runtime until the task is completed.

View File

@@ -14,14 +14,14 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import SystemMessage
from agnext.components.tool_agent import ToolException
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import AgentId
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from coding_one_agent_direct import Message, ToolEnabledAgent
from coding_direct import Message, ToolUseAgent
from common.utils import get_chat_completion_client_from_envs
@@ -48,20 +48,30 @@ async def main() -> None:
)
]
# Register agents.
tool_agent = await runtime.register_and_get(
tool_executor_agent = await runtime.register_and_get(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
"tool_enabled_agent",
lambda: ToolEnabledAgent(
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tools=tools,
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
),
)
run_context = runtime.start()
# Send a task to the tool user.
response = await runtime.send_message(Message("Run the following Python code: print('Hello, World!')"), tool_agent)
response = await runtime.send_message(
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
)
print(response.content)
# Run the runtime until the task is completed.

View File

@@ -7,18 +7,20 @@ import asyncio
import os
import random
import sys
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components.models import (
SystemMessage,
)
from agnext.components.tools import FunctionTool
from agnext.components.tool_agent import ToolAgent
from agnext.components.tools import FunctionTool, Tool
from typing_extensions import Annotated
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from coding_one_agent_direct import Message, ToolEnabledAgent
from coding_direct import Message, ToolUseAgent
from common.utils import get_chat_completion_client_from_envs
@@ -31,28 +33,37 @@ async def get_stock_price(ticker: str, date: Annotated[str, "The date in YYYY/MM
async def main() -> None:
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
tools: List[Tool] = [
# A tool that gets the stock price.
FunctionTool(
get_stock_price,
description="Get the stock price of a company given the ticker and date.",
name="get_stock_price",
)
]
# Register agents.
tool_agent = await runtime.register_and_get(
tool_executor_agent = await runtime.register_and_get(
"tool_executor_agent",
lambda: ToolAgent(
description="Tool Executor Agent",
tools=tools,
),
)
tool_use_agent = await runtime.register_and_get(
"tool_enabled_agent",
lambda: ToolEnabledAgent(
lambda: ToolUseAgent(
description="Tool Use Agent",
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
tools=[
# Define a tool that gets the stock price.
FunctionTool(
get_stock_price,
description="Get the stock price of a company given the ticker and date.",
name="get_stock_price",
)
],
tool_schema=[tool.schema for tool in tools],
tool_agent=tool_executor_agent,
),
)
run_context = runtime.start()
# Send a task to the tool user.
response = await runtime.send_message(Message("What is the stock price of NVDA on 2024/06/01"), tool_agent)
response = await runtime.send_message(Message("What is the stock price of NVDA on 2024/06/01"), tool_use_agent)
# Print the result.
assert isinstance(response, Message)
print(response.content)