mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Use class hierarchy to organize AgentChat message types and introduce StructuredMessage type (#5998)
This PR refactored `AgentEvent` and `ChatMessage` union types to abstract base classes. This allows for user-defined message types that subclass one of the base classes to be used in AgentChat. To support a unified interface for working with the messages, the base classes added abstract methods for: - Convert content to string - Convert content to a `UserMessage` for model client - Convert content for rendering in console. - Dump into a dictionary - Load and create a new instance from a dictionary This way, all agents such as `AssistantAgent` and `SocietyOfMindAgent` can utilize the unified interface to work with any built-in and user-defined message type. This PR also introduces a new message type, `StructuredMessage` for AgentChat (Resolves #5131), which is a generic type that requires a user-specified content type. You can create a `StructuredMessage` as follow: ```python class MessageType(BaseModel): data: str references: List[str] message = StructuredMessage[MessageType](content=MessageType(data="data", references=["a", "b"]), source="user") # message.content is of type `MessageType`. ``` This PR addresses the receving side of this message type. To produce this message type from `AssistantAgent`, the work continue in #5934. Added unit tests to verify this message type works with agents and teams.
This commit is contained in:
@@ -12,6 +12,7 @@ from autogen_agentchat.messages import (
|
||||
MemoryQueryEvent,
|
||||
ModelClientStreamingChunkEvent,
|
||||
MultiModalMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ThoughtEvent,
|
||||
ToolCallExecutionEvent,
|
||||
@@ -624,6 +625,23 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert len(result.messages) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_structured_task() -> None:
|
||||
class InputTask(BaseModel):
|
||||
input: str
|
||||
data: List[str]
|
||||
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
)
|
||||
|
||||
task = StructuredMessage[InputTask](content=InputTask(input="Test", data=["Test1", "Test2"]), source="user")
|
||||
result = await agent.run(task=task)
|
||||
assert len(result.messages) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_model_capabilities() -> None:
|
||||
model = "random-model"
|
||||
@@ -896,6 +914,7 @@ async def test_model_client_stream() -> None:
|
||||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert isinstance(message.messages[-1], TextMessage)
|
||||
assert message.messages[-1].content == "Response to message 3"
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
@@ -929,11 +948,14 @@ async def test_model_client_stream_with_tool_calls() -> None:
|
||||
chunks: List[str] = []
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert isinstance(message.messages[-1], TextMessage)
|
||||
assert isinstance(message.messages[1], ToolCallRequestEvent)
|
||||
assert message.messages[-1].content == "Example response 2 to task"
|
||||
assert message.messages[1].content == [
|
||||
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
|
||||
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||
]
|
||||
assert isinstance(message.messages[2], ToolCallExecutionEvent)
|
||||
assert message.messages[2].content == [
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),
|
||||
|
||||
@@ -20,6 +20,7 @@ from autogen_agentchat.messages import (
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
StructuredMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
@@ -44,6 +45,7 @@ from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import FileLogHandler
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@@ -101,6 +103,34 @@ class _FlakyAgent(BaseChatAgent):
|
||||
self._last_message = None
|
||||
|
||||
|
||||
class _UnknownMessageType(ChatMessage):
|
||||
content: str
|
||||
|
||||
def to_model_message(self) -> UserMessage:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
def to_model_text(self) -> str:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
def to_text(self) -> str:
|
||||
raise NotImplementedError("This message type is not supported.")
|
||||
|
||||
|
||||
class _UnknownMessageTypeAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
return (_UnknownMessageType,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(chat_message=_UnknownMessageType(content="Unknown message type", source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _StopAgent(_EchoAgent):
|
||||
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
|
||||
super().__init__(name, description)
|
||||
@@ -122,6 +152,19 @@ def _pass_function(input: str) -> str:
|
||||
return "pass"
|
||||
|
||||
|
||||
class _InputTask1(BaseModel):
|
||||
task: str
|
||||
data: List[str]
|
||||
|
||||
|
||||
class _InputTask2(BaseModel):
|
||||
task: str
|
||||
data: str
|
||||
|
||||
|
||||
TaskType = str | List[ChatMessage] | ChatMessage
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
||||
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
||||
if request.param == "single_threaded":
|
||||
@@ -164,14 +207,11 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
"Hello, world!",
|
||||
"TERMINATE",
|
||||
]
|
||||
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
||||
normalized_messages = [
|
||||
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
||||
for msg in result.messages
|
||||
]
|
||||
|
||||
# Assert that all expected messages are in the collected messages
|
||||
assert normalized_messages == expected_messages
|
||||
for i in range(len(expected_messages)):
|
||||
produced_message = result.messages[i]
|
||||
assert isinstance(produced_message, TextMessage)
|
||||
content = produced_message.content.replace("\r\n", "\n").rstrip("\n")
|
||||
assert content == expected_messages[i]
|
||||
|
||||
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
||||
|
||||
@@ -202,28 +242,89 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
model_client.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
)
|
||||
assert result.messages[0].content == result_2.messages[0].content[0]
|
||||
task = MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
result_2 = await team.run(task=task)
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result_2.messages[0], MultiModalMessage)
|
||||
assert result.messages[0].content == task.content[0]
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_state(runtime: AgentRuntime | None) -> None:
|
||||
async def test_round_robin_group_chat_unknown_task_message_type(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient([])
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask2]],
|
||||
)
|
||||
with pytest.raises(ValueError, match=r"Message type .*StructuredMessage\[_InputTask1\].* is not registered"):
|
||||
await team1.run(
|
||||
task=StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_unknown_agent_message_type() -> None:
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent")
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||
with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"):
|
||||
await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_round_robin_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
team1 = RoundRobinGroupChat(
|
||||
participants=[agent1, agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination, runtime=runtime)
|
||||
team2 = RoundRobinGroupChat(
|
||||
participants=[agent3, agent4],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
@@ -453,6 +554,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent3"
|
||||
assert result.messages[2].source == "agent2"
|
||||
@@ -485,7 +587,25 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_selector_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
||||
)
|
||||
@@ -497,14 +617,18 @@ async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
|
||||
termination_condition=termination,
|
||||
model_client=model_client,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = SelectorGroupChat(
|
||||
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
||||
participants=[agent3, agent4],
|
||||
termination_condition=termination,
|
||||
model_client=model_client,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
@@ -545,6 +669,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
assert len(result.messages) == 5
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent1"
|
||||
@@ -594,6 +719,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent2"
|
||||
@@ -635,6 +761,7 @@ async def test_selector_group_chat_succcess_after_2_attempts(runtime: AgentRunti
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
|
||||
@@ -659,6 +786,7 @@ async def test_selector_group_chat_fall_back_to_first_after_3_attempts(runtime:
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent1"
|
||||
|
||||
@@ -679,6 +807,7 @@ async def test_selector_group_chat_fall_back_to_previous_after_3_attempts(runtim
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent2"
|
||||
@@ -796,6 +925,12 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
||||
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert isinstance(result.messages[4], HandoffMessage)
|
||||
assert isinstance(result.messages[5], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
assert result.messages[2].content == "Transferred to first_agent."
|
||||
@@ -839,6 +974,65 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
||||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"task",
|
||||
[
|
||||
"Write a program that prints 'Hello, world!'",
|
||||
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
|
||||
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
|
||||
[
|
||||
StructuredMessage[_InputTask1](
|
||||
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
|
||||
source="user",
|
||||
),
|
||||
StructuredMessage[_InputTask2](
|
||||
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
|
||||
),
|
||||
],
|
||||
],
|
||||
ids=["text", "text_message", "multi_modal_message", "structured_message"],
|
||||
)
|
||||
async def test_swarm_handoff_state(task: TaskType, runtime: AgentRuntime | None) -> None:
|
||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
||||
termination = MaxMessageTermination(6)
|
||||
team1 = Swarm(
|
||||
[second_agent, first_agent, third_agent],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team1.run(task=task)
|
||||
state = await team1.save_state()
|
||||
|
||||
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
team2 = Swarm(
|
||||
[second_agent2, first_agent2, third_agent2],
|
||||
termination_condition=termination,
|
||||
runtime=runtime,
|
||||
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
)
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
)
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
@@ -870,9 +1064,14 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
|
||||
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 7
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert isinstance(result.messages[4], HandoffMessage)
|
||||
assert isinstance(result.messages[5], TextMessage)
|
||||
assert isinstance(result.messages[6], TextMessage)
|
||||
assert result.messages[3].content == "handoff to agent2"
|
||||
assert result.messages[4].content == "Transferred to agent1."
|
||||
assert result.messages[5].content == "Hello"
|
||||
@@ -910,18 +1109,23 @@ async def test_swarm_pause_and_resume(runtime: AgentRuntime | None) -> None:
|
||||
team = Swarm([second_agent, first_agent, third_agent], max_turns=1, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
|
||||
# Resume with a new task.
|
||||
result = await team.run(task="new task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "new task"
|
||||
assert result.messages[1].content == "Transferred to first_agent."
|
||||
|
||||
# Resume with the same task.
|
||||
result = await team.run()
|
||||
assert len(result.messages) == 1
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert result.messages[0].content == "Transferred to second_agent."
|
||||
|
||||
|
||||
@@ -996,8 +1200,10 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
|
||||
source="agent1",
|
||||
context=expected_handoff_context,
|
||||
)
|
||||
assert isinstance(result.messages[4], TextMessage)
|
||||
assert result.messages[4].content == "Hello"
|
||||
assert result.messages[4].source == "agent2"
|
||||
assert isinstance(result.messages[5], TextMessage)
|
||||
assert result.messages[5].content == "TERMINATE"
|
||||
assert result.messages[5].source == "agent2"
|
||||
|
||||
@@ -1020,17 +1226,26 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
||||
# Start
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 2
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
# Resume existing.
|
||||
result = await team.run()
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert result.messages[0].content == "Transferred to first_agent."
|
||||
assert result.messages[1].content == "Transferred to second_agent."
|
||||
assert result.messages[2].content == "Transferred to third_agent."
|
||||
# Resume new task.
|
||||
result = await team.run(task="new task")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert result.messages[0].content == "new task"
|
||||
assert result.messages[1].content == "Transferred to first_agent."
|
||||
assert result.messages[2].content == "Transferred to second_agent."
|
||||
@@ -1043,6 +1258,9 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
||||
# Start
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 3
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
assert result.messages[2].content == "Transferred to non_existing_agent."
|
||||
@@ -1055,6 +1273,10 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
|
||||
# Resume with a HandoffMessage
|
||||
result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], HandoffMessage)
|
||||
assert isinstance(result.messages[1], HandoffMessage)
|
||||
assert isinstance(result.messages[2], HandoffMessage)
|
||||
assert isinstance(result.messages[3], HandoffMessage)
|
||||
assert result.messages[0].content == "Handoff to first_agent."
|
||||
assert result.messages[1].content == "Transferred to second_agent."
|
||||
assert result.messages[2].content == "Transferred to third_agent."
|
||||
@@ -1081,6 +1303,10 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
|
||||
|
||||
# Verify the messages were processed in order
|
||||
assert len(result.messages) == 4 # Initial messages + echo until termination
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], TextMessage)
|
||||
assert isinstance(result.messages[3], TextMessage)
|
||||
assert result.messages[0].content == "Message 1" # First message
|
||||
assert result.messages[1].content == "Message 2" # Second message
|
||||
assert result.messages[2].content == "Message 3" # Third message
|
||||
|
||||
@@ -4,10 +4,7 @@ from typing import List, Sequence
|
||||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
AgentEvent,
|
||||
ChatMessage,
|
||||
)
|
||||
from autogen_agentchat.messages import AgentEvent, ChatMessage
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
|
||||
@@ -134,8 +134,8 @@ async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> No
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 5
|
||||
assert result.messages[2].content == "Continue task"
|
||||
assert result.messages[4].content == "print('Hello, world!')"
|
||||
assert result.messages[2].to_text() == "Continue task"
|
||||
assert result.messages[4].to_text() == "print('Hello, world!')"
|
||||
assert result.stop_reason is not None and result.stop_reason == "Because"
|
||||
|
||||
# Test save and load.
|
||||
@@ -214,8 +214,8 @@ async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None)
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 6
|
||||
assert isinstance(result.messages[1].content, str)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert result.messages[1].content.startswith("\nWe are working to address the following user request:")
|
||||
assert isinstance(result.messages[4].content, str)
|
||||
assert isinstance(result.messages[4], TextMessage)
|
||||
assert result.messages[4].content.startswith("\nWe are working to address the following user request:")
|
||||
assert result.stop_reason is not None and result.stop_reason == "test"
|
||||
|
||||
93
python/packages/autogen-agentchat/tests/test_messages.py
Normal file
93
python/packages/autogen-agentchat/tests/test_messages.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from autogen_agentchat.messages import HandoffMessage, MessageFactory, StructuredMessage, TextMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestContent(BaseModel):
|
||||
"""Test content model."""
|
||||
|
||||
field1: str
|
||||
field2: int
|
||||
|
||||
|
||||
def test_structured_message() -> None:
|
||||
# Create a structured message with the test content
|
||||
message = StructuredMessage[TestContent](
|
||||
source="test_agent",
|
||||
content=TestContent(field1="test", field2=42),
|
||||
)
|
||||
|
||||
# Check that the message type is correct
|
||||
assert message.type == "StructuredMessage[TestContent]" # type: ignore
|
||||
|
||||
# Check that the content is of the correct type
|
||||
assert isinstance(message.content, TestContent)
|
||||
|
||||
# Check that the content fields are set correctly
|
||||
assert message.content.field1 == "test"
|
||||
assert message.content.field2 == 42
|
||||
|
||||
# Check that model_dump works correctly
|
||||
dumped_message = message.model_dump()
|
||||
assert dumped_message["source"] == "test_agent"
|
||||
assert dumped_message["content"]["field1"] == "test"
|
||||
assert dumped_message["content"]["field2"] == 42
|
||||
assert dumped_message["type"] == "StructuredMessage[TestContent]"
|
||||
|
||||
|
||||
def test_message_factory() -> None:
|
||||
factory = MessageFactory()
|
||||
|
||||
# Text message data
|
||||
text_data = {
|
||||
"type": "TextMessage",
|
||||
"source": "test_agent",
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
# Create a TextMessage instance
|
||||
text_message = factory.create(text_data)
|
||||
assert isinstance(text_message, TextMessage)
|
||||
assert text_message.source == "test_agent"
|
||||
assert text_message.content == "Hello, world!"
|
||||
assert text_message.type == "TextMessage" # type: ignore
|
||||
|
||||
# Handoff message data
|
||||
handoff_data = {
|
||||
"type": "HandoffMessage",
|
||||
"source": "test_agent",
|
||||
"content": "handoff to another agent",
|
||||
"target": "target_agent",
|
||||
}
|
||||
|
||||
# Create a HandoffMessage instance
|
||||
handoff_message = factory.create(handoff_data)
|
||||
assert isinstance(handoff_message, HandoffMessage)
|
||||
assert handoff_message.source == "test_agent"
|
||||
assert handoff_message.content == "handoff to another agent"
|
||||
assert handoff_message.target == "target_agent"
|
||||
assert handoff_message.type == "HandoffMessage" # type: ignore
|
||||
|
||||
# Structured message data
|
||||
structured_data = {
|
||||
"type": "StructuredMessage[TestContent]",
|
||||
"source": "test_agent",
|
||||
"content": {
|
||||
"field1": "test",
|
||||
"field2": 42,
|
||||
},
|
||||
}
|
||||
# Create a StructuredMessage instance -- this will fail because the type
|
||||
# is not registered in the factory.
|
||||
with pytest.raises(ValueError):
|
||||
structured_message = factory.create(structured_data)
|
||||
# Register the StructuredMessage type in the factory
|
||||
factory.register(StructuredMessage[TestContent])
|
||||
# Create a StructuredMessage instance
|
||||
structured_message = factory.create(structured_data)
|
||||
assert isinstance(structured_message, StructuredMessage)
|
||||
assert isinstance(structured_message.content, TestContent) # type: ignore
|
||||
assert structured_message.source == "test_agent"
|
||||
assert structured_message.content.field1 == "test"
|
||||
assert structured_message.content.field2 == 42
|
||||
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore
|
||||
Reference in New Issue
Block a user