mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Update tool use examples to use inner agents rather than subclassing (#286)
* Update tool use examples to use inner agents rather than subclassing * fix * Merge remote-tracking branch 'origin/main' into ekzhu-update-tool-use-example
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
"""
|
||||
This example implements a tool-enabled agent that uses tools to perform tasks.
|
||||
1. The agent receives a user message, and makes an inference using a model.
|
||||
If the response is a list of function calls, the agent executes the tools by
|
||||
sending tool execution task to itself.
|
||||
2. The agent executes the tools and sends the results back to itself, and
|
||||
makes an inference using the model again.
|
||||
3. The agent keeps executing the tools until the inference response is not a
|
||||
1. The tool use agent receives a user message, and makes an inference using a model.
|
||||
If the response is a list of function calls, the tool use agent executes the tools by
|
||||
sending tool execution task to a tool executor agent.
|
||||
2. The tool executor agent executes the tools and sends the results back to the
|
||||
tool use agent, who makes an inference using the model again.
|
||||
3. The agents keep executing the tools until the inference response is not a
|
||||
list of function calls.
|
||||
4. The agent returns the final response to the user.
|
||||
4. The tool use agent returns the final response to the user.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import FunctionCall, message_handler
|
||||
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
|
||||
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@@ -29,8 +29,8 @@ from agnext.components.models import (
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.tool_agent import ToolAgent, ToolException
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
||||
from agnext.core import CancellationToken
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -42,7 +42,7 @@ class Message:
|
||||
content: str
|
||||
|
||||
|
||||
class ToolEnabledAgent(ToolAgent):
|
||||
class ToolUseAgent(TypeRoutedAgent):
|
||||
"""An agent that uses tools to perform tasks. It executes the tools
|
||||
by itself by sending the tool execution task to itself."""
|
||||
|
||||
@@ -51,24 +51,30 @@ class ToolEnabledAgent(ToolAgent):
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
tools: List[Tool],
|
||||
tool_schema: List[ToolSchema],
|
||||
tool_agent: AgentId,
|
||||
) -> None:
|
||||
super().__init__(description, tools)
|
||||
super().__init__(description)
|
||||
self._model_client = model_client
|
||||
self._system_messages = system_messages
|
||||
self._tool_schema = tool_schema
|
||||
self._tool_agent = tool_agent
|
||||
|
||||
@message_handler
|
||||
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
|
||||
"""Handle a user message, execute the model and tools, and returns the response."""
|
||||
session: List[LLMMessage] = []
|
||||
session.append(UserMessage(content=message.content, source="User"))
|
||||
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
|
||||
response = await self._model_client.create(self._system_messages + session, tools=self._tool_schema)
|
||||
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
||||
|
||||
# Keep executing the tools until the response is not a list of function calls.
|
||||
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
|
||||
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
||||
*[self.send_message(call, self.id) for call in response.content],
|
||||
*[
|
||||
self.send_message(call, self._tool_agent, cancellation_token=cancellation_token)
|
||||
for call in response.content
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
# Combine the results into a single response and handle exceptions.
|
||||
@@ -82,7 +88,7 @@ class ToolEnabledAgent(ToolAgent):
|
||||
raise result
|
||||
session.append(FunctionExecutionResultMessage(content=function_results))
|
||||
# Execute the model again with the new response.
|
||||
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
|
||||
response = await self._model_client.create(self._system_messages + session, tools=self._tool_schema)
|
||||
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -100,20 +106,30 @@ async def main() -> None:
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_agent = await runtime.register_and_get(
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolEnabledAgent(
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tools=tools,
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
response = await runtime.send_message(Message("Run the following Python code: print('Hello, World!')"), tool_agent)
|
||||
response = await runtime.send_message(
|
||||
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
|
||||
)
|
||||
print(response.content)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
@@ -14,14 +14,14 @@ from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import FunctionCall
|
||||
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import SystemMessage
|
||||
from agnext.components.tool_agent import ToolException
|
||||
from agnext.components.tool_agent import ToolAgent, ToolException
|
||||
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
||||
from agnext.core import AgentId
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from coding_one_agent_direct import Message, ToolEnabledAgent
|
||||
from coding_direct import Message, ToolUseAgent
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
||||
@@ -48,20 +48,30 @@ async def main() -> None:
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_agent = await runtime.register_and_get(
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolEnabledAgent(
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tools=tools,
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
response = await runtime.send_message(Message("Run the following Python code: print('Hello, World!')"), tool_agent)
|
||||
response = await runtime.send_message(
|
||||
Message("Run the following Python code: print('Hello, World!')"), tool_use_agent
|
||||
)
|
||||
print(response.content)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
@@ -7,18 +7,20 @@ import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components.models import (
|
||||
SystemMessage,
|
||||
)
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.components.tool_agent import ToolAgent
|
||||
from agnext.components.tools import FunctionTool, Tool
|
||||
from typing_extensions import Annotated
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from coding_one_agent_direct import Message, ToolEnabledAgent
|
||||
from coding_direct import Message, ToolUseAgent
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
||||
@@ -31,28 +33,37 @@ async def get_stock_price(ticker: str, date: Annotated[str, "The date in YYYY/MM
|
||||
async def main() -> None:
|
||||
# Create the runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
tools: List[Tool] = [
|
||||
# A tool that gets the stock price.
|
||||
FunctionTool(
|
||||
get_stock_price,
|
||||
description="Get the stock price of a company given the ticker and date.",
|
||||
name="get_stock_price",
|
||||
)
|
||||
]
|
||||
# Register agents.
|
||||
tool_agent = await runtime.register_and_get(
|
||||
tool_executor_agent = await runtime.register_and_get(
|
||||
"tool_executor_agent",
|
||||
lambda: ToolAgent(
|
||||
description="Tool Executor Agent",
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
tool_use_agent = await runtime.register_and_get(
|
||||
"tool_enabled_agent",
|
||||
lambda: ToolEnabledAgent(
|
||||
lambda: ToolUseAgent(
|
||||
description="Tool Use Agent",
|
||||
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
tools=[
|
||||
# Define a tool that gets the stock price.
|
||||
FunctionTool(
|
||||
get_stock_price,
|
||||
description="Get the stock price of a company given the ticker and date.",
|
||||
name="get_stock_price",
|
||||
)
|
||||
],
|
||||
tool_schema=[tool.schema for tool in tools],
|
||||
tool_agent=tool_executor_agent,
|
||||
),
|
||||
)
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a task to the tool user.
|
||||
response = await runtime.send_message(Message("What is the stock price of NVDA on 2024/06/01"), tool_agent)
|
||||
response = await runtime.send_message(Message("What is the stock price of NVDA on 2024/06/01"), tool_use_agent)
|
||||
# Print the result.
|
||||
assert isinstance(response, Message)
|
||||
print(response.content)
|
||||
Reference in New Issue
Block a user