Add message context to message handler (#367)

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-16 23:14:09 -04:00
committed by GitHub
parent bc26ec3de4
commit 853b00b0f0
49 changed files with 267 additions and 235 deletions

View File

@@ -15,7 +15,7 @@ from agnext.components.models import (
SystemMessage,
)
from agnext.components.tools import Tool
from agnext.core import AgentId, CancellationToken
from agnext.core import AgentId, CancellationToken, MessageContext
from ..types import (
FunctionCallMessage,
@@ -74,50 +74,48 @@ class ChatCompletionAgent(TypeRoutedAgent):
self._tool_approver = tool_approver
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:
"""Handle a text message. This method adds the message to the memory and
does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None:
async def on_multi_modal_message(self, message: MultiModalMessage, ctx: MessageContext) -> None:
"""Handle a multimodal message. This method adds the message to the memory
and does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
async def on_reset(self, message: Reset, ctx: MessageContext) -> None:
"""Handle a reset message. This method clears the memory."""
# Reset the chat messages.
await self._memory.clear()
@message_handler()
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
async def on_respond_now(self, message: RespondNow, ctx: MessageContext) -> TextMessage | FunctionCallMessage:
"""Handle a respond now message. This method generates a response and
returns it to the sender."""
# Generate a response.
response = await self._generate_response(message.response_format, cancellation_token)
response = await self._generate_response(message.response_format, ctx)
# Return the response.
return response
@message_handler()
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method generates a response and
publishes it."""
# Generate a response.
response = await self._generate_response(message.response_format, cancellation_token)
response = await self._generate_response(message.response_format, ctx)
# Publish the response.
await self.publish_message(response)
@message_handler()
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
self, message: FunctionCallMessage, ctx: MessageContext
) -> FunctionExecutionResultMessage:
"""Handle a tool call message. This method executes the tools and
returns the results."""
@@ -147,7 +145,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
function_call.name,
arguments,
function_call.id,
cancellation_token=cancellation_token,
cancellation_token=ctx.cancellation_token,
)
# Append the async result.
execution_futures.append(future)
@@ -170,7 +168,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
async def _generate_response(
self,
response_format: ResponseFormat,
cancellation_token: CancellationToken,
ctx: MessageContext,
) -> TextMessage | FunctionCallMessage:
# Get a response from the model.
hisorical_messages = await self._memory.get_messages()
@@ -192,7 +190,7 @@ class ChatCompletionAgent(TypeRoutedAgent):
response = await self.send_message(
message=FunctionCallMessage(content=response.content, source=self.metadata["type"]),
recipient=self.id,
cancellation_token=cancellation_token,
cancellation_token=ctx.cancellation_token,
)
# Make an assistant message from the response.
hisorical_messages = await self._memory.get_messages()