Use class hierarchy to organize AgentChat message types and introduce StructuredMessage type (#5998)

This PR refactored `AgentEvent` and `ChatMessage` union types to
abstract base classes. This allows for user-defined message types that
subclass one of the base classes to be used in AgentChat.

To support a unified interface for working with the messages, the base
classes added abstract methods for:
- Convert content to string
- Convert content to a `UserMessage` for model client
- Convert content for rendering in console.
- Dump into a dictionary
- Load and create a new instance from a dictionary

This way, all agents such as `AssistantAgent` and `SocietyOfMindAgent`
can utilize the unified interface to work with any built-in and
user-defined message type.

This PR also introduces a new message type, `StructuredMessage` for
AgentChat (Resolves #5131), which is a generic type that requires a
user-specified content type.

You can create a `StructuredMessage` as follow:

```python

class MessageType(BaseModel):
  data: str
  references: List[str]

message = StructuredMessage[MessageType](content=MessageType(data="data", references=["a", "b"]), source="user")

# message.content is of type `MessageType`. 
```

This PR addresses the receving side of this message type. To produce
this message type from `AssistantAgent`, the work continue in #5934.

Added unit tests to verify this message type works with agents and
teams.
This commit is contained in:
Eric Zhu
2025-03-26 16:19:52 -07:00
committed by GitHub
parent 8a5ee3de6a
commit 025490a1bd
42 changed files with 4241 additions and 3627 deletions

View File

@@ -12,6 +12,7 @@ from autogen_agentchat.messages import (
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
StructuredMessage,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
@@ -624,6 +625,23 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(result.messages) == 2
@pytest.mark.asyncio
async def test_run_with_structured_task() -> None:
class InputTask(BaseModel):
input: str
data: List[str]
model_client = ReplayChatCompletionClient(["Hello"])
agent = AssistantAgent(
name="assistant",
model_client=model_client,
)
task = StructuredMessage[InputTask](content=InputTask(input="Test", data=["Test1", "Test2"]), source="user")
result = await agent.run(task=task)
assert len(result.messages) == 2
@pytest.mark.asyncio
async def test_invalid_model_capabilities() -> None:
model = "random-model"
@@ -896,6 +914,7 @@ async def test_model_client_stream() -> None:
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert isinstance(message.messages[-1], TextMessage)
assert message.messages[-1].content == "Response to message 3"
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
@@ -929,11 +948,14 @@ async def test_model_client_stream_with_tool_calls() -> None:
chunks: List[str] = []
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert isinstance(message.messages[-1], TextMessage)
assert isinstance(message.messages[1], ToolCallRequestEvent)
assert message.messages[-1].content == "Example response 2 to task"
assert message.messages[1].content == [
FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'),
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
]
assert isinstance(message.messages[2], ToolCallExecutionEvent)
assert message.messages[2].content == [
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),

View File

@@ -20,6 +20,7 @@ from autogen_agentchat.messages import (
HandoffMessage,
MultiModalMessage,
StopMessage,
StructuredMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
@@ -44,6 +45,7 @@ from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import FileLogHandler
logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -101,6 +103,34 @@ class _FlakyAgent(BaseChatAgent):
self._last_message = None
class _UnknownMessageType(ChatMessage):
content: str
def to_model_message(self) -> UserMessage:
raise NotImplementedError("This message type is not supported.")
def to_model_text(self) -> str:
raise NotImplementedError("This message type is not supported.")
def to_text(self) -> str:
raise NotImplementedError("This message type is not supported.")
class _UnknownMessageTypeAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
return (_UnknownMessageType,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=_UnknownMessageType(content="Unknown message type", source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
class _StopAgent(_EchoAgent):
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
super().__init__(name, description)
@@ -122,6 +152,19 @@ def _pass_function(input: str) -> str:
return "pass"
class _InputTask1(BaseModel):
task: str
data: List[str]
class _InputTask2(BaseModel):
task: str
data: str
TaskType = str | List[ChatMessage] | ChatMessage
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
if request.param == "single_threaded":
@@ -164,14 +207,11 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
"Hello, world!",
"TERMINATE",
]
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
normalized_messages = [
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
for msg in result.messages
]
# Assert that all expected messages are in the collected messages
assert normalized_messages == expected_messages
for i in range(len(expected_messages)):
produced_message = result.messages[i]
assert isinstance(produced_message, TextMessage)
content = produced_message.content.replace("\r\n", "\n").rstrip("\n")
assert content == expected_messages[i]
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
@@ -202,28 +242,89 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
model_client.reset()
index = 0
await team.reset()
result_2 = await team.run(
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
)
assert result.messages[0].content == result_2.messages[0].content[0]
task = MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
result_2 = await team.run(task=task)
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result_2.messages[0], MultiModalMessage)
assert result.messages[0].content == task.content[0]
assert result.messages[1:] == result_2.messages[1:]
@pytest.mark.asyncio
async def test_round_robin_group_chat_state(runtime: AgentRuntime | None) -> None:
async def test_round_robin_group_chat_unknown_task_message_type(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient([])
agent1 = AssistantAgent("agent1", model_client=model_client)
agent2 = AssistantAgent("agent2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team1 = RoundRobinGroupChat(
participants=[agent1, agent2],
termination_condition=termination,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask2]],
)
with pytest.raises(ValueError, match=r"Message type .*StructuredMessage\[_InputTask1\].* is not registered"):
await team1.run(
task=StructuredMessage[_InputTask1](
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
source="user",
)
)
@pytest.mark.asyncio
async def test_round_robin_group_chat_unknown_agent_message_type() -> None:
model_client = ReplayChatCompletionClient(["Hello"])
agent1 = AssistantAgent("agent1", model_client=model_client)
agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent")
termination = TextMentionTermination("TERMINATE")
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"):
await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user"))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"task",
[
"Write a program that prints 'Hello, world!'",
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
[
StructuredMessage[_InputTask1](
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
source="user",
),
StructuredMessage[_InputTask2](
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
),
],
],
ids=["text", "text_message", "multi_modal_message", "structured_message"],
)
async def test_round_robin_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
)
agent1 = AssistantAgent("agent1", model_client=model_client)
agent2 = AssistantAgent("agent2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, runtime=runtime)
await team1.run(task="Write a program that prints 'Hello, world!'")
team1 = RoundRobinGroupChat(
participants=[agent1, agent2],
termination_condition=termination,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team1.run(task=task)
state = await team1.save_state()
agent3 = AssistantAgent("agent1", model_client=model_client)
agent4 = AssistantAgent("agent2", model_client=model_client)
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination, runtime=runtime)
team2 = RoundRobinGroupChat(
participants=[agent3, agent4],
termination_condition=termination,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
@@ -453,6 +554,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
task="Write a program that prints 'Hello, world!'",
)
assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent3"
assert result.messages[2].source == "agent2"
@@ -485,7 +587,25 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
@pytest.mark.asyncio
async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
@pytest.mark.parametrize(
"task",
[
"Write a program that prints 'Hello, world!'",
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
[
StructuredMessage[_InputTask1](
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
source="user",
),
StructuredMessage[_InputTask2](
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
),
],
],
ids=["text", "text_message", "multi_modal_message", "structured_message"],
)
async def test_selector_group_chat_state(task: TaskType, runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
)
@@ -497,14 +617,18 @@ async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
termination_condition=termination,
model_client=model_client,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team1.run(task="Write a program that prints 'Hello, world!'")
await team1.run(task=task)
state = await team1.save_state()
agent3 = AssistantAgent("agent1", model_client=model_client)
agent4 = AssistantAgent("agent2", model_client=model_client)
team2 = SelectorGroupChat(
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
participants=[agent3, agent4],
termination_condition=termination,
model_client=model_client,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team2.load_state(state)
state2 = await team2.save_state()
@@ -545,6 +669,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
task="Write a program that prints 'Hello, world!'",
)
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent2"
assert result.messages[2].source == "agent1"
@@ -594,6 +719,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent2"
assert result.messages[2].source == "agent2"
@@ -635,6 +761,7 @@ async def test_selector_group_chat_succcess_after_2_attempts(runtime: AgentRunti
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 2
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent2"
@@ -659,6 +786,7 @@ async def test_selector_group_chat_fall_back_to_first_after_3_attempts(runtime:
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 2
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent1"
@@ -679,6 +807,7 @@ async def test_selector_group_chat_fall_back_to_previous_after_3_attempts(runtim
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 3
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].source == "agent2"
assert result.messages[2].source == "agent2"
@@ -796,6 +925,12 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
result = await team.run(task="task")
assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert isinstance(result.messages[2], HandoffMessage)
assert isinstance(result.messages[3], HandoffMessage)
assert isinstance(result.messages[4], HandoffMessage)
assert isinstance(result.messages[5], HandoffMessage)
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
assert result.messages[2].content == "Transferred to first_agent."
@@ -839,6 +974,65 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
@pytest.mark.asyncio
@pytest.mark.parametrize(
"task",
[
"Write a program that prints 'Hello, world!'",
[TextMessage(content="Write a program that prints 'Hello, world!'", source="user")],
[MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")],
[
StructuredMessage[_InputTask1](
content=_InputTask1(task="Write a program that prints 'Hello, world!'", data=["a", "b", "c"]),
source="user",
),
StructuredMessage[_InputTask2](
content=_InputTask2(task="Write a program that prints 'Hello, world!'", data="a"), source="user"
),
],
],
ids=["text", "text_message", "multi_modal_message", "structured_message"],
)
async def test_swarm_handoff_state(task: TaskType, runtime: AgentRuntime | None) -> None:
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
termination = MaxMessageTermination(6)
team1 = Swarm(
[second_agent, first_agent, third_agent],
termination_condition=termination,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team1.run(task=task)
state = await team1.save_state()
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team2 = Swarm(
[second_agent2, first_agent2, third_agent2],
termination_condition=termination,
runtime=runtime,
custom_message_types=[StructuredMessage[_InputTask1], StructuredMessage[_InputTask2]],
)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
SwarmGroupChatManager, # pyright: ignore
)
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
SwarmGroupChatManager, # pyright: ignore
)
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
@pytest.mark.asyncio
async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
@@ -870,9 +1064,14 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
result = await team.run(task="task")
assert len(result.messages) == 7
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert isinstance(result.messages[3], HandoffMessage)
assert isinstance(result.messages[4], HandoffMessage)
assert isinstance(result.messages[5], TextMessage)
assert isinstance(result.messages[6], TextMessage)
assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello"
@@ -910,18 +1109,23 @@ async def test_swarm_pause_and_resume(runtime: AgentRuntime | None) -> None:
team = Swarm([second_agent, first_agent, third_agent], max_turns=1, runtime=runtime)
result = await team.run(task="task")
assert len(result.messages) == 2
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
# Resume with a new task.
result = await team.run(task="new task")
assert len(result.messages) == 2
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert result.messages[0].content == "new task"
assert result.messages[1].content == "Transferred to first_agent."
# Resume with the same task.
result = await team.run()
assert len(result.messages) == 1
assert isinstance(result.messages[0], HandoffMessage)
assert result.messages[0].content == "Transferred to second_agent."
@@ -996,8 +1200,10 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
source="agent1",
context=expected_handoff_context,
)
assert isinstance(result.messages[4], TextMessage)
assert result.messages[4].content == "Hello"
assert result.messages[4].source == "agent2"
assert isinstance(result.messages[5], TextMessage)
assert result.messages[5].content == "TERMINATE"
assert result.messages[5].source == "agent2"
@@ -1020,17 +1226,26 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
# Start
result = await team.run(task="task")
assert len(result.messages) == 2
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
# Resume existing.
result = await team.run()
assert len(result.messages) == 3
assert isinstance(result.messages[0], HandoffMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert isinstance(result.messages[2], HandoffMessage)
assert result.messages[0].content == "Transferred to first_agent."
assert result.messages[1].content == "Transferred to second_agent."
assert result.messages[2].content == "Transferred to third_agent."
# Resume new task.
result = await team.run(task="new task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert isinstance(result.messages[2], HandoffMessage)
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[0].content == "new task"
assert result.messages[1].content == "Transferred to first_agent."
assert result.messages[2].content == "Transferred to second_agent."
@@ -1043,6 +1258,9 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
# Start
result = await team.run(task="task")
assert len(result.messages) == 3
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert isinstance(result.messages[2], HandoffMessage)
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
assert result.messages[2].content == "Transferred to non_existing_agent."
@@ -1055,6 +1273,10 @@ async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> N
# Resume with a HandoffMessage
result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
assert len(result.messages) == 4
assert isinstance(result.messages[0], HandoffMessage)
assert isinstance(result.messages[1], HandoffMessage)
assert isinstance(result.messages[2], HandoffMessage)
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[0].content == "Handoff to first_agent."
assert result.messages[1].content == "Transferred to second_agent."
assert result.messages[2].content == "Transferred to third_agent."
@@ -1081,6 +1303,10 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
# Verify the messages were processed in order
assert len(result.messages) == 4 # Initial messages + echo until termination
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage)
assert isinstance(result.messages[2], TextMessage)
assert isinstance(result.messages[3], TextMessage)
assert result.messages[0].content == "Message 1" # First message
assert result.messages[1].content == "Message 2" # Second message
assert result.messages[2].content == "Message 3" # Third message

View File

@@ -4,10 +4,7 @@ from typing import List, Sequence
import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
)
from autogen_agentchat.messages import AgentEvent, ChatMessage
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ChatCompletionClient

View File

@@ -134,8 +134,8 @@ async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> No
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 5
assert result.messages[2].content == "Continue task"
assert result.messages[4].content == "print('Hello, world!')"
assert result.messages[2].to_text() == "Continue task"
assert result.messages[4].to_text() == "print('Hello, world!')"
assert result.stop_reason is not None and result.stop_reason == "Because"
# Test save and load.
@@ -214,8 +214,8 @@ async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None)
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 6
assert isinstance(result.messages[1].content, str)
assert isinstance(result.messages[1], TextMessage)
assert result.messages[1].content.startswith("\nWe are working to address the following user request:")
assert isinstance(result.messages[4].content, str)
assert isinstance(result.messages[4], TextMessage)
assert result.messages[4].content.startswith("\nWe are working to address the following user request:")
assert result.stop_reason is not None and result.stop_reason == "test"

View File

@@ -0,0 +1,93 @@
import pytest
from autogen_agentchat.messages import HandoffMessage, MessageFactory, StructuredMessage, TextMessage
from pydantic import BaseModel
class TestContent(BaseModel):
"""Test content model."""
field1: str
field2: int
def test_structured_message() -> None:
# Create a structured message with the test content
message = StructuredMessage[TestContent](
source="test_agent",
content=TestContent(field1="test", field2=42),
)
# Check that the message type is correct
assert message.type == "StructuredMessage[TestContent]" # type: ignore
# Check that the content is of the correct type
assert isinstance(message.content, TestContent)
# Check that the content fields are set correctly
assert message.content.field1 == "test"
assert message.content.field2 == 42
# Check that model_dump works correctly
dumped_message = message.model_dump()
assert dumped_message["source"] == "test_agent"
assert dumped_message["content"]["field1"] == "test"
assert dumped_message["content"]["field2"] == 42
assert dumped_message["type"] == "StructuredMessage[TestContent]"
def test_message_factory() -> None:
factory = MessageFactory()
# Text message data
text_data = {
"type": "TextMessage",
"source": "test_agent",
"content": "Hello, world!",
}
# Create a TextMessage instance
text_message = factory.create(text_data)
assert isinstance(text_message, TextMessage)
assert text_message.source == "test_agent"
assert text_message.content == "Hello, world!"
assert text_message.type == "TextMessage" # type: ignore
# Handoff message data
handoff_data = {
"type": "HandoffMessage",
"source": "test_agent",
"content": "handoff to another agent",
"target": "target_agent",
}
# Create a HandoffMessage instance
handoff_message = factory.create(handoff_data)
assert isinstance(handoff_message, HandoffMessage)
assert handoff_message.source == "test_agent"
assert handoff_message.content == "handoff to another agent"
assert handoff_message.target == "target_agent"
assert handoff_message.type == "HandoffMessage" # type: ignore
# Structured message data
structured_data = {
"type": "StructuredMessage[TestContent]",
"source": "test_agent",
"content": {
"field1": "test",
"field2": 42,
},
}
# Create a StructuredMessage instance -- this will fail because the type
# is not registered in the factory.
with pytest.raises(ValueError):
structured_message = factory.create(structured_data)
# Register the StructuredMessage type in the factory
factory.register(StructuredMessage[TestContent])
# Create a StructuredMessage instance
structured_message = factory.create(structured_data)
assert isinstance(structured_message, StructuredMessage)
assert isinstance(structured_message.content, TestContent) # type: ignore
assert structured_message.source == "test_agent"
assert structured_message.content.field1 == "test"
assert structured_message.content.field2 == 42
assert structured_message.type == "StructuredMessage[TestContent]" # type: ignore