Add message context to message handler (#367)

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-16 23:14:09 -04:00
committed by GitHub
parent bc26ec3de4
commit 853b00b0f0
49 changed files with 267 additions and 235 deletions

View File

@@ -30,10 +30,11 @@ from agnext.components.models import (
)
from agnext.components.tool_agent import ToolAgent, ToolException
from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema
from agnext.core import AgentId, CancellationToken
from agnext.core import AgentId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import MessageContext
from common.utils import get_chat_completion_client_from_envs
@@ -61,7 +62,7 @@ class ToolUseAgent(TypeRoutedAgent):
self._tool_agent = tool_agent
@message_handler
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:
"""Handle a user message, execute the model and tools, and returns the response."""
session: List[LLMMessage] = []
session.append(UserMessage(content=message.content, source="User"))
@@ -72,7 +73,7 @@ class ToolUseAgent(TypeRoutedAgent):
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(call, self._tool_agent, cancellation_token=cancellation_token)
self.send_message(call, self._tool_agent, cancellation_token=ctx.cancellation_token)
for call in response.content
],
return_exceptions=True,

View File

@@ -32,10 +32,10 @@ from agnext.components.models import (
UserMessage,
)
from agnext.components.tools import PythonCodeExecutionTool, Tool
from agnext.core import CancellationToken
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import MessageContext
from common.utils import get_chat_completion_client_from_envs
@@ -69,7 +69,7 @@ class ToolExecutorAgent(TypeRoutedAgent):
self._tools = tools
@message_handler
async def handle_tool_call(self, message: ToolExecutionTask, cancellation_token: CancellationToken) -> None:
async def handle_tool_call(self, message: ToolExecutionTask, ctx: MessageContext) -> None:
"""Handle a tool execution task. This method executes the tool and publishes the result."""
# Find the tool
tool = next((tool for tool in self._tools if tool.name == message.function_call.name), None)
@@ -78,7 +78,7 @@ class ToolExecutorAgent(TypeRoutedAgent):
else:
try:
arguments = json.loads(message.function_call.arguments)
result = await tool.run_json(args=arguments, cancellation_token=cancellation_token)
result = await tool.run_json(args=arguments, cancellation_token=ctx.cancellation_token)
result_as_str = tool.return_value_as_string(result)
except json.JSONDecodeError:
result_as_str = f"Error: Invalid arguments: {message.function_call.arguments}"
@@ -112,7 +112,7 @@ class ToolUseAgent(TypeRoutedAgent):
self._tool_counter: Dict[str, int] = {}
@message_handler
async def handle_user_message(self, message: UserRequest, cancellation_token: CancellationToken) -> None:
async def handle_user_message(self, message: UserRequest, ctx: MessageContext) -> None:
"""Handle a user message. This method calls the model. If the model response is a string,
it publishes the response. If the model response is a list of function calls, it publishes
the function calls to the tool executor agent."""
@@ -142,7 +142,7 @@ class ToolUseAgent(TypeRoutedAgent):
await self.publish_message(task)
@message_handler
async def handle_tool_result(self, message: ToolExecutionTaskResult, cancellation_token: CancellationToken) -> None:
async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None:
"""Handle a tool execution result. This method aggregates the tool results and
calls the model again to get another response. If the response is a string, it
publishes the response. If the response is a list of function calls, it publishes