Tool call result summary message (#4755)

* Adding ToolCallResultSummaryMessage

* Support for ToolCallResultSummaryMessage

* Added ToolCallSummaryMessage

* ruff format

* Add ToolCallSummaryMessage to ChatMessage

* typing and tests for ToolCallSummaryMessage

* PR Feedback

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
jspv
2024-12-20 00:23:18 -05:00
committed by GitHub
parent 3dd4be949e
commit a271708a97
6 changed files with 37 additions and 10 deletions

View File

@@ -28,6 +28,7 @@ from ..messages import (
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
@@ -62,7 +63,7 @@ class AssistantAgent(BaseChatAgent):
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
Hand off behavior:
@@ -280,9 +281,12 @@ class AssistantAgent(BaseChatAgent):
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
return [TextMessage, HandoffMessage]
return [TextMessage]
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
return message_types
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
@@ -379,7 +383,7 @@ class AssistantAgent(BaseChatAgent):
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

View File

@@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
class ToolCallSummaryMessage(BaseMessage):
"""A message signaling the summary of tool call results."""
content: str
"""Summary of the the tool call results."""
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""
@@ -110,7 +121,13 @@ AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(disc
AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
TextMessage
| MultiModalMessage
| StopMessage
| HandoffMessage
| ToolCallRequestEvent
| ToolCallExecutionEvent
| ToolCallSummaryMessage,
Field(discriminator="type"),
]
"""(Deprecated, will be removed in 0.4.0) All message and event types."""
@@ -126,6 +143,7 @@ __all__ = [
"ToolCallExecutionEvent",
"ToolCallMessage",
"ToolCallResultMessage",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"AgentMessage",

View File

@@ -21,6 +21,7 @@ from ....messages import (
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
@@ -433,7 +434,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
elif isinstance(m, StopMessage | HandoffMessage):
context.append(UserMessage(content=m.content, source=m.source))
elif m.source == self._name:
assert isinstance(m, TextMessage)
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
context.append(AssistantMessage(content=m.content, source=m.source))
else:
assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage)

View File

@@ -15,6 +15,7 @@ from ...messages import (
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
@@ -100,7 +101,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg, TextMessage | StopMessage | HandoffMessage):
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
message += f" {msg.content}"
elif isinstance(msg, MultiModalMessage):
for item in msg.content:

View File

@@ -14,6 +14,7 @@ from autogen_agentchat.messages import (
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.tools import FunctionTool
@@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert result.messages[3].models_usage is None

View File

@@ -22,6 +22,7 @@ from autogen_agentchat.messages import (
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_agentchat.teams import (
RoundRobinGroupChat,
@@ -325,7 +326,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
assert result.messages[3].content == "pass" # ensure the tool call was executed
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert isinstance(result.messages[6], TextMessage) # echo agent response