mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Tool call result summary message (#4755)
* Adding ToolCallResultSummaryMessage * Support for ToolCallResultSummaryMessage * Added ToolCallSummaryMessage * ruff format * Add ToolCallSummaryMessage to ChatMessage * typing and tests for ToolCallSummaryMessage * PR Feedback --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from ..messages import (
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ..state import AssistantAgentState
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
@@ -62,7 +63,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
* When the model returns tool calls, they will be executed right away:
|
||||
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
|
||||
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
|
||||
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
|
||||
|
||||
Hand off behavior:
|
||||
@@ -280,9 +281,12 @@ class AssistantAgent(BaseChatAgent):
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
message_types: List[type[ChatMessage]] = [TextMessage]
|
||||
if self._handoffs:
|
||||
return [TextMessage, HandoffMessage]
|
||||
return [TextMessage]
|
||||
message_types.append(HandoffMessage)
|
||||
if self._tools:
|
||||
message_types.append(ToolCallSummaryMessage)
|
||||
return message_types
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
@@ -379,7 +383,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
)
|
||||
tool_call_summary = "\n".join(tool_call_summaries)
|
||||
yield Response(
|
||||
chat_message=TextMessage(content=tool_call_summary, source=self.name),
|
||||
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage):
|
||||
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
|
||||
|
||||
|
||||
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
|
||||
class ToolCallSummaryMessage(BaseMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
content: str
|
||||
"""Summary of the the tool call results."""
|
||||
|
||||
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"
|
||||
|
||||
|
||||
ChatMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
|
||||
]
|
||||
"""Messages for agent-to-agent communication only."""
|
||||
|
||||
|
||||
@@ -110,7 +121,13 @@ AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(disc
|
||||
|
||||
|
||||
AgentMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
|
||||
TextMessage
|
||||
| MultiModalMessage
|
||||
| StopMessage
|
||||
| HandoffMessage
|
||||
| ToolCallRequestEvent
|
||||
| ToolCallExecutionEvent
|
||||
| ToolCallSummaryMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""(Deprecated, will be removed in 0.4.0) All message and event types."""
|
||||
@@ -126,6 +143,7 @@ __all__ = [
|
||||
"ToolCallExecutionEvent",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessage",
|
||||
"ToolCallSummaryMessage",
|
||||
"ChatMessage",
|
||||
"AgentEvent",
|
||||
"AgentMessage",
|
||||
|
||||
@@ -21,6 +21,7 @@ from ....messages import (
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
@@ -433,7 +434,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
elif isinstance(m, StopMessage | HandoffMessage):
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
elif m.source == self._name:
|
||||
assert isinstance(m, TextMessage)
|
||||
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage)
|
||||
|
||||
@@ -15,6 +15,7 @@ from ...messages import (
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
@@ -100,7 +101,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
continue
|
||||
# The agent type must be the same as the topic type, which we use as the agent name.
|
||||
message = f"{msg.source}:"
|
||||
if isinstance(msg, TextMessage | StopMessage | HandoffMessage):
|
||||
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
|
||||
message += f" {msg.content}"
|
||||
elif isinstance(msg, MultiModalMessage):
|
||||
for item in msg.content:
|
||||
|
||||
@@ -14,6 +14,7 @@ from autogen_agentchat.messages import (
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from autogen_core import Image
|
||||
from autogen_core.tools import FunctionTool
|
||||
@@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[1].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert result.messages[2].models_usage is None
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
assert isinstance(result.messages[3], ToolCallSummaryMessage)
|
||||
assert result.messages[3].content == "pass"
|
||||
assert result.messages[3].models_usage is None
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from autogen_agentchat.messages import (
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from autogen_agentchat.teams import (
|
||||
RoundRobinGroupChat,
|
||||
@@ -325,7 +326,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
assert isinstance(result.messages[0], TextMessage) # task
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
|
||||
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
|
||||
assert result.messages[3].content == "pass" # ensure the tool call was executed
|
||||
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[5], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[6], TextMessage) # echo agent response
|
||||
|
||||
Reference in New Issue
Block a user