Update ai agent documentation on tool agent (#272)

* Update ai agent documentation on tool agent

* Fix
This commit is contained in:
Eric Zhu
2024-07-25 11:53:59 -07:00
committed by GitHub
parent 84d4e27776
commit e9c3a384f3
6 changed files with 36 additions and 81 deletions

View File

@@ -11,14 +11,13 @@ list of function calls.
"""
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
from agnext.components import FunctionCall, message_handler
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import (
AssistantMessage,
@@ -29,6 +28,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import CancellationToken
@@ -42,13 +42,7 @@ class Message:
content: str
@dataclass
class FunctionExecutionException(BaseException):
call_id: str
content: str
class ToolEnabledAgent(TypeRoutedAgent):
class ToolEnabledAgent(ToolAgent):
"""An agent that uses tools to perform tasks. It executes the tools
by itself by sending the tool execution task to itself."""
@@ -59,17 +53,16 @@ class ToolEnabledAgent(TypeRoutedAgent):
model_client: ChatCompletionClient,
tools: List[Tool],
) -> None:
super().__init__(description)
super().__init__(description, tools)
self._model_client = model_client
self._system_messages = system_messages
self._tools = tools
@message_handler
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
"""Handle a user message, execute the model and tools, and returns the response."""
session: List[LLMMessage] = []
session.append(UserMessage(content=message.content, source="User"))
response = await self._model_client.create(self._system_messages + session, tools=self._tools)
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
# Keep executing the tools until the response is not a list of function calls.
@@ -83,40 +76,18 @@ class ToolEnabledAgent(TypeRoutedAgent):
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, FunctionExecutionException):
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result
session.append(FunctionExecutionResultMessage(content=function_results))
# Execute the model again with the new response.
response = await self._model_client.create(self._system_messages + session, tools=self._tools)
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
assert isinstance(response.content, str)
return Message(content=response.content)
@message_handler
async def handle_tool_call(
self, message: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Handle a tool execution task. This method executes the tool and publishes the result."""
# Find the tool
tool = next((tool for tool in self._tools if tool.name == message.name), None)
if tool is None:
raise FunctionExecutionException(call_id=message.id, content=f"Error: Tool not found: {message.name}")
else:
try:
arguments = json.loads(message.arguments)
result = await tool.run_json(args=arguments, cancellation_token=cancellation_token)
result_as_str = tool.return_value_as_string(result)
except json.JSONDecodeError as e:
raise FunctionExecutionException(
call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}"
) from e
except Exception as e:
raise FunctionExecutionException(call_id=message.id, content=f"Error: {e}") from e
return FunctionExecutionResult(content=result_as_str, call_id=message.id)
async def main() -> None:
# Create the runtime.