diff --git a/python/docs/src/core-concepts/ai-agents.md b/python/docs/src/core-concepts/ai-agents.md
index 2177fe853..a1f5490c9 100644
--- a/python/docs/src/core-concepts/ai-agents.md
+++ b/python/docs/src/core-concepts/ai-agents.md
@@ -136,9 +136,10 @@ stock_price_tool = FunctionTool(get_stock_price, description="Get the stock pric
### Tool-Equipped Agent
-AGNext does not prescribe a specific way to use tools with agents, and you can
-use tools in any way that makes sense for your application.
-Here is an example agent class that shows one possible way to use tools with agents.
+To use tools with an agent, you can use {py:class}`agnext.components.tool_agent.ToolAgent`,
+either by subclassing it or by using it in a composition pattern.
+Here is an example tool-equipped agent that subclasses {py:class}~`agnext.components.tool_agent.ToolAgent`
+and executes its tools by sending direct messages to itself.
```python
import json
@@ -147,6 +148,7 @@ from typing import List
from dataclasses import dataclass
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler, FunctionCall
+from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.models import (
ChatCompletionClient,
SystemMessage,
@@ -163,17 +165,11 @@ from agnext.core import CancellationToken
class MyMessage:
content: str
-@dataclass
-class FunctionExecutionException(BaseException):
- call_id: str
- content: str
-
-class ToolAgent(TypeRoutedAgent):
+class ToolEquippedAgent(ToolAgent):
def __init__(self, model_client: ChatCompletionClient, tools: List[Tool]) -> None:
- super().__init__("An agent with tools")
+ super().__init__("An agent with tools", tools)
self._system_messages = [SystemMessage("You are a helpful AI assistant.")]
self._model_client = model_client
- self._tools = tools
@message_handler
async def handle_user_message(self, message: MyMessage, cancellation_token: CancellationToken) -> MyMessage:
@@ -181,7 +177,7 @@ class ToolAgent(TypeRoutedAgent):
session = [UserMessage(content=message.content, source="user")]
# Get a response from the model.
response = await self._model_client.create(
- self._system_messages + session, tools=self._tools, cancellation_token=cancellation_token
+ self._system_messages + session, tools=self.tools, cancellation_token=cancellation_token
)
# Add the response to the session.
session.append(AssistantMessage(content=response.content, source="assistant"))
@@ -198,54 +194,37 @@ class ToolAgent(TypeRoutedAgent):
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
- elif isinstance(result, FunctionExecutionException):
+ elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
session.append(FunctionExecutionResultMessage(content=function_results))
# Query the model again with the new response.
response = await self._model_client.create(
- self._system_messages + session, tools=self._tools, cancellation_token=cancellation_token
+ self._system_messages + session, tools=self.tools, cancellation_token=cancellation_token
)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
# Return the final response.
return MyMessage(content=response.content)
-
- @message_handler
- async def handle_function_call(
- self, message: FunctionCall, cancellation_token: CancellationToken
- ) -> FunctionExecutionResult:
- tool = next((tool for tool in self._tools if tool.name == message.name), None)
- if tool is None:
- raise FunctionExecutionException(call_id=message.id, content=f"Error: Tool not found: {message.name}")
- else:
- try:
- arguments = json.loads(message.arguments)
- result = await tool.run_json(args=arguments, cancellation_token=cancellation_token)
- result_as_str = tool.return_value_as_string(result)
- except json.JSONDecodeError as e:
- raise FunctionExecutionException(call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}") from e
- except Exception as e:
- raise FunctionExecutionException(call_id=message.id, content=f"Error: {e}") from e
- return FunctionExecutionResult(content=result_as_str, call_id=message.id)
```
-The `ToolAgent` class is much more involved than the `SimpleAgent` class, however,
+The `ToolEquippedAgent` class is much more involved than the `SimpleAgent` class, however,
the core idea can be described using a simple control flow graph:
-
+
-The `ToolAgent`'s `handle_my_message` handler handles messages from the user,
+The `ToolEquippedAgent`'s `handle_user_message` handler handles messages from the user,
and determines whether the model has generated a tool call.
If the model has generated tool calls, then the handler sends a function call
-message to itself to execute the tools, and then queries the model again
+message to itself to execute the tools -- implemented by the parent {py:class}~`agnext.components.tool_agent.ToolAgent` class,
+and then queries the model again
with the results of the tool calls.
This process continues until the model stops generating tool calls,
at which point the final response is returned to the user.
-By separating the tool execution logic into a separate handler, we expose the
-model-tool interactions to the agent runtime as messages, so the tool executions
+By having the tool execution logic in a separate handler in the base class,
+we expose the model-tool interactions to the agent runtime as messages, so the tool executions
can be observed externally and intercepted if necessary.
To run the agent, we need to create a runtime and register the agent.
@@ -256,7 +235,7 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
agent = await runtime.register_and_get(
"tool-agent",
- lambda: ToolAgent(
+ lambda: ToolEquippedAgent(
OpenAIChatCompletionClient(model="gpt-4o-mini", api_key="YOUR_API_KEY"),
tools=[
FunctionTool(get_stock_price, description="Get the stock price."),
diff --git a/python/docs/src/core-concepts/tool-agent-cfg.svg b/python/docs/src/core-concepts/tool-agent-cfg.svg
deleted file mode 100644
index 7ffe47097..000000000
--- a/python/docs/src/core-concepts/tool-agent-cfg.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/python/docs/src/core-concepts/tool-equipped-agent-cfg.svg b/python/docs/src/core-concepts/tool-equipped-agent-cfg.svg
new file mode 100644
index 000000000..171051dea
--- /dev/null
+++ b/python/docs/src/core-concepts/tool-equipped-agent-cfg.svg
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/python/samples/tool-use/coding_one_agent_direct.py b/python/samples/tool-use/coding_one_agent_direct.py
index 3b029d354..342fd2f51 100644
--- a/python/samples/tool-use/coding_one_agent_direct.py
+++ b/python/samples/tool-use/coding_one_agent_direct.py
@@ -11,14 +11,13 @@ list of function calls.
"""
import asyncio
-import json
import os
import sys
from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
-from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
+from agnext.components import FunctionCall, message_handler
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import (
AssistantMessage,
@@ -29,6 +28,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
+from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import CancellationToken
@@ -42,13 +42,7 @@ class Message:
content: str
-@dataclass
-class FunctionExecutionException(BaseException):
- call_id: str
- content: str
-
-
-class ToolEnabledAgent(TypeRoutedAgent):
+class ToolEnabledAgent(ToolAgent):
"""An agent that uses tools to perform tasks. It executes the tools
by itself by sending the tool execution task to itself."""
@@ -59,17 +53,16 @@ class ToolEnabledAgent(TypeRoutedAgent):
model_client: ChatCompletionClient,
tools: List[Tool],
) -> None:
- super().__init__(description)
+ super().__init__(description, tools)
self._model_client = model_client
self._system_messages = system_messages
- self._tools = tools
@message_handler
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
"""Handle a user message, execute the model and tools, and returns the response."""
session: List[LLMMessage] = []
session.append(UserMessage(content=message.content, source="User"))
- response = await self._model_client.create(self._system_messages + session, tools=self._tools)
+ response = await self._model_client.create(self._system_messages + session, tools=self.tools)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
# Keep executing the tools until the response is not a list of function calls.
@@ -83,40 +76,18 @@ class ToolEnabledAgent(TypeRoutedAgent):
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
- elif isinstance(result, FunctionExecutionException):
+ elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result
session.append(FunctionExecutionResultMessage(content=function_results))
# Execute the model again with the new response.
- response = await self._model_client.create(self._system_messages + session, tools=self._tools)
+ response = await self._model_client.create(self._system_messages + session, tools=self.tools)
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
assert isinstance(response.content, str)
return Message(content=response.content)
- @message_handler
- async def handle_tool_call(
- self, message: FunctionCall, cancellation_token: CancellationToken
- ) -> FunctionExecutionResult:
- """Handle a tool execution task. This method executes the tool and publishes the result."""
- # Find the tool
- tool = next((tool for tool in self._tools if tool.name == message.name), None)
- if tool is None:
- raise FunctionExecutionException(call_id=message.id, content=f"Error: Tool not found: {message.name}")
- else:
- try:
- arguments = json.loads(message.arguments)
- result = await tool.run_json(args=arguments, cancellation_token=cancellation_token)
- result_as_str = tool.return_value_as_string(result)
- except json.JSONDecodeError as e:
- raise FunctionExecutionException(
- call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}"
- ) from e
- except Exception as e:
- raise FunctionExecutionException(call_id=message.id, content=f"Error: {e}") from e
- return FunctionExecutionResult(content=result_as_str, call_id=message.id)
-
async def main() -> None:
# Create the runtime.
diff --git a/python/samples/tool-use/coding_one_agent_direct_intercept.py b/python/samples/tool-use/coding_one_agent_direct_intercept.py
index be26b96c1..71f8c5467 100644
--- a/python/samples/tool-use/coding_one_agent_direct_intercept.py
+++ b/python/samples/tool-use/coding_one_agent_direct_intercept.py
@@ -14,13 +14,14 @@ from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall
from agnext.components.code_executor import LocalCommandLineCodeExecutor
from agnext.components.models import SystemMessage
+from agnext.components.tool_agent import ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import AgentId
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
-from coding_one_agent_direct import FunctionExecutionException, Message, ToolEnabledAgent
+from coding_one_agent_direct import Message, ToolEnabledAgent
from common.utils import get_chat_completion_client_from_envs
@@ -32,7 +33,7 @@ class ToolInterventionHandler(DefaultInterventionHandler):
f"Function call: {message.name}\nArguments: {message.arguments}\nDo you want to execute the tool? (y/n): "
)
if user_input.strip().lower() != "y":
- raise FunctionExecutionException(content="User denied tool execution.", call_id=message.id)
+ raise ToolException(content="User denied tool execution.", call_id=message.id)
return message
diff --git a/python/src/agnext/components/tool_agent/_tool_agent.py b/python/src/agnext/components/tool_agent/_tool_agent.py
index 6d70e7931..7e6056f55 100644
--- a/python/src/agnext/components/tool_agent/_tool_agent.py
+++ b/python/src/agnext/components/tool_agent/_tool_agent.py
@@ -55,6 +55,10 @@ class ToolAgent(TypeRoutedAgent):
super().__init__(description)
self._tools = tools
+ @property
+ def tools(self) -> List[Tool]:
+ return self._tools
+
@message_handler
async def handle_function_call(
self, message: FunctionCall, cancellation_token: CancellationToken